From 9b4fab85fd5c752ffbfe668db2a1eb3a6da7aee6 Mon Sep 17 00:00:00 2001 From: Paul Sharp <44529197+DrPaulSharp@users.noreply.github.com> Date: Thu, 9 Jan 2025 13:31:39 +0000 Subject: [PATCH 1/2] Refactors "set_fields" routine with custom context manager --- RATapi/classlist.py | 62 +++++++++++++++++++++++++++++++++++++++-- pyproject.toml | 2 +- tests/test_classlist.py | 22 +++++++++++++++ 3 files changed, 82 insertions(+), 4 deletions(-) diff --git a/RATapi/classlist.py b/RATapi/classlist.py index 2ec2378c..dcb72099 100644 --- a/RATapi/classlist.py +++ b/RATapi/classlist.py @@ -4,6 +4,7 @@ import collections import contextlib +import importlib import warnings from collections.abc import Sequence from typing import Any, Generic, TypeVar, Union @@ -261,9 +262,64 @@ def extend(self, other: Sequence[T]) -> None: def set_fields(self, index: int, **kwargs) -> None: """Assign the values of an existing object's attributes using keyword arguments.""" self._validate_name_field(kwargs) - class_handle = self.data[index].__class__ - new_fields = {**self.data[index].__dict__, **kwargs} - self.data[index] = class_handle(**new_fields) + pydantic_object = False + + if importlib.util.find_spec("pydantic"): + # Pydantic is installed, so set up a context manager that will + # suppress validation errors until all fields have been set. + from pydantic import BaseModel, ValidationError + + if isinstance(self.data[index], BaseModel): + pydantic_object = True + + # Define a custom context manager + class SuppressCustomValidation(contextlib.AbstractContextManager): + """Context manager to suppress "value_error" based validation errors in pydantic. + + This validation context is necessary because errors can occur whilst individual + model values are set, which are resolved when all of the input values are set. + + After the exception is suppressed, execution proceeds with the next + statement following the with statement. + + with SuppressCustomValidation(): + setattr(self.data[index], key, value) + # Execution still resumes here if the attribute cannot be set + """ + + def __init__(self): + pass + + def __enter__(self): + pass + + def __exit__(self, exctype, excinst, exctb): + # If the return of __exit__ is True or truthy, the exception is suppressed. + # Otherwise, the default behaviour of raising the exception applies. + # + # To suppress errors arising from field and model validators in pydantic, + # we will examine the validation errors raised. If all of the errors + # listed in the exception have the type "value_error", this indicates + # they have arisen from field or model validators and will be suppressed. + # Otherwise, they will be raised. + if exctype is None: + return + if issubclass(exctype, ValidationError) and all( + [error["type"] == "value_error" for error in excinst.errors()] + ): + return True + return False + + validation_context = SuppressCustomValidation() + else: + validation_context = contextlib.nullcontext() + + for key, value in kwargs.items(): + with validation_context: + setattr(self.data[index], key, value) + + if pydantic_object: + self._class_handle.model_validate(self.data[index]) def get_names(self) -> list[str]: """Return a list of the values of the name_field attribute of each class object in the list. diff --git a/pyproject.toml b/pyproject.toml index 31b2279f..d20e5b8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ extend-exclude = ["*.ipynb"] [tool.ruff.lint] select = ["E", "F", "UP", "B", "SIM", "I"] -ignore = ["SIM108"] +ignore = ["SIM103", "SIM108"] [tool.ruff.lint.flake8-pytest-style] fixture-parentheses = false diff --git a/tests/test_classlist.py b/tests/test_classlist.py index 4b19f156..71816a5e 100644 --- a/tests/test_classlist.py +++ b/tests/test_classlist.py @@ -1005,3 +1005,25 @@ class NestedModel(pydantic.BaseModel): for submodel, exp_dict in zip(model.submodels, submodels_list): for key, value in exp_dict.items(): assert getattr(submodel, key) == value + + def test_set_pydantic_fields(self): + """Test that intermediate validation errors for pydantic models are suppressed when using "set_fields".""" + from pydantic import BaseModel, model_validator + + class MinMaxModel(BaseModel): + min: float + value: float + max: float + + @model_validator(mode="after") + def check_value_in_range(self) -> "MinMaxModel": + if self.value < self.min or self.value > self.max: + raise ValueError( + f"value {self.value} is not within the defined range: {self.min} <= value <= {self.max}" + ) + return self + + model_list = ClassList([MinMaxModel(min=1, value=2, max=5)]) + model_list.set_fields(0, min=3, value=4) + + assert model_list == ClassList([MinMaxModel(min=3.0, value=4.0, max=5.0)]) From 41424a8a9c9be47bb35c39d63998ffc3f8961a18 Mon Sep 17 00:00:00 2001 From: Paul Sharp <44529197+DrPaulSharp@users.noreply.github.com> Date: Thu, 9 Jan 2025 14:48:37 +0000 Subject: [PATCH 2/2] Addresses review comments --- RATapi/classlist.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/RATapi/classlist.py b/RATapi/classlist.py index dcb72099..68b0ea3c 100644 --- a/RATapi/classlist.py +++ b/RATapi/classlist.py @@ -266,7 +266,7 @@ def set_fields(self, index: int, **kwargs) -> None: if importlib.util.find_spec("pydantic"): # Pydantic is installed, so set up a context manager that will - # suppress validation errors until all fields have been set. + # suppress custom validation errors until all fields have been set. from pydantic import BaseModel, ValidationError if isinstance(self.data[index], BaseModel): @@ -318,6 +318,9 @@ def __exit__(self, exctype, excinst, exctb): with validation_context: setattr(self.data[index], key, value) + # We have suppressed custom validation errors for pydantic objects. + # We now must revalidate the pydantic model outside the validation context + # to catch any errors that remain after setting all of the fields. if pydantic_object: self._class_handle.model_validate(self.data[index])