Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions RATapi/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import os
import tempfile
import warnings
from pathlib import Path
from typing import Union

import prettytable
from pydantic import (
Expand Down Expand Up @@ -220,3 +222,30 @@ def delete_IPC(self):
with contextlib.suppress(FileNotFoundError):
os.remove(self._IPCFilePath)
return None

def save(self, path: Union[str, Path], filename: str = "controls"):
"""Save a controls object to a JSON file.

Parameters
----------
path : str or Path
The directory in which the controls object will be written.
filename : str
The name for the JSON file containing the controls object.

"""
file = Path(path, f"{filename.removesuffix('.json')}.json")
file.write_text(self.model_dump_json())

@classmethod
def load(cls, path: Union[str, Path]) -> "Controls":
"""Load a controls object from file.

Parameters
----------
path : str or Path
The path to the controls object file.

"""
file = Path(path)
return cls.model_validate_json(file.read_text())
70 changes: 70 additions & 0 deletions RATapi/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import collections
import copy
import functools
import json
from enum import Enum
from pathlib import Path
from textwrap import indent
Expand Down Expand Up @@ -834,6 +835,75 @@ def classlist_script(name, classlist):
+ "\n)"
)

def save(self, path: Union[str, Path], filename: str = "project"):
"""Save a project to a JSON file.

Parameters
----------
path : str or Path
The path in which the project will be written.
filename : str
The name of the generated project file.

"""
json_dict = {}
for field in self.model_fields:
attr = getattr(self, field)

if field == "data":

def make_data_dict(item):
return {
"name": item.name,
"data": item.data.tolist(),
"data_range": item.data_range,
"simulation_range": item.simulation_range,
}

json_dict["data"] = [make_data_dict(data) for data in attr]

elif field == "custom_files":

def make_custom_file_dict(item):
return {
"name": item.name,
"filename": item.filename,
"language": item.language,
"path": str(item.path),
}

json_dict["custom_files"] = [make_custom_file_dict(file) for file in attr]

elif isinstance(attr, ClassList):
json_dict[field] = [item.model_dump() for item in attr]
else:
json_dict[field] = attr

file = Path(path, f"{filename.removesuffix('.json')}.json")
file.write_text(json.dumps(json_dict))

@classmethod
def load(cls, path: Union[str, Path]) -> "Project":
"""Load a project from file.

Parameters
----------
path : str or Path
The path to the project file.

"""
input = Path(path).read_text()
model_dict = json.loads(input)
for i in range(0, len(model_dict["data"])):
if model_dict["data"][i]["name"] == "Simulation":
model_dict["data"][i]["data"] = np.empty([0, 3])
del model_dict["data"][i]["data_range"]
else:
data = model_dict["data"][i]["data"]
model_dict["data"][i]["data"] = np.array(data)

return cls.model_validate(model_dict)

def _classlist_wrapper(self, class_list: ClassList, func: Callable):
"""Defines the function used to wrap around ClassList routines to force revalidation.

Expand Down
70 changes: 0 additions & 70 deletions RATapi/utils/convert.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Utilities for converting input files to Python `Project`s."""

import json
import warnings
from collections.abc import Iterable
from os import PathLike
Expand Down Expand Up @@ -553,72 +552,3 @@ def convert_parameters(
eng.save(str(filename), "problem", nargout=0)
eng.exit()
return None


def project_to_json(project: Project) -> str:
"""Write a Project as a JSON file.

Parameters
----------
project : Project
The input Project object to convert.

Returns
-------
str
A string representing the class in JSON format.
"""
json_dict = {}
for field in project.model_fields:
attr = getattr(project, field)

if field == "data":

def make_data_dict(item):
return {
"name": item.name,
"data": item.data.tolist(),
"data_range": item.data_range,
"simulation_range": item.simulation_range,
}

json_dict["data"] = [make_data_dict(data) for data in attr]

elif field == "custom_files":

def make_custom_file_dict(item):
return {"name": item.name, "filename": item.filename, "language": item.language, "path": str(item.path)}

json_dict["custom_files"] = [make_custom_file_dict(file) for file in attr]

elif isinstance(attr, ClassList):
json_dict[field] = [dict(item) for item in attr]
else:
json_dict[field] = attr

return json.dumps(json_dict)


def project_from_json(input: str) -> Project:
"""Read a Project from a JSON string generated by `to_json`.

Parameters
----------
input : str
The JSON input as a string.

Returns
-------
Project
The project corresponding to that JSON input.
"""
model_dict = json.loads(input)
for i in range(0, len(model_dict["data"])):
if model_dict["data"][i]["name"] == "Simulation":
model_dict["data"][i]["data"] = empty([0, 3])
del model_dict["data"][i]["data_range"]
else:
data = model_dict["data"][i]["data"]
model_dict["data"][i]["data"] = array(data)

return Project.model_validate(model_dict)
31 changes: 1 addition & 30 deletions tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest

import RATapi
from RATapi.utils.convert import project_class_to_r1, project_from_json, project_to_json, r1_to_project_class
from RATapi.utils.convert import project_class_to_r1, r1_to_project_class

TEST_DIR_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test_data")

Expand Down Expand Up @@ -110,35 +110,6 @@ def test_invalid_constraints():
assert output_project.background_parameters[0].min == output_project.background_parameters[0].value


@pytest.mark.parametrize(
"project",
[
"r1_default_project",
"r1_monolayer",
"r1_monolayer_8_contrasts",
"r1_orso_polymer",
"r1_motofit_bench_mark",
"dspc_bilayer",
"dspc_standard_layers",
"dspc_custom_layers",
"dspc_custom_xy",
"domains_standard_layers",
"domains_custom_layers",
"domains_custom_xy",
"absorption",
],
)
def test_json_involution(project, request):
"""Test that converting a Project to JSON and back returns the same project."""
original_project = request.getfixturevalue(project)
json_data = project_to_json(original_project)

converted_project = project_from_json(json_data)

for field in RATapi.Project.model_fields:
assert getattr(converted_project, field) == getattr(original_project, field)


@pytest.mark.skipif(importlib.util.find_spec("matlab") is None, reason="Matlab not installed")
@pytest.mark.parametrize("path_type", [os.path.join, pathlib.Path])
def test_matlab_save(path_type, request):
Expand Down
30 changes: 30 additions & 0 deletions tests/test_project.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Test the project module."""

import copy
import tempfile
from pathlib import Path
from typing import Callable

Expand Down Expand Up @@ -1531,3 +1532,32 @@ def test_wrap_extend(test_project, class_list: str, model_type: str, field: str,

# Ensure invalid model was not appended
assert test_attribute == orig_class_list


@pytest.mark.parametrize(
"project",
[
"r1_default_project",
"r1_monolayer",
"r1_monolayer_8_contrasts",
"r1_orso_polymer",
"r1_motofit_bench_mark",
"dspc_standard_layers",
"dspc_custom_layers",
"dspc_custom_xy",
"domains_standard_layers",
"domains_custom_layers",
"domains_custom_xy",
"absorption",
],
)
def test_save_load(project, request):
"""Test that saving and loading a project returns the same project."""
original_project = request.getfixturevalue(project)

with tempfile.TemporaryDirectory() as tmp:
original_project.save(tmp)
converted_project = RATapi.Project.load(Path(tmp, "project.json"))

for field in RATapi.Project.model_fields:
assert getattr(converted_project, field) == getattr(original_project, field)
Loading