From a030d9dbe022010c7d3517b09053392b968f7d77 Mon Sep 17 00:00:00 2001 From: alexhroom Date: Wed, 22 Jan 2025 12:53:03 +0000 Subject: [PATCH 1/4] added save and load method to project and controls --- RATapi/controls.py | 26 ++++++++++++++++++++++++++ RATapi/project.py | 26 ++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/RATapi/controls.py b/RATapi/controls.py index fe77e1ff..8c41dc58 100644 --- a/RATapi/controls.py +++ b/RATapi/controls.py @@ -2,6 +2,7 @@ import os import tempfile import warnings +from pathlib import Path import prettytable from pydantic import ( @@ -220,3 +221,28 @@ def delete_IPC(self): with contextlib.suppress(FileNotFoundError): os.remove(self._IPCFilePath) return None + + def save(self, path: str | Path, filename: str = "controls"): + """Save a controls object to a JSON file. + + Parameters + ---------- + path : str or Path + The directory in which the controls object will be written. + + """ + file = Path(path, f"{filename.removesuffix('.json')}.json") + file.write_text(self.model_dump_json()) + + @classmethod + def load(cls, path: str | Path) -> "Controls": + """Load a controls object from file. + + Parameters + ---------- + path : str or Path + The path to the controls object file. + + """ + file = Path(path) + return cls.model_validate_json(file.read_text()) diff --git a/RATapi/project.py b/RATapi/project.py index 176298f5..7920b504 100644 --- a/RATapi/project.py +++ b/RATapi/project.py @@ -23,6 +23,7 @@ import RATapi.models from RATapi.classlist import ClassList +from RATapi.utils.convert import project_from_json, project_to_json from RATapi.utils.custom_errors import custom_pydantic_validation_error from RATapi.utils.enums import Calculations, Geometries, LayerModels, Priors, TypeOptions @@ -834,6 +835,31 @@ def classlist_script(name, classlist): + "\n)" ) + def save(self, path: str | Path, filename: str = "project"): + """Save a project to a JSON file. + + Parameters + ---------- + path : str or Path + The directory in which the project will be written. + + """ + file = Path(path, f"{filename.removesuffix('.json')}.json") + file.write_text(project_to_json(self)) + + @classmethod + def load(cls, path: str | Path) -> "Project": + """Load a project from file. + + Parameters + ---------- + path : str or Path + The path to the project file. + + """ + file = Path(path) + return project_from_json(file.read_text()) + def _classlist_wrapper(self, class_list: ClassList, func: Callable): """Defines the function used to wrap around ClassList routines to force revalidation. From e9d641c323a4f513fb6cbae36fa186d2e5ce8bb2 Mon Sep 17 00:00:00 2001 From: alexhroom Date: Fri, 24 Jan 2025 12:16:06 +0000 Subject: [PATCH 2/4] moved json save/load to the project method --- RATapi/project.py | 58 +++++++++++++++++++++++++++++----- RATapi/utils/convert.py | 70 ----------------------------------------- tests/test_convert.py | 31 +----------------- tests/test_project.py | 29 +++++++++++++++++ 4 files changed, 81 insertions(+), 107 deletions(-) diff --git a/RATapi/project.py b/RATapi/project.py index 7920b504..c2511ed5 100644 --- a/RATapi/project.py +++ b/RATapi/project.py @@ -3,6 +3,7 @@ import collections import copy import functools +import json from enum import Enum from pathlib import Path from textwrap import indent @@ -23,7 +24,6 @@ import RATapi.models from RATapi.classlist import ClassList -from RATapi.utils.convert import project_from_json, project_to_json from RATapi.utils.custom_errors import custom_pydantic_validation_error from RATapi.utils.enums import Calculations, Geometries, LayerModels, Priors, TypeOptions @@ -835,20 +835,55 @@ def classlist_script(name, classlist): + "\n)" ) - def save(self, path: str | Path, filename: str = "project"): + def save(self, path: Union[str, Path], filename: str = "project"): """Save a project to a JSON file. Parameters ---------- path : str or Path - The directory in which the project will be written. + The path in which the project will be written. + filename : str + The name of the generated project file. """ + json_dict = {} + for field in self.model_fields: + attr = getattr(self, field) + + if field == "data": + + def make_data_dict(item): + return { + "name": item.name, + "data": item.data.tolist(), + "data_range": item.data_range, + "simulation_range": item.simulation_range, + } + + json_dict["data"] = [make_data_dict(data) for data in attr] + + elif field == "custom_files": + + def make_custom_file_dict(item): + return { + "name": item.name, + "filename": item.filename, + "language": item.language, + "path": str(item.path), + } + + json_dict["custom_files"] = [make_custom_file_dict(file) for file in attr] + + elif isinstance(attr, ClassList): + json_dict[field] = [dict(item) for item in attr] + else: + json_dict[field] = attr + file = Path(path, f"{filename.removesuffix('.json')}.json") - file.write_text(project_to_json(self)) + file.write_text(json.dumps(json_dict)) @classmethod - def load(cls, path: str | Path) -> "Project": + def load(cls, path: Union[str, Path]) -> "Project": """Load a project from file. Parameters @@ -857,8 +892,17 @@ def load(cls, path: str | Path) -> "Project": The path to the project file. """ - file = Path(path) - return project_from_json(file.read_text()) + input = Path(path).read_text() + model_dict = json.loads(input) + for i in range(0, len(model_dict["data"])): + if model_dict["data"][i]["name"] == "Simulation": + model_dict["data"][i]["data"] = np.empty([0, 3]) + del model_dict["data"][i]["data_range"] + else: + data = model_dict["data"][i]["data"] + model_dict["data"][i]["data"] = np.array(data) + + return cls.model_validate(model_dict) def _classlist_wrapper(self, class_list: ClassList, func: Callable): """Defines the function used to wrap around ClassList routines to force revalidation. diff --git a/RATapi/utils/convert.py b/RATapi/utils/convert.py index 2afdc32e..e831e905 100644 --- a/RATapi/utils/convert.py +++ b/RATapi/utils/convert.py @@ -1,6 +1,5 @@ """Utilities for converting input files to Python `Project`s.""" -import json import warnings from collections.abc import Iterable from os import PathLike @@ -553,72 +552,3 @@ def convert_parameters( eng.save(str(filename), "problem", nargout=0) eng.exit() return None - - -def project_to_json(project: Project) -> str: - """Write a Project as a JSON file. - - Parameters - ---------- - project : Project - The input Project object to convert. - - Returns - ------- - str - A string representing the class in JSON format. - """ - json_dict = {} - for field in project.model_fields: - attr = getattr(project, field) - - if field == "data": - - def make_data_dict(item): - return { - "name": item.name, - "data": item.data.tolist(), - "data_range": item.data_range, - "simulation_range": item.simulation_range, - } - - json_dict["data"] = [make_data_dict(data) for data in attr] - - elif field == "custom_files": - - def make_custom_file_dict(item): - return {"name": item.name, "filename": item.filename, "language": item.language, "path": str(item.path)} - - json_dict["custom_files"] = [make_custom_file_dict(file) for file in attr] - - elif isinstance(attr, ClassList): - json_dict[field] = [dict(item) for item in attr] - else: - json_dict[field] = attr - - return json.dumps(json_dict) - - -def project_from_json(input: str) -> Project: - """Read a Project from a JSON string generated by `to_json`. - - Parameters - ---------- - input : str - The JSON input as a string. - - Returns - ------- - Project - The project corresponding to that JSON input. - """ - model_dict = json.loads(input) - for i in range(0, len(model_dict["data"])): - if model_dict["data"][i]["name"] == "Simulation": - model_dict["data"][i]["data"] = empty([0, 3]) - del model_dict["data"][i]["data_range"] - else: - data = model_dict["data"][i]["data"] - model_dict["data"][i]["data"] = array(data) - - return Project.model_validate(model_dict) diff --git a/tests/test_convert.py b/tests/test_convert.py index bb157c2f..e636fc9a 100644 --- a/tests/test_convert.py +++ b/tests/test_convert.py @@ -8,7 +8,7 @@ import pytest import RATapi -from RATapi.utils.convert import project_class_to_r1, project_from_json, project_to_json, r1_to_project_class +from RATapi.utils.convert import project_class_to_r1, r1_to_project_class TEST_DIR_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test_data") @@ -110,35 +110,6 @@ def test_invalid_constraints(): assert output_project.background_parameters[0].min == output_project.background_parameters[0].value -@pytest.mark.parametrize( - "project", - [ - "r1_default_project", - "r1_monolayer", - "r1_monolayer_8_contrasts", - "r1_orso_polymer", - "r1_motofit_bench_mark", - "dspc_bilayer", - "dspc_standard_layers", - "dspc_custom_layers", - "dspc_custom_xy", - "domains_standard_layers", - "domains_custom_layers", - "domains_custom_xy", - "absorption", - ], -) -def test_json_involution(project, request): - """Test that converting a Project to JSON and back returns the same project.""" - original_project = request.getfixturevalue(project) - json_data = project_to_json(original_project) - - converted_project = project_from_json(json_data) - - for field in RATapi.Project.model_fields: - assert getattr(converted_project, field) == getattr(original_project, field) - - @pytest.mark.skipif(importlib.util.find_spec("matlab") is None, reason="Matlab not installed") @pytest.mark.parametrize("path_type", [os.path.join, pathlib.Path]) def test_matlab_save(path_type, request): diff --git a/tests/test_project.py b/tests/test_project.py index 19ef7aa1..f23fe102 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -1531,3 +1531,32 @@ def test_wrap_extend(test_project, class_list: str, model_type: str, field: str, # Ensure invalid model was not appended assert test_attribute == orig_class_list + + +@pytest.mark.parametrize( + "project", + [ + "r1_default_project", + "r1_monolayer", + "r1_monolayer_8_contrasts", + "r1_orso_polymer", + "r1_motofit_bench_mark", + "dspc_standard_layers", + "dspc_custom_layers", + "dspc_custom_xy", + "domains_standard_layers", + "domains_custom_layers", + "domains_custom_xy", + "absorption", + ], +) +def test_save_load(project, request): + """Test that saving and loading a project returns the same project.""" + original_project = request.getfixturevalue(project) + + with tempfile.TemporaryDirectory() as tmp: + original_project.save(tmp) + converted_project = RATapi.Project.load(Path(tmp, "project.json")) + + for field in RATapi.Project.model_fields: + assert getattr(converted_project, field) == getattr(original_project, field) From 4297b71c75fc0f3b1584c78400d396803facfd4c Mon Sep 17 00:00:00 2001 From: alexhroom Date: Fri, 24 Jan 2025 14:37:48 +0000 Subject: [PATCH 3/4] backwards compatibility --- RATapi/controls.py | 5 +++-- tests/test_project.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/RATapi/controls.py b/RATapi/controls.py index 8c41dc58..4983173b 100644 --- a/RATapi/controls.py +++ b/RATapi/controls.py @@ -3,6 +3,7 @@ import tempfile import warnings from pathlib import Path +from typing import Union import prettytable from pydantic import ( @@ -222,7 +223,7 @@ def delete_IPC(self): os.remove(self._IPCFilePath) return None - def save(self, path: str | Path, filename: str = "controls"): + def save(self, path: Union[str, Path], filename: str = "controls"): """Save a controls object to a JSON file. Parameters @@ -235,7 +236,7 @@ def save(self, path: str | Path, filename: str = "controls"): file.write_text(self.model_dump_json()) @classmethod - def load(cls, path: str | Path) -> "Controls": + def load(cls, path: Union[str, Path]) -> "Controls": """Load a controls object from file. Parameters diff --git a/tests/test_project.py b/tests/test_project.py index f23fe102..ee2136ba 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -1,6 +1,7 @@ """Test the project module.""" import copy +import tempfile from pathlib import Path from typing import Callable From de8b12cbf02a650d24f7272c3c0a5ede3e5cde4b Mon Sep 17 00:00:00 2001 From: alexhroom Date: Mon, 27 Jan 2025 13:56:08 +0000 Subject: [PATCH 4/4] fixed docstring and changed to model dump --- RATapi/controls.py | 2 ++ RATapi/project.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/RATapi/controls.py b/RATapi/controls.py index 4983173b..e874e6f3 100644 --- a/RATapi/controls.py +++ b/RATapi/controls.py @@ -230,6 +230,8 @@ def save(self, path: Union[str, Path], filename: str = "controls"): ---------- path : str or Path The directory in which the controls object will be written. + filename : str + The name for the JSON file containing the controls object. """ file = Path(path, f"{filename.removesuffix('.json')}.json") diff --git a/RATapi/project.py b/RATapi/project.py index c2511ed5..9b38f88e 100644 --- a/RATapi/project.py +++ b/RATapi/project.py @@ -875,7 +875,7 @@ def make_custom_file_dict(item): json_dict["custom_files"] = [make_custom_file_dict(file) for file in attr] elif isinstance(attr, ClassList): - json_dict[field] = [dict(item) for item in attr] + json_dict[field] = [item.model_dump() for item in attr] else: json_dict[field] = attr