diff --git a/RATapi/controls.py b/RATapi/controls.py index fe77e1ff..e874e6f3 100644 --- a/RATapi/controls.py +++ b/RATapi/controls.py @@ -2,6 +2,8 @@ import os import tempfile import warnings +from pathlib import Path +from typing import Union import prettytable from pydantic import ( @@ -220,3 +222,30 @@ def delete_IPC(self): with contextlib.suppress(FileNotFoundError): os.remove(self._IPCFilePath) return None + + def save(self, path: Union[str, Path], filename: str = "controls"): + """Save a controls object to a JSON file. + + Parameters + ---------- + path : str or Path + The directory in which the controls object will be written. + filename : str + The name for the JSON file containing the controls object. + + """ + file = Path(path, f"{filename.removesuffix('.json')}.json") + file.write_text(self.model_dump_json()) + + @classmethod + def load(cls, path: Union[str, Path]) -> "Controls": + """Load a controls object from file. + + Parameters + ---------- + path : str or Path + The path to the controls object file. + + """ + file = Path(path) + return cls.model_validate_json(file.read_text()) diff --git a/RATapi/project.py b/RATapi/project.py index 176298f5..9b38f88e 100644 --- a/RATapi/project.py +++ b/RATapi/project.py @@ -3,6 +3,7 @@ import collections import copy import functools +import json from enum import Enum from pathlib import Path from textwrap import indent @@ -834,6 +835,75 @@ def classlist_script(name, classlist): + "\n)" ) + def save(self, path: Union[str, Path], filename: str = "project"): + """Save a project to a JSON file. + + Parameters + ---------- + path : str or Path + The path in which the project will be written. + filename : str + The name of the generated project file. + + """ + json_dict = {} + for field in self.model_fields: + attr = getattr(self, field) + + if field == "data": + + def make_data_dict(item): + return { + "name": item.name, + "data": item.data.tolist(), + "data_range": item.data_range, + "simulation_range": item.simulation_range, + } + + json_dict["data"] = [make_data_dict(data) for data in attr] + + elif field == "custom_files": + + def make_custom_file_dict(item): + return { + "name": item.name, + "filename": item.filename, + "language": item.language, + "path": str(item.path), + } + + json_dict["custom_files"] = [make_custom_file_dict(file) for file in attr] + + elif isinstance(attr, ClassList): + json_dict[field] = [item.model_dump() for item in attr] + else: + json_dict[field] = attr + + file = Path(path, f"{filename.removesuffix('.json')}.json") + file.write_text(json.dumps(json_dict)) + + @classmethod + def load(cls, path: Union[str, Path]) -> "Project": + """Load a project from file. + + Parameters + ---------- + path : str or Path + The path to the project file. + + """ + input = Path(path).read_text() + model_dict = json.loads(input) + for i in range(0, len(model_dict["data"])): + if model_dict["data"][i]["name"] == "Simulation": + model_dict["data"][i]["data"] = np.empty([0, 3]) + del model_dict["data"][i]["data_range"] + else: + data = model_dict["data"][i]["data"] + model_dict["data"][i]["data"] = np.array(data) + + return cls.model_validate(model_dict) + def _classlist_wrapper(self, class_list: ClassList, func: Callable): """Defines the function used to wrap around ClassList routines to force revalidation. diff --git a/RATapi/utils/convert.py b/RATapi/utils/convert.py index 2afdc32e..e831e905 100644 --- a/RATapi/utils/convert.py +++ b/RATapi/utils/convert.py @@ -1,6 +1,5 @@ """Utilities for converting input files to Python `Project`s.""" -import json import warnings from collections.abc import Iterable from os import PathLike @@ -553,72 +552,3 @@ def convert_parameters( eng.save(str(filename), "problem", nargout=0) eng.exit() return None - - -def project_to_json(project: Project) -> str: - """Write a Project as a JSON file. - - Parameters - ---------- - project : Project - The input Project object to convert. - - Returns - ------- - str - A string representing the class in JSON format. - """ - json_dict = {} - for field in project.model_fields: - attr = getattr(project, field) - - if field == "data": - - def make_data_dict(item): - return { - "name": item.name, - "data": item.data.tolist(), - "data_range": item.data_range, - "simulation_range": item.simulation_range, - } - - json_dict["data"] = [make_data_dict(data) for data in attr] - - elif field == "custom_files": - - def make_custom_file_dict(item): - return {"name": item.name, "filename": item.filename, "language": item.language, "path": str(item.path)} - - json_dict["custom_files"] = [make_custom_file_dict(file) for file in attr] - - elif isinstance(attr, ClassList): - json_dict[field] = [dict(item) for item in attr] - else: - json_dict[field] = attr - - return json.dumps(json_dict) - - -def project_from_json(input: str) -> Project: - """Read a Project from a JSON string generated by `to_json`. - - Parameters - ---------- - input : str - The JSON input as a string. - - Returns - ------- - Project - The project corresponding to that JSON input. - """ - model_dict = json.loads(input) - for i in range(0, len(model_dict["data"])): - if model_dict["data"][i]["name"] == "Simulation": - model_dict["data"][i]["data"] = empty([0, 3]) - del model_dict["data"][i]["data_range"] - else: - data = model_dict["data"][i]["data"] - model_dict["data"][i]["data"] = array(data) - - return Project.model_validate(model_dict) diff --git a/tests/test_convert.py b/tests/test_convert.py index bb157c2f..e636fc9a 100644 --- a/tests/test_convert.py +++ b/tests/test_convert.py @@ -8,7 +8,7 @@ import pytest import RATapi -from RATapi.utils.convert import project_class_to_r1, project_from_json, project_to_json, r1_to_project_class +from RATapi.utils.convert import project_class_to_r1, r1_to_project_class TEST_DIR_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test_data") @@ -110,35 +110,6 @@ def test_invalid_constraints(): assert output_project.background_parameters[0].min == output_project.background_parameters[0].value -@pytest.mark.parametrize( - "project", - [ - "r1_default_project", - "r1_monolayer", - "r1_monolayer_8_contrasts", - "r1_orso_polymer", - "r1_motofit_bench_mark", - "dspc_bilayer", - "dspc_standard_layers", - "dspc_custom_layers", - "dspc_custom_xy", - "domains_standard_layers", - "domains_custom_layers", - "domains_custom_xy", - "absorption", - ], -) -def test_json_involution(project, request): - """Test that converting a Project to JSON and back returns the same project.""" - original_project = request.getfixturevalue(project) - json_data = project_to_json(original_project) - - converted_project = project_from_json(json_data) - - for field in RATapi.Project.model_fields: - assert getattr(converted_project, field) == getattr(original_project, field) - - @pytest.mark.skipif(importlib.util.find_spec("matlab") is None, reason="Matlab not installed") @pytest.mark.parametrize("path_type", [os.path.join, pathlib.Path]) def test_matlab_save(path_type, request): diff --git a/tests/test_project.py b/tests/test_project.py index 19ef7aa1..ee2136ba 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -1,6 +1,7 @@ """Test the project module.""" import copy +import tempfile from pathlib import Path from typing import Callable @@ -1531,3 +1532,32 @@ def test_wrap_extend(test_project, class_list: str, model_type: str, field: str, # Ensure invalid model was not appended assert test_attribute == orig_class_list + + +@pytest.mark.parametrize( + "project", + [ + "r1_default_project", + "r1_monolayer", + "r1_monolayer_8_contrasts", + "r1_orso_polymer", + "r1_motofit_bench_mark", + "dspc_standard_layers", + "dspc_custom_layers", + "dspc_custom_xy", + "domains_standard_layers", + "domains_custom_layers", + "domains_custom_xy", + "absorption", + ], +) +def test_save_load(project, request): + """Test that saving and loading a project returns the same project.""" + original_project = request.getfixturevalue(project) + + with tempfile.TemporaryDirectory() as tmp: + original_project.save(tmp) + converted_project = RATapi.Project.load(Path(tmp, "project.json")) + + for field in RATapi.Project.model_fields: + assert getattr(converted_project, field) == getattr(original_project, field)