diff --git a/RATapi/classlist.py b/RATapi/classlist.py index 2ec2378c..68b0ea3c 100644 --- a/RATapi/classlist.py +++ b/RATapi/classlist.py @@ -4,6 +4,7 @@ import collections import contextlib +import importlib import warnings from collections.abc import Sequence from typing import Any, Generic, TypeVar, Union @@ -261,9 +262,67 @@ def extend(self, other: Sequence[T]) -> None: def set_fields(self, index: int, **kwargs) -> None: """Assign the values of an existing object's attributes using keyword arguments.""" self._validate_name_field(kwargs) - class_handle = self.data[index].__class__ - new_fields = {**self.data[index].__dict__, **kwargs} - self.data[index] = class_handle(**new_fields) + pydantic_object = False + + if importlib.util.find_spec("pydantic"): + # Pydantic is installed, so set up a context manager that will + # suppress custom validation errors until all fields have been set. + from pydantic import BaseModel, ValidationError + + if isinstance(self.data[index], BaseModel): + pydantic_object = True + + # Define a custom context manager + class SuppressCustomValidation(contextlib.AbstractContextManager): + """Context manager to suppress "value_error" based validation errors in pydantic. + + This validation context is necessary because errors can occur whilst individual + model values are set, which are resolved when all of the input values are set. + + After the exception is suppressed, execution proceeds with the next + statement following the with statement. + + with SuppressCustomValidation(): + setattr(self.data[index], key, value) + # Execution still resumes here if the attribute cannot be set + """ + + def __init__(self): + pass + + def __enter__(self): + pass + + def __exit__(self, exctype, excinst, exctb): + # If the return of __exit__ is True or truthy, the exception is suppressed. + # Otherwise, the default behaviour of raising the exception applies. + # + # To suppress errors arising from field and model validators in pydantic, + # we will examine the validation errors raised. If all of the errors + # listed in the exception have the type "value_error", this indicates + # they have arisen from field or model validators and will be suppressed. + # Otherwise, they will be raised. + if exctype is None: + return + if issubclass(exctype, ValidationError) and all( + [error["type"] == "value_error" for error in excinst.errors()] + ): + return True + return False + + validation_context = SuppressCustomValidation() + else: + validation_context = contextlib.nullcontext() + + for key, value in kwargs.items(): + with validation_context: + setattr(self.data[index], key, value) + + # We have suppressed custom validation errors for pydantic objects. + # We now must revalidate the pydantic model outside the validation context + # to catch any errors that remain after setting all of the fields. + if pydantic_object: + self._class_handle.model_validate(self.data[index]) def get_names(self) -> list[str]: """Return a list of the values of the name_field attribute of each class object in the list. diff --git a/pyproject.toml b/pyproject.toml index 31b2279f..d20e5b8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ extend-exclude = ["*.ipynb"] [tool.ruff.lint] select = ["E", "F", "UP", "B", "SIM", "I"] -ignore = ["SIM108"] +ignore = ["SIM103", "SIM108"] [tool.ruff.lint.flake8-pytest-style] fixture-parentheses = false diff --git a/tests/test_classlist.py b/tests/test_classlist.py index 4b19f156..71816a5e 100644 --- a/tests/test_classlist.py +++ b/tests/test_classlist.py @@ -1005,3 +1005,25 @@ class NestedModel(pydantic.BaseModel): for submodel, exp_dict in zip(model.submodels, submodels_list): for key, value in exp_dict.items(): assert getattr(submodel, key) == value + + def test_set_pydantic_fields(self): + """Test that intermediate validation errors for pydantic models are suppressed when using "set_fields".""" + from pydantic import BaseModel, model_validator + + class MinMaxModel(BaseModel): + min: float + value: float + max: float + + @model_validator(mode="after") + def check_value_in_range(self) -> "MinMaxModel": + if self.value < self.min or self.value > self.max: + raise ValueError( + f"value {self.value} is not within the defined range: {self.min} <= value <= {self.max}" + ) + return self + + model_list = ClassList([MinMaxModel(min=1, value=2, max=5)]) + model_list.set_fields(0, min=3, value=4) + + assert model_list == ClassList([MinMaxModel(min=3.0, value=4.0, max=5.0)])