From 66acc5dc95e1e1e3a5fae72bacd2ae977a80945b Mon Sep 17 00:00:00 2001 From: alexhroom Date: Thu, 19 Dec 2024 11:52:34 +0000 Subject: [PATCH 1/3] makes model and classlist printing more flexible via a display_fields property --- RATapi/classlist.py | 34 +++++++++++++++++++++-- RATapi/models.py | 21 +++++++++++++-- tests/test_classlist.py | 38 ++++++++++++++++++++++++++ tests/test_project.py | 60 ++++++++++++++++++++--------------------- 4 files changed, 119 insertions(+), 34 deletions(-) diff --git a/RATapi/classlist.py b/RATapi/classlist.py index 68b0ea3c..10c4975f 100644 --- a/RATapi/classlist.py +++ b/RATapi/classlist.py @@ -57,10 +57,40 @@ def __init__(self, init_list: Union[Sequence[T], T] = None, name_field: str = "n super().__init__(init_list) def __str__(self): + # `display_fields` gives more control over the items displayed from the list if available + if not self.data: + return str([]) try: - [model.__dict__ for model in self.data] + model_display_fields = [model.display_fields for model in self.data] + # get all items included in at least one list + # the list comprehension ensures they are in the order that they're in in the model + required_fields = list(set().union(*model_display_fields)) + table_fields = ["index"] + [i for i in list(self.data[0].__dict__) if i in required_fields] except AttributeError: - output = str(self.data) + try: + model_display_fields = [model.__dict__ for model in self.data] + table_fields = ["index"] + list(self.data[0].__dict__) + except AttributeError: + return str(self.data) + + if any(model_display_fields): + table = prettytable.PrettyTable() + table.field_names = [field.replace("_", " ") for field in table_fields] + rows = [] + for index, model in enumerate(self.data): + row = [index] + for field in table_fields[1:]: + value = getattr(model, field, "") + if isinstance(value, np.ndarray): + value = f"{'Data array: ['+' x '.join(str(i) for i in value.shape) if value.size > 0 else '['}]" + elif field == "model": + value = "\n".join(str(element) for element in value) + else: + value = str(value) + row.append(value) + rows.append(row) + table.add_rows(rows) + output = table.get_string() else: if any(model.__dict__ for model in self.data): table = prettytable.PrettyTable() diff --git a/RATapi/models.py b/RATapi/models.py index 1274a2a8..b0ecbf73 100644 --- a/RATapi/models.py +++ b/RATapi/models.py @@ -40,10 +40,15 @@ def __repr__(self): def __str__(self): table = prettytable.PrettyTable() - table.field_names = [key.replace("_", " ") for key in self.__dict__] - table.add_row(list(self.__dict__.values())) + table.field_names = [key.replace("_", " ") for key in self.display_fields] + table.add_row(list(self.display_fields.values())) return table.get_string() + @property + def display_fields(self) -> dict: + """A dictionary of which fields should be displayed by this model and their values.""" + return self.__dict__ + class Signal(RATModel): """Base model for background & resolution signals.""" @@ -525,6 +530,8 @@ class Parameter(RATModel): mu: float = 0.0 sigma: float = np.inf + show_priors: bool = False + @model_validator(mode="after") def check_min_max(self) -> "Parameter": """The maximum value of a parameter must be greater than the minimum.""" @@ -539,6 +546,16 @@ def check_value_in_range(self) -> "Parameter": raise ValueError(f"value {self.value} is not within the defined range: {self.min} <= value <= {self.max}") return self + @property + def display_fields(self) -> dict: + visible_fields = ["name", "min", "value", "max", "fit"] + if self.show_priors: + visible_fields.append("prior_type") + if self.prior_type == Priors.Gaussian: + visible_fields.extend(["mu", "sigma"]) + + return {f: getattr(self, f) for f in visible_fields} + class ProtectedParameter(Parameter): """A Parameter with a fixed name.""" diff --git a/tests/test_classlist.py b/tests/test_classlist.py index 71816a5e..2a9aa881 100644 --- a/tests/test_classlist.py +++ b/tests/test_classlist.py @@ -7,6 +7,7 @@ from collections.abc import Iterable, Sequence from typing import Any, Union +import prettytable import pytest from RATapi.classlist import ClassList @@ -44,6 +45,23 @@ def three_name_class_list(): return ClassList([InputAttributes(name="Alice"), InputAttributes(name="Bob"), InputAttributes(name="Eve")]) +class DisplayFieldsClass: + """A classlist with four attributes and a display_fields property.""" + + def __init__(self, display_range): + self.a = 1 + self.b = 2 + self.c = 3 + self.d = 4 + + self.display_range = display_range + + @property + def display_fields(self): + fields = ["a", "b", "c", "d"][self.display_range[0] : self.display_range[1]] + return {f: getattr(self, f) for f in fields} + + class TestInitialisation: @pytest.mark.parametrize( "input_object", @@ -174,6 +192,26 @@ def test_str_empty_classlist() -> None: assert str(ClassList()) == str([]) +@pytest.mark.parametrize( + "display_ranges, expected_header", + ( + ([(1, 3), (1, 3), (1, 3)], ["b", "c"]), + ([(1, 2), (0, 4), (2, 3)], ["a", "b", "c", "d"]), + ([(0, 2), (0, 1), (2, 3)], ["a", "b", "c"]), + ), +) +def test_str_display_fields(display_ranges, expected_header): + """If a class has the `display_fields` property, the ClassList should print with the minimal required attributes.""" + class_list = ClassList([DisplayFieldsClass(dr) for dr in display_ranges]) + expected_table = prettytable.PrettyTable() + expected_table.field_names = ["index"] + expected_header + expected_vals = {"a": 1, "b": 2, "c": 3, "d": 4} + row = [expected_vals[v] for v in expected_header] + expected_table.add_rows([[0] + row, [1] + row, [2] + row]) + + assert str(class_list) == expected_table.get_string() + + @pytest.mark.parametrize( ["new_item", "expected_classlist"], [ diff --git a/tests/test_project.py b/tests/test_project.py index 5cc0cae8..8141d6e5 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -80,35 +80,35 @@ def default_project_str(): "Geometry: ------------------------------------------------------------------------------------------\n\n" "air/substrate\n\n" "Parameters: ----------------------------------------------------------------------------------------\n\n" - "+-------+---------------------+-----+-------+-----+------+------------+-----+-------+\n" - "| index | name | min | value | max | fit | prior type | mu | sigma |\n" - "+-------+---------------------+-----+-------+-----+------+------------+-----+-------+\n" - "| 0 | Substrate Roughness | 1.0 | 3.0 | 5.0 | True | uniform | 0.0 | inf |\n" - "+-------+---------------------+-----+-------+-----+------+------------+-----+-------+\n\n" + "+-------+---------------------+-----+-------+-----+------+\n" + "| index | name | min | value | max | fit |\n" + "+-------+---------------------+-----+-------+-----+------+\n" + "| 0 | Substrate Roughness | 1.0 | 3.0 | 5.0 | True |\n" + "+-------+---------------------+-----+-------+-----+------+\n\n" "Bulk In: -------------------------------------------------------------------------------------------\n\n" - "+-------+---------+-----+-------+-----+-------+------------+-----+-------+\n" - "| index | name | min | value | max | fit | prior type | mu | sigma |\n" - "+-------+---------+-----+-------+-----+-------+------------+-----+-------+\n" - "| 0 | SLD Air | 0.0 | 0.0 | 0.0 | False | uniform | 0.0 | inf |\n" - "+-------+---------+-----+-------+-----+-------+------------+-----+-------+\n\n" + "+-------+---------+-----+-------+-----+-------+\n" + "| index | name | min | value | max | fit |\n" + "+-------+---------+-----+-------+-----+-------+\n" + "| 0 | SLD Air | 0.0 | 0.0 | 0.0 | False |\n" + "+-------+---------+-----+-------+-----+-------+\n\n" "Bulk Out: ------------------------------------------------------------------------------------------\n\n" - "+-------+---------+---------+----------+----------+-------+------------+-----+-------+\n" - "| index | name | min | value | max | fit | prior type | mu | sigma |\n" - "+-------+---------+---------+----------+----------+-------+------------+-----+-------+\n" - "| 0 | SLD D2O | 6.2e-06 | 6.35e-06 | 6.35e-06 | False | uniform | 0.0 | inf |\n" - "+-------+---------+---------+----------+----------+-------+------------+-----+-------+\n\n" + "+-------+---------+---------+----------+----------+-------+\n" + "| index | name | min | value | max | fit |\n" + "+-------+---------+---------+----------+----------+-------+\n" + "| 0 | SLD D2O | 6.2e-06 | 6.35e-06 | 6.35e-06 | False |\n" + "+-------+---------+---------+----------+----------+-------+\n\n" "Scalefactors: --------------------------------------------------------------------------------------\n\n" - "+-------+---------------+------+-------+------+-------+------------+-----+-------+\n" - "| index | name | min | value | max | fit | prior type | mu | sigma |\n" - "+-------+---------------+------+-------+------+-------+------------+-----+-------+\n" - "| 0 | Scalefactor 1 | 0.02 | 0.23 | 0.25 | False | uniform | 0.0 | inf |\n" - "+-------+---------------+------+-------+------+-------+------------+-----+-------+\n\n" + "+-------+---------------+------+-------+------+-------+\n" + "| index | name | min | value | max | fit |\n" + "+-------+---------------+------+-------+------+-------+\n" + "| 0 | Scalefactor 1 | 0.02 | 0.23 | 0.25 | False |\n" + "+-------+---------------+------+-------+------+-------+\n\n" "Background Parameters: -----------------------------------------------------------------------------\n\n" - "+-------+--------------------+-------+-------+-------+-------+------------+-----+-------+\n" - "| index | name | min | value | max | fit | prior type | mu | sigma |\n" - "+-------+--------------------+-------+-------+-------+-------+------------+-----+-------+\n" - "| 0 | Background Param 1 | 1e-07 | 1e-06 | 1e-05 | False | uniform | 0.0 | inf |\n" - "+-------+--------------------+-------+-------+-------+-------+------------+-----+-------+\n\n" + "+-------+--------------------+-------+-------+-------+-------+\n" + "| index | name | min | value | max | fit |\n" + "+-------+--------------------+-------+-------+-------+-------+\n" + "| 0 | Background Param 1 | 1e-07 | 1e-06 | 1e-05 | False |\n" + "+-------+--------------------+-------+-------+-------+-------+\n\n" "Backgrounds: ---------------------------------------------------------------------------------------\n\n" "+-------+--------------+----------+--------------------+---------+---------+---------+---------+---------+\n" "| index | name | type | source | value 1 | value 2 | value 3 | value 4 | value 5 |\n" @@ -116,11 +116,11 @@ def default_project_str(): "| 0 | Background 1 | constant | Background Param 1 | | | | | |\n" "+-------+--------------+----------+--------------------+---------+---------+---------+---------+---------+\n\n" "Resolution Parameters: -----------------------------------------------------------------------------\n\n" - "+-------+--------------------+------+-------+------+-------+------------+-----+-------+\n" - "| index | name | min | value | max | fit | prior type | mu | sigma |\n" - "+-------+--------------------+------+-------+------+-------+------------+-----+-------+\n" - "| 0 | Resolution Param 1 | 0.01 | 0.03 | 0.05 | False | uniform | 0.0 | inf |\n" - "+-------+--------------------+------+-------+------+-------+------------+-----+-------+\n\n" + "+-------+--------------------+------+-------+------+-------+\n" + "| index | name | min | value | max | fit |\n" + "+-------+--------------------+------+-------+------+-------+\n" + "| 0 | Resolution Param 1 | 0.01 | 0.03 | 0.05 | False |\n" + "+-------+--------------------+------+-------+------+-------+\n\n" "Resolutions: ---------------------------------------------------------------------------------------\n\n" "+-------+--------------+----------+--------------------+---------+---------+---------+---------+---------+\n" "| index | name | type | source | value 1 | value 2 | value 3 | value 4 | value 5 |\n" From 7d2b81be7553e72b3f9127576d30cd117a4d2537 Mon Sep 17 00:00:00 2001 From: alexhroom Date: Tue, 7 Jan 2025 11:14:25 +0000 Subject: [PATCH 2/3] added prior visibility setting for projects --- RATapi/project.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/RATapi/project.py b/RATapi/project.py index eec7ec99..b8c51e5e 100644 --- a/RATapi/project.py +++ b/RATapi/project.py @@ -713,6 +713,30 @@ def get_contrast_model_field(self): model_field = "custom_files" return model_field + def set_prior_visibility(self, priors_visible: bool): + """Set whether priors are visible or invisible for all parameters. + + Parameters + ---------- + priors_visible : bool + Whether priors should be shown. + + """ + for classlist_name in parameter_class_lists: + classlist = getattr(self, classlist_name) + for i in range(0, len(classlist)): + classlist[i].show_priors = priors_visible + + def show_priors(self): + """Show priors for all parameters in the model.""" + # convenience function from set_prior_visibility + self.set_prior_visibility(True) + + def hide_priors(self): + """Hide priors for all parameters in the model.""" + # convenience function from set_prior_visibility + self.set_prior_visibility(False) + def write_script(self, obj_name: str = "problem", script: str = "project_script.py"): """Write a python script that can be run to reproduce this project object. From e1306a5e223996cafbe201b478a6e1ae54daddfd Mon Sep 17 00:00:00 2001 From: alexhroom Date: Fri, 10 Jan 2025 14:01:39 +0000 Subject: [PATCH 3/3] added unused background field hiding --- RATapi/classlist.py | 4 +++- RATapi/models.py | 10 ++++++++++ tests/test_project.py | 20 ++++++++++---------- 3 files changed, 23 insertions(+), 11 deletions(-) diff --git a/RATapi/classlist.py b/RATapi/classlist.py index 10c4975f..48a3e6fc 100644 --- a/RATapi/classlist.py +++ b/RATapi/classlist.py @@ -82,7 +82,9 @@ def __str__(self): for field in table_fields[1:]: value = getattr(model, field, "") if isinstance(value, np.ndarray): - value = f"{'Data array: ['+' x '.join(str(i) for i in value.shape) if value.size > 0 else '['}]" + value = ( + f"{'Data array: [' + ' x '.join(str(i) for i in value.shape) if value.size > 0 else '['}]" + ) elif field == "model": value = "\n".join(str(element) for element in value) else: diff --git a/RATapi/models.py b/RATapi/models.py index b0ecbf73..095badbe 100644 --- a/RATapi/models.py +++ b/RATapi/models.py @@ -72,6 +72,16 @@ def __setattr__(self, name, value): super().__setattr__(name, value) + @property + def display_fields(self) -> dict: + visible_fields = ["name", "type", "source"] + if self.type != TypeOptions.Constant: + visible_fields.append("value_1") + if self.type == TypeOptions.Function: + visible_fields.extend(["value_2", "value_3", "value_4", "value_5"]) + + return {f: getattr(self, f) for f in visible_fields} + class Background(Signal): """A background signal. diff --git a/tests/test_project.py b/tests/test_project.py index 8141d6e5..6ec58075 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -110,11 +110,11 @@ def default_project_str(): "| 0 | Background Param 1 | 1e-07 | 1e-06 | 1e-05 | False |\n" "+-------+--------------------+-------+-------+-------+-------+\n\n" "Backgrounds: ---------------------------------------------------------------------------------------\n\n" - "+-------+--------------+----------+--------------------+---------+---------+---------+---------+---------+\n" - "| index | name | type | source | value 1 | value 2 | value 3 | value 4 | value 5 |\n" - "+-------+--------------+----------+--------------------+---------+---------+---------+---------+---------+\n" - "| 0 | Background 1 | constant | Background Param 1 | | | | | |\n" - "+-------+--------------+----------+--------------------+---------+---------+---------+---------+---------+\n\n" + "+-------+--------------+----------+--------------------+\n" + "| index | name | type | source |\n" + "+-------+--------------+----------+--------------------+\n" + "| 0 | Background 1 | constant | Background Param 1 |\n" + "+-------+--------------+----------+--------------------+\n\n" "Resolution Parameters: -----------------------------------------------------------------------------\n\n" "+-------+--------------------+------+-------+------+-------+\n" "| index | name | min | value | max | fit |\n" @@ -122,11 +122,11 @@ def default_project_str(): "| 0 | Resolution Param 1 | 0.01 | 0.03 | 0.05 | False |\n" "+-------+--------------------+------+-------+------+-------+\n\n" "Resolutions: ---------------------------------------------------------------------------------------\n\n" - "+-------+--------------+----------+--------------------+---------+---------+---------+---------+---------+\n" - "| index | name | type | source | value 1 | value 2 | value 3 | value 4 | value 5 |\n" - "+-------+--------------+----------+--------------------+---------+---------+---------+---------+---------+\n" - "| 0 | Resolution 1 | constant | Resolution Param 1 | | | | | |\n" - "+-------+--------------+----------+--------------------+---------+---------+---------+---------+---------+\n\n" + "+-------+--------------+----------+--------------------+\n" + "| index | name | type | source |\n" + "+-------+--------------+----------+--------------------+\n" + "| 0 | Resolution 1 | constant | Resolution Param 1 |\n" + "+-------+--------------+----------+--------------------+\n\n" "Data: ----------------------------------------------------------------------------------------------\n\n" "+-------+------------+------+------------+------------------+\n" "| index | name | data | data range | simulation range |\n"