diff --git a/src/marketdata/exceptions.py b/src/marketdata/exceptions.py index 1436aba..f44f472 100644 --- a/src/marketdata/exceptions.py +++ b/src/marketdata/exceptions.py @@ -94,5 +94,13 @@ class InvalidStatusDataError(BaseMarketdataException): pass -class MinMaxDateValidationError(BaseMarketdataException): +class MinMaxValidationError(BaseMarketdataException): + pass + + +class MinMaxValueValidationError(MinMaxValidationError): + pass + + +class MinMaxDateValidationError(MinMaxValidationError): pass diff --git a/src/marketdata/input_types/base.py b/src/marketdata/input_types/base.py index c83fde4..259462e 100644 --- a/src/marketdata/input_types/base.py +++ b/src/marketdata/input_types/base.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator -from marketdata.exceptions import MinMaxDateValidationError +from marketdata.exceptions import MinMaxDateValidationError, MinMaxValueValidationError from marketdata.utils import check_is_date BaseModelConfig = ConfigDict(populate_by_name=True, frozen=False) @@ -26,6 +26,17 @@ def _validate_min_max_dates( f"{min_param} must be less than {max_param}" ) + def _validate_min_max_value( + self, min_param: str | None, max_param: str | None + ) -> None: + min_value = getattr(self, min_param) + max_value = getattr(self, max_param) + + if min_value is not None and max_value is not None and min_value > max_value: + raise MinMaxValueValidationError( + f"{min_param} must be less than or equal to {max_param}" + ) + class OutputFormat(str, Enum): DATAFRAME = "dataframe" diff --git a/src/marketdata/input_types/options.py b/src/marketdata/input_types/options.py index b15dba9..c600ed0 100644 --- a/src/marketdata/input_types/options.py +++ b/src/marketdata/input_types/options.py @@ -125,12 +125,12 @@ def validate_expiration( @model_validator(mode="after") def validate_input(self) -> "OptionsChainInput": - params_typles = [ + params_tuples = [ ("min_bid", "max_bid"), ("min_ask", "max_ask"), ] - for min_param, max_param in params_typles: - self._validate_min_max_dates(min_param, max_param) + for min_param, max_param in params_tuples: + self._validate_min_max_value(min_param, max_param) return self diff --git a/src/tests/test_input_types.py b/src/tests/test_input_types.py index 5ce292c..63727b6 100644 --- a/src/tests/test_input_types.py +++ b/src/tests/test_input_types.py @@ -7,12 +7,17 @@ from pydantic import Field, model_validator import marketdata.input_types as input_types_pkg -from marketdata.exceptions import MinMaxDateValidationError +from marketdata.exceptions import ( + MinMaxDateValidationError, + MinMaxValidationError, + MinMaxValueValidationError, +) from marketdata.input_types.base import ( BaseInputType, OutputFormat, UserUniversalAPIParams, ) +from marketdata.input_types.options import OptionsChainInput from marketdata.internal_settings import GLOBAL_EXCLUDED_PARAMS @@ -26,6 +31,16 @@ def validate_input(self) -> "DummyInput": return self +class DummyNumericInput(BaseInputType): + min_param: float | None = Field(default=None) + max_param: float | None = Field(default=None) + + @model_validator(mode="after") + def validate_input(self) -> "DummyNumericInput": + self._validate_min_max_value("min_param", "max_param") + return self + + def _all_input_models() -> list[type[BaseInputType]]: """Return every concrete BaseInputType subclass defined in the SDK. @@ -95,6 +110,52 @@ def test_base_input_type_min_max_validation(): DummyInput(min_param="2025-01-01", max_param="2024-01-01") +def test_base_input_type_min_max_value_validation(): + with pytest.raises(MinMaxValueValidationError): + DummyNumericInput(min_param=5.0, max_param=1.0) + + +def test_base_input_type_min_max_value_valid_range(): + instance = DummyNumericInput(min_param=1.0, max_param=5.0) + assert instance.min_param == 1.0 + assert instance.max_param == 5.0 + + +def test_base_input_type_min_max_value_allows_none(): + # Either bound missing -> no comparison, no error. + assert DummyNumericInput(min_param=5.0).max_param is None + assert DummyNumericInput(max_param=1.0).min_param is None + assert DummyNumericInput().min_param is None + + +def test_min_max_errors_share_common_base(): + # Both specialized errors must be catchable as the common MinMaxValidationError. + assert issubclass(MinMaxDateValidationError, MinMaxValidationError) + assert issubclass(MinMaxValueValidationError, MinMaxValidationError) + + +@pytest.mark.parametrize( + "kwargs", + [ + {"min_bid": 5.0, "max_bid": 1.0}, + {"min_ask": 10.0, "max_ask": 2.0}, + ], +) +def test_options_chain_input_invalid_price_range(kwargs: dict): + with pytest.raises(MinMaxValueValidationError): + OptionsChainInput(symbol="AAPL", **kwargs) + + +def test_options_chain_input_valid_price_range(): + instance = OptionsChainInput( + symbol="AAPL", min_bid=1.0, max_bid=5.0, min_ask=2.0, max_ask=10.0 + ) + assert instance.min_bid == 1.0 + assert instance.max_bid == 5.0 + assert instance.min_ask == 2.0 + assert instance.max_ask == 10.0 + + def test_universal_api_params_api_format(): params = UserUniversalAPIParams(output_format=OutputFormat.DATAFRAME) assert params.api_format == OutputFormat.JSON