Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions RATapi/classlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,42 @@ def __init__(self, init_list: Union[Sequence[T], T] = None, name_field: str = "n
super().__init__(init_list)

def __str__(self):
# `display_fields` gives more control over the items displayed from the list if available
if not self.data:
return str([])
try:
[model.__dict__ for model in self.data]
model_display_fields = [model.display_fields for model in self.data]
# get all items included in at least one list
# the list comprehension ensures they are in the order that they're in in the model
required_fields = list(set().union(*model_display_fields))
table_fields = ["index"] + [i for i in list(self.data[0].__dict__) if i in required_fields]
except AttributeError:
output = str(self.data)
try:
model_display_fields = [model.__dict__ for model in self.data]
table_fields = ["index"] + list(self.data[0].__dict__)
except AttributeError:
return str(self.data)

if any(model_display_fields):
table = prettytable.PrettyTable()
table.field_names = [field.replace("_", " ") for field in table_fields]
rows = []
for index, model in enumerate(self.data):
row = [index]
for field in table_fields[1:]:
value = getattr(model, field, "")
if isinstance(value, np.ndarray):
value = (
f"{'Data array: [' + ' x '.join(str(i) for i in value.shape) if value.size > 0 else '['}]"
)
elif field == "model":
value = "\n".join(str(element) for element in value)
else:
value = str(value)
row.append(value)
rows.append(row)
table.add_rows(rows)
output = table.get_string()
else:
if any(model.__dict__ for model in self.data):
table = prettytable.PrettyTable()
Expand Down
31 changes: 29 additions & 2 deletions RATapi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,15 @@ def __repr__(self):

def __str__(self):
table = prettytable.PrettyTable()
table.field_names = [key.replace("_", " ") for key in self.__dict__]
table.add_row(list(self.__dict__.values()))
table.field_names = [key.replace("_", " ") for key in self.display_fields]
table.add_row(list(self.display_fields.values()))
return table.get_string()

@property
def display_fields(self) -> dict:
"""A dictionary of which fields should be displayed by this model and their values."""
return self.__dict__


class Signal(RATModel):
"""Base model for background & resolution signals."""
Expand All @@ -67,6 +72,16 @@ def __setattr__(self, name, value):

super().__setattr__(name, value)

@property
def display_fields(self) -> dict:
visible_fields = ["name", "type", "source"]
if self.type != TypeOptions.Constant:
visible_fields.append("value_1")
if self.type == TypeOptions.Function:
visible_fields.extend(["value_2", "value_3", "value_4", "value_5"])

return {f: getattr(self, f) for f in visible_fields}


class Background(Signal):
"""A background signal.
Expand Down Expand Up @@ -525,6 +540,8 @@ class Parameter(RATModel):
mu: float = 0.0
sigma: float = np.inf

show_priors: bool = False

@model_validator(mode="after")
def check_min_max(self) -> "Parameter":
"""The maximum value of a parameter must be greater than the minimum."""
Expand All @@ -539,6 +556,16 @@ def check_value_in_range(self) -> "Parameter":
raise ValueError(f"value {self.value} is not within the defined range: {self.min} <= value <= {self.max}")
return self

@property
def display_fields(self) -> dict:
visible_fields = ["name", "min", "value", "max", "fit"]
if self.show_priors:
visible_fields.append("prior_type")
if self.prior_type == Priors.Gaussian:
visible_fields.extend(["mu", "sigma"])

return {f: getattr(self, f) for f in visible_fields}


class ProtectedParameter(Parameter):
"""A Parameter with a fixed name."""
Expand Down
24 changes: 24 additions & 0 deletions RATapi/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,30 @@ def get_contrast_model_field(self):
model_field = "custom_files"
return model_field

def set_prior_visibility(self, priors_visible: bool):
"""Set whether priors are visible or invisible for all parameters.

Parameters
----------
priors_visible : bool
Whether priors should be shown.

"""
for classlist_name in parameter_class_lists:
classlist = getattr(self, classlist_name)
for i in range(0, len(classlist)):
classlist[i].show_priors = priors_visible

def show_priors(self):
"""Show priors for all parameters in the model."""
# convenience function from set_prior_visibility
self.set_prior_visibility(True)

def hide_priors(self):
"""Hide priors for all parameters in the model."""
# convenience function from set_prior_visibility
self.set_prior_visibility(False)

def write_script(self, obj_name: str = "problem", script: str = "project_script.py"):
"""Write a python script that can be run to reproduce this project object.

Expand Down
38 changes: 38 additions & 0 deletions tests/test_classlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from collections.abc import Iterable, Sequence
from typing import Any, Union

import prettytable
import pytest

from RATapi.classlist import ClassList
Expand Down Expand Up @@ -44,6 +45,23 @@ def three_name_class_list():
return ClassList([InputAttributes(name="Alice"), InputAttributes(name="Bob"), InputAttributes(name="Eve")])


class DisplayFieldsClass:
"""A classlist with four attributes and a display_fields property."""

def __init__(self, display_range):
self.a = 1
self.b = 2
self.c = 3
self.d = 4

self.display_range = display_range

@property
def display_fields(self):
fields = ["a", "b", "c", "d"][self.display_range[0] : self.display_range[1]]
return {f: getattr(self, f) for f in fields}


class TestInitialisation:
@pytest.mark.parametrize(
"input_object",
Expand Down Expand Up @@ -174,6 +192,26 @@ def test_str_empty_classlist() -> None:
assert str(ClassList()) == str([])


@pytest.mark.parametrize(
"display_ranges, expected_header",
(
([(1, 3), (1, 3), (1, 3)], ["b", "c"]),
([(1, 2), (0, 4), (2, 3)], ["a", "b", "c", "d"]),
([(0, 2), (0, 1), (2, 3)], ["a", "b", "c"]),
),
)
def test_str_display_fields(display_ranges, expected_header):
"""If a class has the `display_fields` property, the ClassList should print with the minimal required attributes."""
class_list = ClassList([DisplayFieldsClass(dr) for dr in display_ranges])
expected_table = prettytable.PrettyTable()
expected_table.field_names = ["index"] + expected_header
expected_vals = {"a": 1, "b": 2, "c": 3, "d": 4}
row = [expected_vals[v] for v in expected_header]
expected_table.add_rows([[0] + row, [1] + row, [2] + row])

assert str(class_list) == expected_table.get_string()


@pytest.mark.parametrize(
["new_item", "expected_classlist"],
[
Expand Down
80 changes: 40 additions & 40 deletions tests/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,53 +80,53 @@ def default_project_str():
"Geometry: ------------------------------------------------------------------------------------------\n\n"
"air/substrate\n\n"
"Parameters: ----------------------------------------------------------------------------------------\n\n"
"+-------+---------------------+-----+-------+-----+------+------------+-----+-------+\n"
"| index | name | min | value | max | fit | prior type | mu | sigma |\n"
"+-------+---------------------+-----+-------+-----+------+------------+-----+-------+\n"
"| 0 | Substrate Roughness | 1.0 | 3.0 | 5.0 | True | uniform | 0.0 | inf |\n"
"+-------+---------------------+-----+-------+-----+------+------------+-----+-------+\n\n"
"+-------+---------------------+-----+-------+-----+------+\n"
"| index | name | min | value | max | fit |\n"
"+-------+---------------------+-----+-------+-----+------+\n"
"| 0 | Substrate Roughness | 1.0 | 3.0 | 5.0 | True |\n"
"+-------+---------------------+-----+-------+-----+------+\n\n"
"Bulk In: -------------------------------------------------------------------------------------------\n\n"
"+-------+---------+-----+-------+-----+-------+------------+-----+-------+\n"
"| index | name | min | value | max | fit | prior type | mu | sigma |\n"
"+-------+---------+-----+-------+-----+-------+------------+-----+-------+\n"
"| 0 | SLD Air | 0.0 | 0.0 | 0.0 | False | uniform | 0.0 | inf |\n"
"+-------+---------+-----+-------+-----+-------+------------+-----+-------+\n\n"
"+-------+---------+-----+-------+-----+-------+\n"
"| index | name | min | value | max | fit |\n"
"+-------+---------+-----+-------+-----+-------+\n"
"| 0 | SLD Air | 0.0 | 0.0 | 0.0 | False |\n"
"+-------+---------+-----+-------+-----+-------+\n\n"
"Bulk Out: ------------------------------------------------------------------------------------------\n\n"
"+-------+---------+---------+----------+----------+-------+------------+-----+-------+\n"
"| index | name | min | value | max | fit | prior type | mu | sigma |\n"
"+-------+---------+---------+----------+----------+-------+------------+-----+-------+\n"
"| 0 | SLD D2O | 6.2e-06 | 6.35e-06 | 6.35e-06 | False | uniform | 0.0 | inf |\n"
"+-------+---------+---------+----------+----------+-------+------------+-----+-------+\n\n"
"+-------+---------+---------+----------+----------+-------+\n"
"| index | name | min | value | max | fit |\n"
"+-------+---------+---------+----------+----------+-------+\n"
"| 0 | SLD D2O | 6.2e-06 | 6.35e-06 | 6.35e-06 | False |\n"
"+-------+---------+---------+----------+----------+-------+\n\n"
"Scalefactors: --------------------------------------------------------------------------------------\n\n"
"+-------+---------------+------+-------+------+-------+------------+-----+-------+\n"
"| index | name | min | value | max | fit | prior type | mu | sigma |\n"
"+-------+---------------+------+-------+------+-------+------------+-----+-------+\n"
"| 0 | Scalefactor 1 | 0.02 | 0.23 | 0.25 | False | uniform | 0.0 | inf |\n"
"+-------+---------------+------+-------+------+-------+------------+-----+-------+\n\n"
"+-------+---------------+------+-------+------+-------+\n"
"| index | name | min | value | max | fit |\n"
"+-------+---------------+------+-------+------+-------+\n"
"| 0 | Scalefactor 1 | 0.02 | 0.23 | 0.25 | False |\n"
"+-------+---------------+------+-------+------+-------+\n\n"
"Background Parameters: -----------------------------------------------------------------------------\n\n"
"+-------+--------------------+-------+-------+-------+-------+------------+-----+-------+\n"
"| index | name | min | value | max | fit | prior type | mu | sigma |\n"
"+-------+--------------------+-------+-------+-------+-------+------------+-----+-------+\n"
"| 0 | Background Param 1 | 1e-07 | 1e-06 | 1e-05 | False | uniform | 0.0 | inf |\n"
"+-------+--------------------+-------+-------+-------+-------+------------+-----+-------+\n\n"
"+-------+--------------------+-------+-------+-------+-------+\n"
"| index | name | min | value | max | fit |\n"
"+-------+--------------------+-------+-------+-------+-------+\n"
"| 0 | Background Param 1 | 1e-07 | 1e-06 | 1e-05 | False |\n"
"+-------+--------------------+-------+-------+-------+-------+\n\n"
"Backgrounds: ---------------------------------------------------------------------------------------\n\n"
"+-------+--------------+----------+--------------------+---------+---------+---------+---------+---------+\n"
"| index | name | type | source | value 1 | value 2 | value 3 | value 4 | value 5 |\n"
"+-------+--------------+----------+--------------------+---------+---------+---------+---------+---------+\n"
"| 0 | Background 1 | constant | Background Param 1 | | | | | |\n"
"+-------+--------------+----------+--------------------+---------+---------+---------+---------+---------+\n\n"
"+-------+--------------+----------+--------------------+\n"
"| index | name | type | source |\n"
"+-------+--------------+----------+--------------------+\n"
"| 0 | Background 1 | constant | Background Param 1 |\n"
"+-------+--------------+----------+--------------------+\n\n"
"Resolution Parameters: -----------------------------------------------------------------------------\n\n"
"+-------+--------------------+------+-------+------+-------+------------+-----+-------+\n"
"| index | name | min | value | max | fit | prior type | mu | sigma |\n"
"+-------+--------------------+------+-------+------+-------+------------+-----+-------+\n"
"| 0 | Resolution Param 1 | 0.01 | 0.03 | 0.05 | False | uniform | 0.0 | inf |\n"
"+-------+--------------------+------+-------+------+-------+------------+-----+-------+\n\n"
"+-------+--------------------+------+-------+------+-------+\n"
"| index | name | min | value | max | fit |\n"
"+-------+--------------------+------+-------+------+-------+\n"
"| 0 | Resolution Param 1 | 0.01 | 0.03 | 0.05 | False |\n"
"+-------+--------------------+------+-------+------+-------+\n\n"
"Resolutions: ---------------------------------------------------------------------------------------\n\n"
"+-------+--------------+----------+--------------------+---------+---------+---------+---------+---------+\n"
"| index | name | type | source | value 1 | value 2 | value 3 | value 4 | value 5 |\n"
"+-------+--------------+----------+--------------------+---------+---------+---------+---------+---------+\n"
"| 0 | Resolution 1 | constant | Resolution Param 1 | | | | | |\n"
"+-------+--------------+----------+--------------------+---------+---------+---------+---------+---------+\n\n"
"+-------+--------------+----------+--------------------+\n"
"| index | name | type | source |\n"
"+-------+--------------+----------+--------------------+\n"
"| 0 | Resolution 1 | constant | Resolution Param 1 |\n"
"+-------+--------------+----------+--------------------+\n\n"
"Data: ----------------------------------------------------------------------------------------------\n\n"
"+-------+------------+------+------------+------------------+\n"
"| index | name | data | data range | simulation range |\n"
Expand Down
Loading