Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 62 additions & 3 deletions RATapi/classlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import collections
import contextlib
import importlib
import warnings
from collections.abc import Sequence
from typing import Any, Generic, TypeVar, Union
Expand Down Expand Up @@ -261,9 +262,67 @@ def extend(self, other: Sequence[T]) -> None:
def set_fields(self, index: int, **kwargs) -> None:
"""Assign the values of an existing object's attributes using keyword arguments."""
self._validate_name_field(kwargs)
class_handle = self.data[index].__class__
new_fields = {**self.data[index].__dict__, **kwargs}
self.data[index] = class_handle(**new_fields)
pydantic_object = False

if importlib.util.find_spec("pydantic"):
# Pydantic is installed, so set up a context manager that will
# suppress custom validation errors until all fields have been set.
from pydantic import BaseModel, ValidationError

if isinstance(self.data[index], BaseModel):
pydantic_object = True

# Define a custom context manager
class SuppressCustomValidation(contextlib.AbstractContextManager):
"""Context manager to suppress "value_error" based validation errors in pydantic.

This validation context is necessary because errors can occur whilst individual
model values are set, which are resolved when all of the input values are set.

After the exception is suppressed, execution proceeds with the next
statement following the with statement.

with SuppressCustomValidation():
setattr(self.data[index], key, value)
# Execution still resumes here if the attribute cannot be set
"""

def __init__(self):
pass

def __enter__(self):
pass

def __exit__(self, exctype, excinst, exctb):
# If the return of __exit__ is True or truthy, the exception is suppressed.
# Otherwise, the default behaviour of raising the exception applies.
#
# To suppress errors arising from field and model validators in pydantic,
# we will examine the validation errors raised. If all of the errors
# listed in the exception have the type "value_error", this indicates
# they have arisen from field or model validators and will be suppressed.
# Otherwise, they will be raised.
if exctype is None:
return
if issubclass(exctype, ValidationError) and all(
[error["type"] == "value_error" for error in excinst.errors()]
):
return True
return False

validation_context = SuppressCustomValidation()
else:
validation_context = contextlib.nullcontext()

for key, value in kwargs.items():
with validation_context:
setattr(self.data[index], key, value)

# We have suppressed custom validation errors for pydantic objects.
# We now must revalidate the pydantic model outside the validation context
# to catch any errors that remain after setting all of the fields.
if pydantic_object:
self._class_handle.model_validate(self.data[index])

def get_names(self) -> list[str]:
"""Return a list of the values of the name_field attribute of each class object in the list.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ extend-exclude = ["*.ipynb"]

[tool.ruff.lint]
select = ["E", "F", "UP", "B", "SIM", "I"]
ignore = ["SIM108"]
ignore = ["SIM103", "SIM108"]

[tool.ruff.lint.flake8-pytest-style]
fixture-parentheses = false
Expand Down
22 changes: 22 additions & 0 deletions tests/test_classlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,3 +1005,25 @@ class NestedModel(pydantic.BaseModel):
for submodel, exp_dict in zip(model.submodels, submodels_list):
for key, value in exp_dict.items():
assert getattr(submodel, key) == value

def test_set_pydantic_fields(self):
"""Test that intermediate validation errors for pydantic models are suppressed when using "set_fields"."""
from pydantic import BaseModel, model_validator

class MinMaxModel(BaseModel):
min: float
value: float
max: float

@model_validator(mode="after")
def check_value_in_range(self) -> "MinMaxModel":
if self.value < self.min or self.value > self.max:
raise ValueError(
f"value {self.value} is not within the defined range: {self.min} <= value <= {self.max}"
)
return self

model_list = ClassList([MinMaxModel(min=1, value=2, max=5)])
model_list.set_fields(0, min=3, value=4)

assert model_list == ClassList([MinMaxModel(min=3.0, value=4.0, max=5.0)])
Loading