From ad8471368fd7454f7e0f99160afbe4de5e2c0c13 Mon Sep 17 00:00:00 2001 From: rafa-guedes Date: Wed, 18 Feb 2026 17:40:36 +1300 Subject: [PATCH 1/3] Add a generic pydantic base class that forbid extra arguments from being passed --- src/oceanum/_base.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 src/oceanum/_base.py diff --git a/src/oceanum/_base.py b/src/oceanum/_base.py new file mode 100644 index 0000000..fbeb304 --- /dev/null +++ b/src/oceanum/_base.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +"""Base classes for oceanum models.""" + +from pydantic import BaseModel, ConfigDict + + +class StrictBaseModel(BaseModel): + """ + Base model with strict validation that forbids extra fields. + + This prevents silent failures when users make typos in field names. + Instead of ignoring unknown fields, a ValidationError is raised. + """ + + model_config = ConfigDict(extra="forbid") From 5c3a125975e70813b2e4971e1b7488a94b64abba Mon Sep 17 00:00:00 2001 From: rafa-guedes Date: Wed, 18 Feb 2026 17:41:26 +1300 Subject: [PATCH 2/3] Use the new strict base model across the codebase --- src/oceanum/cli/models.py | 11 ++++++----- src/oceanum/cli/renderer.py | 3 ++- src/oceanum/datamesh/datasource.py | 8 ++++---- src/oceanum/datamesh/query.py | 29 +++++++++++------------------ src/oceanum/datamesh/session.py | 4 ++-- 5 files changed, 25 insertions(+), 30 deletions(-) diff --git a/src/oceanum/cli/models.py b/src/oceanum/cli/models.py index 956d034..8a5b303 100644 --- a/src/oceanum/cli/models.py +++ b/src/oceanum/cli/models.py @@ -8,12 +8,13 @@ from typing_extensions import Self import json -from pydantic import BaseModel, Field +from pydantic import Field +from oceanum._base import StrictBaseModel from . import utils -class DeviceCodeResponse(BaseModel): +class DeviceCodeResponse(StrictBaseModel): device_code: str user_code: str verification_uri: str @@ -21,7 +22,7 @@ class DeviceCodeResponse(BaseModel): interval: int verification_uri_complete: str -class TokenResponse(BaseModel): +class TokenResponse(StrictBaseModel): access_token: str id_token: str|None = None refresh_token: str|None = None @@ -77,11 +78,11 @@ def delete(self) -> bool: return True return False -class Auth0Config(BaseModel): +class Auth0Config(StrictBaseModel): domain: str client_id: str -class ContextObject(BaseModel): +class ContextObject(StrictBaseModel): domain: str token: TokenResponse|None=None auth0: Auth0Config|None=None diff --git a/src/oceanum/cli/renderer.py b/src/oceanum/cli/renderer.py index eb548e3..c15aeed 100644 --- a/src/oceanum/cli/renderer.py +++ b/src/oceanum/cli/renderer.py @@ -17,6 +17,7 @@ from tabulate import tabulate from pydantic import BaseModel +from oceanum._base import StrictBaseModel _sty = click.style @@ -27,7 +28,7 @@ help='Output format' ) -class RenderField(BaseModel): +class RenderField(StrictBaseModel): default: Any|None = None label: str = 'Name' path: str = '$.name' diff --git a/src/oceanum/datamesh/datasource.py b/src/oceanum/datamesh/datasource.py index 9352bbe..348a916 100644 --- a/src/oceanum/datamesh/datasource.py +++ b/src/oceanum/datamesh/datasource.py @@ -11,7 +11,6 @@ import warnings from pydantic import ( ConfigDict, - BaseModel, Field, AnyHttpUrl, PrivateAttr, @@ -19,6 +18,7 @@ BeforeValidator, field_validator, ) +from oceanum._base import StrictBaseModel from pydantic_core import core_schema from pydantic.json import timedelta_isoformat from typing_extensions import Annotated @@ -138,7 +138,7 @@ def __get_pydantic_json_schema__(cls, schema, handler): ] -class Schema(BaseModel): +class Schema(StrictBaseModel): attrs: Optional[dict] = Field(title="Global attributes", default={}) dims: Optional[dict] = Field(title="Dimensions", default={}) coords: Optional[dict] = Field(title="Coordinates", default={}) @@ -191,7 +191,7 @@ class Coordinates(Enum): } -class Datasource(BaseModel): +class Datasource(StrictBaseModel): """Datasource""" id: str = Field( @@ -325,7 +325,7 @@ class Datasource(BaseModel): _detail: bool = PrivateAttr(default=False) # TODO[pydantic]: The following keys were removed: `json_encoders`. # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. - model_config = ConfigDict(use_enum_values=True, validate_assignment=True) + model_config = ConfigDict(extra="forbid", use_enum_values=True, validate_assignment=True) @field_validator("id") @classmethod diff --git a/src/oceanum/datamesh/query.py b/src/oceanum/datamesh/query.py index 9ee6e72..8dc50fc 100644 --- a/src/oceanum/datamesh/query.py +++ b/src/oceanum/datamesh/query.py @@ -6,11 +6,11 @@ from pydantic import ( field_validator, ConfigDict, - BaseModel, Field, BeforeValidator, WithJsonSchema, ) +from oceanum._base import StrictBaseModel from typing import Optional, Dict, Union, List, Any from typing_extensions import Annotated from enum import Enum @@ -160,14 +160,14 @@ class ResampleType(str, Enum): slinear = "linear" -class FilterGeometry(BaseModel): +class FilterGeometry(StrictBaseModel): id: str = Field(title="Datasource ID") parameters: Optional[Dict] = Field( title="Optional parameters to access datasource", default={} ) -class GeoFilter(BaseModel): +class GeoFilter(StrictBaseModel): """GeoFilter class Describes a spatial subset or interpolation """ @@ -225,7 +225,7 @@ def validate_geom(cls, v): return v -class LevelFilter(BaseModel): +class LevelFilter(StrictBaseModel): """LevelFilter class Describes a vertical subset or interpolation """ @@ -253,7 +253,7 @@ class LevelFilter(BaseModel): ) -class TimeFilter(BaseModel): +class TimeFilter(StrictBaseModel): """TimeFilter class Describes a temporal subset or interpolation """ @@ -306,7 +306,7 @@ class AggregateOps(str, Enum): sum = "sum" -class Aggregate(BaseModel): +class Aggregate(StrictBaseModel): operations: List[AggregateOps] = Field( title="Aggregate operations to perform", default=[AggregateOps.mean], @@ -325,7 +325,7 @@ class Aggregate(BaseModel): ) -class Function(BaseModel): +class Function(StrictBaseModel): id: str = Field(title="Function id") args: Dict[str, Any] = Field(title="function arguments") vselect: Optional[List[str]] = Field( @@ -345,14 +345,14 @@ class Function(BaseModel): # df ∩ features -> subset of df within (resolution) of features -class CoordSelector(BaseModel): +class CoordSelector(StrictBaseModel): coord: str = Field(title="Coordinate name") values: List[str | int | float] = Field( title="List of coordinate values to select by" ) -class Query(BaseModel): +class Query(StrictBaseModel): """ Datamesh query """ @@ -416,14 +416,7 @@ def __hash__(self): return hash(self.model_dump_json(warnings=False)) -class Workspace(BaseModel): - data: List[Query] = Field(title="Datamesh queries") - id: Optional[str] = Field(title="Unique ID of this package", default=None) - name: Optional[str] = Field(title="Package name", default="OceanQL package") - description: Optional[str] = Field(title="Package description", default="") - - -class Workspace(BaseModel): +class Workspace(StrictBaseModel): data: List[Query] = Field(title="Datamesh queries") id: Optional[str] = Field(title="Unique ID of this package", default=None) name: Optional[str] = Field(title="Package name", default="OceanQL package") @@ -436,7 +429,7 @@ class Container(str, Enum): Dataset = "dataset" -class Stage(BaseModel): +class Stage(StrictBaseModel): query: Query = Field(title="OceanQL query") qhash: str = Field(title="Query hash") formats: List[str] = Field(title="Available download formats") diff --git a/src/oceanum/datamesh/session.py b/src/oceanum/datamesh/session.py index cc5a285..7c5c632 100644 --- a/src/oceanum/datamesh/session.py +++ b/src/oceanum/datamesh/session.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from oceanum._base import StrictBaseModel from typing import Optional from datetime import datetime, timedelta from .exceptions import DatameshConnectError, DatameshSessionError @@ -7,7 +7,7 @@ import os -class Session(BaseModel): +class Session(StrictBaseModel): id: str user: str creation_time: datetime From 5f70d66d98e8f5d20dfb7d7b8cdab231bc7b68ce Mon Sep 17 00:00:00 2001 From: rafa-guedes Date: Wed, 18 Feb 2026 17:41:52 +1300 Subject: [PATCH 3/3] Add tests for the stricted behaviour --- tests/test_strict_base_model.py | 225 ++++++++++++++++++++++++++++++++ 1 file changed, 225 insertions(+) create mode 100644 tests/test_strict_base_model.py diff --git a/tests/test_strict_base_model.py b/tests/test_strict_base_model.py new file mode 100644 index 0000000..16f0b31 --- /dev/null +++ b/tests/test_strict_base_model.py @@ -0,0 +1,225 @@ +# -*- coding: utf-8 -*- +"""Tests for StrictBaseModel extra='forbid' behavior.""" + +import pytest +from pydantic import ValidationError + +from oceanum._base import StrictBaseModel +from oceanum.datamesh.query import Query, GeoFilter, TimeFilter, Aggregate, Function +from oceanum.datamesh.datasource import Datasource, Schema +from oceanum.datamesh.session import Session +from oceanum.cli.models import TokenResponse, Auth0Config + + +class TestStrictBaseModel: + """Test that StrictBaseModel forbids extra fields.""" + + def test_strict_base_model_forbids_extra_fields(self): + """Test that StrictBaseModel raises ValidationError for extra fields.""" + + class TestModel(StrictBaseModel): + name: str + value: int + + # Valid instantiation should work + model = TestModel(name="test", value=42) + assert model.name == "test" + assert model.value == 42 + + # Extra field should raise ValidationError + with pytest.raises(ValidationError) as exc_info: + TestModel(name="test", value=42, extra_field="should fail") + + # Verify the error mentions the extra field + assert "extra_field" in str(exc_info.value) + + def test_strict_base_model_catches_typos(self): + """Test that typos in field names are caught.""" + + class TestModel(StrictBaseModel): + datasource: str + description: str = "" + + # Correct field names work + model = TestModel(datasource="test-source", description="A test") + assert model.datasource == "test-source" + + # Typo in field name raises ValidationError + with pytest.raises(ValidationError) as exc_info: + TestModel(datasource="test-source", descrption="typo in description") + + assert "descrption" in str(exc_info.value) + + +class TestQueryExtraForbid: + """Test that Query class forbids extra fields.""" + + def test_query_forbids_extra_fields(self): + """Test that Query raises ValidationError for extra fields.""" + # Valid Query + query = Query(datasource="test-datasource") + assert query.datasource == "test-datasource" + + # Extra field should raise ValidationError + with pytest.raises(ValidationError) as exc_info: + Query(datasource="test-datasource", unknown_param="value") + + assert "unknown_param" in str(exc_info.value) + + def test_query_catches_common_typos(self): + """Test that common typos in Query fields are caught.""" + # Typo: 'varaibles' instead of 'variables' + with pytest.raises(ValidationError) as exc_info: + Query(datasource="test-datasource", varaibles=["temp"]) + + assert "varaibles" in str(exc_info.value) + + # Typo: 'timeFilter' instead of 'timefilter' + with pytest.raises(ValidationError) as exc_info: + Query(datasource="test-datasource", timeFilter={}) + + assert "timeFilter" in str(exc_info.value) + + +class TestGeoFilterExtraForbid: + """Test that GeoFilter class forbids extra fields.""" + + def test_geofilter_forbids_extra_fields(self): + """Test that GeoFilter raises ValidationError for extra fields.""" + # Valid GeoFilter + geofilter = GeoFilter(geom=[0, 0, 10, 10]) + assert geofilter.geom == [0, 0, 10, 10] + + # Extra field should raise ValidationError + with pytest.raises(ValidationError) as exc_info: + GeoFilter(geom=[0, 0, 10, 10], unknown_field="value") + + assert "unknown_field" in str(exc_info.value) + + +class TestAggregateExtraForbid: + """Test that Aggregate class forbids extra fields.""" + + def test_aggregate_forbids_extra_fields(self): + """Test that Aggregate raises ValidationError for extra fields.""" + # Valid Aggregate + agg = Aggregate() + assert agg.spatial is True + + # Extra field should raise ValidationError + with pytest.raises(ValidationError) as exc_info: + Aggregate(spatail=True) # typo: 'spatail' instead of 'spatial' + + assert "spatail" in str(exc_info.value) + + +class TestFunctionExtraForbid: + """Test that Function class forbids extra fields.""" + + def test_function_forbids_extra_fields(self): + """Test that Function raises ValidationError for extra fields.""" + # Valid Function + func = Function(id="test-func", args={"param": 1}) + assert func.id == "test-func" + + # Extra field should raise ValidationError + with pytest.raises(ValidationError) as exc_info: + Function(id="test-func", args={}, unknown="value") + + assert "unknown" in str(exc_info.value) + + +class TestDatasourceExtraForbid: + """Test that Datasource class forbids extra fields.""" + + def test_datasource_forbids_extra_fields(self): + """Test that Datasource raises ValidationError for extra fields.""" + # Extra field should raise ValidationError + with pytest.raises(ValidationError) as exc_info: + Datasource( + id="test-ds", + name="Test Datasource", + driver="onzarr", + unknown_field="value" + ) + + assert "unknown_field" in str(exc_info.value) + + def test_datasource_catches_typos(self): + """Test that typos in Datasource fields are caught.""" + # Typo: 'discription' instead of 'description' + with pytest.raises(ValidationError) as exc_info: + Datasource( + id="test-ds", + name="Test Datasource", + driver="onzarr", + discription="typo" + ) + + assert "discription" in str(exc_info.value) + + +class TestSchemaExtraForbid: + """Test that Schema class forbids extra fields.""" + + def test_schema_forbids_extra_fields(self): + """Test that Schema raises ValidationError for extra fields.""" + # Valid Schema + schema = Schema(attrs={"title": "Test"}) + assert schema.attrs == {"title": "Test"} + + # Extra field should raise ValidationError + with pytest.raises(ValidationError) as exc_info: + Schema(attrs={}, unknown_field="value") + + assert "unknown_field" in str(exc_info.value) + + +class TestSessionExtraForbid: + """Test that Session class forbids extra fields.""" + + def test_session_forbids_extra_fields(self): + """Test that Session raises ValidationError for extra fields.""" + from datetime import datetime + + # Extra field should raise ValidationError + with pytest.raises(ValidationError) as exc_info: + Session( + id="test-session", + user="test-user", + creation_time=datetime.now(), + end_time=datetime.now(), + write=False, + unknown_field="value" + ) + + assert "unknown_field" in str(exc_info.value) + + +class TestCLIModelsExtraForbid: + """Test that CLI models forbid extra fields.""" + + def test_auth0_config_forbids_extra_fields(self): + """Test that Auth0Config raises ValidationError for extra fields.""" + # Valid Auth0Config + config = Auth0Config(domain="test.auth0.com", client_id="abc123") + assert config.domain == "test.auth0.com" + + # Extra field should raise ValidationError + with pytest.raises(ValidationError) as exc_info: + Auth0Config(domain="test.auth0.com", client_id="abc123", extra="value") + + assert "extra" in str(exc_info.value) + + def test_token_response_forbids_extra_fields(self): + """Test that TokenResponse raises ValidationError for extra fields.""" + # Extra field should raise ValidationError + with pytest.raises(ValidationError) as exc_info: + TokenResponse( + access_token="token123", + expires_in=3600, + token_type="Bearer", + unknown_field="value" + ) + + assert "unknown_field" in str(exc_info.value)