Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion RATapi/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def warn_setting_incorrect_properties(self, handler: ValidatorFunctionWrapHandle
f" controls procedure are:\n "
f"{', '.join(fields.get('procedure', []))}\n",
}
custom_error_list = custom_pydantic_validation_error(exc.errors(), custom_error_msgs)
custom_error_list = custom_pydantic_validation_error(exc.errors(include_url=False), custom_error_msgs)
raise ValidationError.from_exception_data(exc.title, custom_error_list, hide_input=True) from None

if isinstance(model_input, validated_self.__class__):
Expand Down
121 changes: 89 additions & 32 deletions RATapi/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,6 @@ def update_renamed_models(self) -> "Project":
for index, param in all_matches:
if param in params:
setattr(project_field[index], param, new_name)
self._all_names = self.get_all_names()
return self

@model_validator(mode="after")
Expand All @@ -566,28 +565,45 @@ def cross_check_model_values(self) -> "Project":
values = ["value_1", "value_2", "value_3", "value_4", "value_5"]
for field in ["backgrounds", "resolutions"]:
self.check_allowed_source(field)
self.check_allowed_values(field, values, getattr(self, f"{field[:-1]}_parameters").get_names())
self.check_allowed_values(
field,
values,
getattr(self, f"{field[:-1]}_parameters").get_names(),
self._all_names[f"{field[:-1]}_parameters"],
)

self.check_allowed_values(
"layers",
["thickness", "SLD", "SLD_real", "SLD_imaginary", "roughness"],
self.parameters.get_names(),
self._all_names["parameters"],
)

self.check_allowed_values("contrasts", ["data"], self.data.get_names())
self.check_allowed_values("contrasts", ["background"], self.backgrounds.get_names())
self.check_allowed_values("contrasts", ["bulk_in"], self.bulk_in.get_names())
self.check_allowed_values("contrasts", ["bulk_out"], self.bulk_out.get_names())
self.check_allowed_values("contrasts", ["scalefactor"], self.scalefactors.get_names())
self.check_allowed_values("contrasts", ["resolution"], self.resolutions.get_names())
self.check_allowed_values("contrasts", ["domain_ratio"], self.domain_ratios.get_names())
self.check_allowed_values("contrasts", ["data"], self.data.get_names(), self._all_names["data"])
self.check_allowed_values(
"contrasts", ["background"], self.backgrounds.get_names(), self._all_names["backgrounds"]
)
self.check_allowed_values("contrasts", ["bulk_in"], self.bulk_in.get_names(), self._all_names["bulk_in"])
self.check_allowed_values("contrasts", ["bulk_out"], self.bulk_out.get_names(), self._all_names["bulk_out"])
self.check_allowed_values(
"contrasts", ["scalefactor"], self.scalefactors.get_names(), self._all_names["scalefactors"]
)
self.check_allowed_values(
"contrasts", ["resolution"], self.resolutions.get_names(), self._all_names["resolutions"]
)
self.check_allowed_values(
"contrasts", ["domain_ratio"], self.domain_ratios.get_names(), self._all_names["domain_ratios"]
)

self.check_contrast_model_allowed_values(
"contrasts",
getattr(self, self._contrast_model_field).get_names(),
self._all_names[self._contrast_model_field],
self._contrast_model_field,
)
self.check_contrast_model_allowed_values("domain_contrasts", self.layers.get_names(), "layers")
self.check_contrast_model_allowed_values(
"domain_contrasts", self.layers.get_names(), self._all_names["layers"], "layers"
)
return self

@model_validator(mode="after")
Expand All @@ -606,6 +622,12 @@ def check_protected_parameters(self) -> "Project":
self._protected_parameters = self.get_all_protected_parameters()
return self

@model_validator(mode="after")
def update_names(self) -> "Project":
"""Following validation, update the list of all parameter names."""
self._all_names = self.get_all_names()
return self

def __str__(self):
output = ""
for key, value in self.__dict__.items():
Expand All @@ -630,7 +652,9 @@ def get_all_protected_parameters(self):
for class_list in parameter_class_lists
}

def check_allowed_values(self, attribute: str, field_list: list[str], allowed_values: list[str]) -> None:
def check_allowed_values(
self, attribute: str, field_list: list[str], allowed_values: list[str], previous_values: list[str]
) -> None:
"""Check the values of the given fields in the given model are in the supplied list of allowed values.

Parameters
Expand All @@ -641,6 +665,8 @@ def check_allowed_values(self, attribute: str, field_list: list[str], allowed_va
The fields of the attribute to be checked for valid values.
allowed_values : list [str]
The list of allowed values for the fields given in field_list.
previous_values : list [str]
The list of allowed values for the fields given in field_list after the previous validation.

Raises
------
Expand All @@ -649,14 +675,22 @@ def check_allowed_values(self, attribute: str, field_list: list[str], allowed_va

"""
class_list = getattr(self, attribute)
for model in class_list:
for index, model in enumerate(class_list):
for field in field_list:
value = getattr(model, field, "")
if value and value not in allowed_values:
raise ValueError(
f'The value "{value}" in the "{field}" field of "{attribute}" must be defined in '
f'"{values_defined_in[f"{attribute}.{field}"]}".',
)
if value in previous_values:
raise ValueError(
f'The value "{value}" used in the "{field}" field of {attribute}[{index}] must be defined '
f'in "{values_defined_in[f"{attribute}.{field}"]}". Please remove "{value}" from '
f'"{attribute}[{index}].{field}" before attempting to delete it.',
)
else:
raise ValueError(
f'The value "{value}" used in the "{field}" field of {attribute}[{index}] must be defined '
f'in "{values_defined_in[f"{attribute}.{field}"]}". Please add "{value}" to '
f'"{values_defined_in[f"{attribute}.{field}"]}" before including it in "{attribute}".',
)

def check_allowed_source(self, attribute: str) -> None:
"""Check that the source of a background or resolution is defined in the relevant field for its type.
Expand All @@ -679,24 +713,37 @@ def check_allowed_source(self, attribute: str) -> None:

"""
class_list = getattr(self, attribute)
for model in class_list:
for index, model in enumerate(class_list):
if model.type == TypeOptions.Constant:
allowed_values = getattr(self, f"{attribute[:-1]}_parameters").get_names()
previous_values = self._all_names[f"{attribute[:-1]}_parameters"]
elif model.type == TypeOptions.Data:
allowed_values = self.data.get_names()
previous_values = self._all_names["data"]
else:
allowed_values = self.custom_files.get_names()
previous_values = self._all_names["custom_files"]

if (value := model.source) != "" and value not in allowed_values:
raise ValueError(
f'The value "{value}" in the "source" field of "{attribute}" must be defined in '
f'"{values_defined_in[f"{attribute}.{model.type}.source"]}".',
)
if value in previous_values:
raise ValueError(
f'The value "{value}" used in the "source" field of {attribute}[{index}] must be defined in '
f'"{values_defined_in[f"{attribute}.{model.type}.source"]}". Please remove "{value}" from '
f'"{attribute}[{index}].source" before attempting to delete it.',
)
else:
raise ValueError(
f'The value "{value}" used in the "source" field of {attribute}[{index}] must be defined in '
f'"{values_defined_in[f"{attribute}.{model.type}.source"]}". Please add "{value}" to '
f'"{values_defined_in[f"{attribute}.{model.type}.source"]}" before including it in '
f'"{attribute}".',
)

def check_contrast_model_allowed_values(
self,
contrast_attribute: str,
allowed_values: list[str],
previous_values: list[str],
allowed_field: str,
) -> None:
"""Ensure the contents of the ``model`` for a contrast or domain contrast exist in the required project fields.
Expand All @@ -707,6 +754,8 @@ def check_contrast_model_allowed_values(
The specific contrast attribute of Project being validated (either "contrasts" or "domain_contrasts").
allowed_values : list [str]
The list of allowed values for the model of the contrast_attribute.
previous_values : list [str]
The list of allowed values for the model of the contrast_attribute after the previous validation.
allowed_field : str
The name of the field in the project in which the allowed_values are defined.

Expand All @@ -717,13 +766,22 @@ def check_contrast_model_allowed_values(

"""
class_list = getattr(self, contrast_attribute)
for contrast in class_list:
model_values = contrast.model
if model_values and not all(value in allowed_values for value in model_values):
raise ValueError(
f'The values: "{", ".join(str(i) for i in model_values)}" in the "model" field of '
f'"{contrast_attribute}" must be defined in "{allowed_field}".',
)
for index, contrast in enumerate(class_list):
if (model_values := contrast.model) and (missing_values := list(set(model_values) - set(allowed_values))):
if all(value in previous_values for value in model_values):
raise ValueError(
f"The value{'s' if len(missing_values) > 1 else ''}: "
f'"{", ".join(str(i) for i in missing_values)}" used in the "model" field of '
f'{contrast_attribute}[{index}] must be defined in "{allowed_field}". Please remove all '
f'unnecessary values from "model" before attempting to delete them.',
)
else:
raise ValueError(
f"The value{'s' if len(missing_values) > 1 else ''}: "
f'"{", ".join(str(i) for i in missing_values)}" used in the "model" field of '
f'{contrast_attribute}[{index}] must be defined in "{allowed_field}". Please add all '
f'required values to "{allowed_field}" before including them in "{contrast_attribute}".',
)

def get_contrast_model_field(self):
"""Get the field used to define the contents of the "model" field in contrasts.
Expand Down Expand Up @@ -945,7 +1003,7 @@ def wrapped_func(*args, **kwargs):
Project.model_validate(self)
except ValidationError as exc:
class_list.data = previous_state
custom_error_list = custom_pydantic_validation_error(exc.errors())
custom_error_list = custom_pydantic_validation_error(exc.errors(include_url=False))
raise ValidationError.from_exception_data(exc.title, custom_error_list, hide_input=True) from None
except (TypeError, ValueError):
class_list.data = previous_state
Expand Down Expand Up @@ -980,9 +1038,8 @@ def try_relative_to(path: Path, relative_to: Path) -> str:
else:
warnings.warn(
"Could not save custom file path as relative to the project directory, "
"which means that it may not work on other devices."
"If you would like to share your project, make sure your custom files "
"are in a subfolder of the project save location.",
"which means that it may not work on other devices. If you would like to share your project, "
"make sure your custom files are in a subfolder of the project save location.",
stacklevel=2,
)
return str(path.resolve())
Loading