diff --git a/src/marketdata/input_types/options.py b/src/marketdata/input_types/options.py index 7da03ee..b15dba9 100644 --- a/src/marketdata/input_types/options.py +++ b/src/marketdata/input_types/options.py @@ -35,7 +35,9 @@ class OptionsChainInput(BaseInputType): description="The expiration date to filter by", default=None ) days_to_expiration: int | None = Field( - description="The number of days to expiration to filter by", default=None + description="The number of days to expiration to filter by", + alias="dte", + default=None, ) from_date: datetime.date | str | None = Field( description="The start date to fetch options chain for", diff --git a/src/tests/test_input_types.py b/src/tests/test_input_types.py index 6c5d317..5ce292c 100644 --- a/src/tests/test_input_types.py +++ b/src/tests/test_input_types.py @@ -1,15 +1,19 @@ import datetime +import importlib +import pkgutil from pathlib import Path import pytest from pydantic import Field, model_validator +import marketdata.input_types as input_types_pkg from marketdata.exceptions import MinMaxDateValidationError from marketdata.input_types.base import ( BaseInputType, OutputFormat, UserUniversalAPIParams, ) +from marketdata.internal_settings import GLOBAL_EXCLUDED_PARAMS class DummyInput(BaseInputType): @@ -22,6 +26,70 @@ def validate_input(self) -> "DummyInput": return self +def _all_input_models() -> list[type[BaseInputType]]: + """Return every concrete BaseInputType subclass defined in the SDK. + + Imports each input_types submodule first so all subclasses are registered. + """ + for module_info in pkgutil.iter_modules(input_types_pkg.__path__): + importlib.import_module(f"{input_types_pkg.__name__}.{module_info.name}") + + seen: set[type[BaseInputType]] = set() + stack: list[type[BaseInputType]] = [BaseInputType] + while stack: + for sub in stack.pop().__subclasses__(): + if sub not in seen: + seen.add(sub) + stack.append(sub) + # Only audit models shipped in the SDK, not test-only helper subclasses. + return sorted( + (c for c in seen if c.__module__.startswith(input_types_pkg.__name__)), + key=lambda c: c.__name__, + ) + + +def _snake_case_fields() -> list[tuple[str, str, "object"]]: + """All (model_name, field_name, field) tuples for fields whose Python name + differs from a bare API parameter (i.e. contain an underscore).""" + cases = [] + for model in _all_input_models(): + for field_name, field in model.model_fields.items(): + if "_" in field_name: + cases.append((model.__name__, field_name, field)) + return cases + + +_SNAKE_CASE_FIELDS = _snake_case_fields() + + +def test_snake_case_field_discovery_is_not_empty(): + # Guard: if discovery silently breaks, the parametrized test below would + # vacuously pass. days_to_expiration alone guarantees at least one case. + assert _SNAKE_CASE_FIELDS + + +@pytest.mark.parametrize( + "model_name, field_name, field", + _SNAKE_CASE_FIELDS, + ids=[f"{m}.{f}" for m, f, _ in _SNAKE_CASE_FIELDS], +) +def test_snake_case_input_fields_have_api_alias(model_name, field_name, field): + """Every multi-word input field must be sent under an explicit API alias. + + The URL builder serializes with ``by_alias=True``; a snake_case field + without an alias would leak its Python name to the wire (see issue #30, + ``days_to_expiration`` -> ``dte``). Fields that are never serialized to the + query string are exempted via GLOBAL_EXCLUDED_PARAMS. + """ + if field_name in GLOBAL_EXCLUDED_PARAMS: + return + + assert field.alias and field.alias != field_name, ( + f"{model_name}.{field_name} has no API alias; it would be sent to the " + f"API as '{field_name}'. Add an alias or exclude it." + ) + + def test_base_input_type_min_max_validation(): with pytest.raises(MinMaxDateValidationError): DummyInput(min_param="2025-01-01", max_param="2024-01-01") diff --git a/src/tests/test_options_chain.py b/src/tests/test_options_chain.py index a8477e0..8a680d5 100644 --- a/src/tests/test_options_chain.py +++ b/src/tests/test_options_chain.py @@ -334,3 +334,22 @@ def test_options_chain_strike_limit_is_int_on_wire(load_json, respx_mock, client params = respx_mock.calls.last.request.url.params assert params.get("strikeLimit") == "10" + + +def test_options_chain_input_days_to_expiration_alias_on_wire( + load_json, respx_mock, client +): + mock_data = load_json("options_chain_response_200") + respx_mock.get("https://api.marketdata.app/v1/options/chain/AAPL/").respond( + json=mock_data, status_code=200 + ) + + client.options.chain( + "AAPL", + days_to_expiration=30, + output_format=OutputFormat.INTERNAL, + ) + + params = respx_mock.calls.last.request.url.params + assert params.get("dte") == "30" + assert params.get("days_to_expiration") is None