diff --git a/RATapi/inputs.py b/RATapi/inputs.py index 06d9f82c..e362cf6d 100644 --- a/RATapi/inputs.py +++ b/RATapi/inputs.py @@ -179,7 +179,7 @@ def make_problem(project: RATapi.Project) -> ProblemDefinition: for layer in project.layers: layer_params = [ project.parameters.index(getattr(layer, attribute), True) - for attribute in list(layer.model_fields.keys())[1:-2] + for attribute in list(RATapi.models.Layer.model_fields.keys())[1:-2] ] layer_params.append(project.parameters.index(layer.hydration, True) if layer.hydration else float("NaN")) layer_params.append(hydrate_id[layer.hydrate_with]) diff --git a/RATapi/project.py b/RATapi/project.py index ae2ed2cc..0dfbd0a5 100644 --- a/RATapi/project.py +++ b/RATapi/project.py @@ -365,12 +365,12 @@ def model_post_init(self, __context: Any) -> None: and wrap ClassList routines to control revalidation. """ # Ensure all ClassLists have the correct _class_handle defined - for field in (fields := self.model_fields): - type = fields[field].annotation - if get_origin(type) == ClassList: + for field in (fields := Project.model_fields): + annotation = fields[field].annotation + if get_origin(annotation) == ClassList: classlist = getattr(self, field) if not hasattr(field, "_class_handle"): - classlist._class_handle = get_args(type)[0] + classlist._class_handle = get_args(annotation)[0] layers_field = self.layers if not hasattr(layers_field, "_class_handle"): diff --git a/tests/test_project.py b/tests/test_project.py index e82a84da..c0e26685 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -141,7 +141,7 @@ def test_classlists(test_project) -> None: """The ClassLists in the "Project" model should contain instances of the models given by the dictionary "model_in_classlist". """ - for model in (fields := test_project.model_fields): + for model in (fields := RATapi.Project.model_fields): if get_origin(fields[model].annotation) == RATapi.ClassList: class_list = getattr(test_project, model) assert class_list._class_handle == get_args(fields[model].annotation)[0] @@ -1613,7 +1613,7 @@ def test_save_load(project, request): for file in original_project.custom_files: file.path = file.path.resolve() - for field in original_project.model_fields: + for field in RATapi.Project.model_fields: assert getattr(converted_project, field) == getattr(original_project, field)