diff --git a/docs/user-guide/plotting.qmd b/docs/user-guide/plotting.qmd index 5fdce145d..ca364d885 100644 --- a/docs/user-guide/plotting.qmd +++ b/docs/user-guide/plotting.qmd @@ -71,6 +71,13 @@ cmp.plot.timeseries(); cmp.plot.scatter(); ``` +## Transform values +```{python} +(cmp.transform_values(lambda x: x / 100.0, new_quantity=ms.Quantity(cmp.quantity.name, "cm")) + .plot.timeseries() +); +``` + ## Taylor diagrams diff --git a/src/modelskill/comparison/_collection.py b/src/modelskill/comparison/_collection.py index 2e3bddf1c..7eb1e2121 100644 --- a/src/modelskill/comparison/_collection.py +++ b/src/modelskill/comparison/_collection.py @@ -16,6 +16,7 @@ overload, Hashable, Tuple, + TYPE_CHECKING, ) import zipfile import numpy as np @@ -38,6 +39,10 @@ IdxOrNameTypes, TimeTypes, ) +from ..quantity import Quantity + +if TYPE_CHECKING: + from numpy._typing._array_like import NDArray class ComparerCollection(Mapping): @@ -875,3 +880,27 @@ def _load_comparer(folder: str, f: str) -> Comparer: cmp = Comparer.load(f) os.remove(f) return cmp + + def transform_values( + self, + func: Callable[[NDArray[np.floating]], NDArray[np.floating]], + new_quantity: Quantity | None = None, + ) -> "ComparerCollection": + """Transform the values of all comparers using a function. + + Parameters + ---------- + func : Callable + Function to apply to the values. + new_quantity : Quantity, optional + New quantity for the transformed values. If None, the original quantity is used. + + Returns + ------- + ComparerCollection + New ComparerCollection with transformed values. + """ + cmps = [ + cmp.transform_values(func=func, new_quantity=new_quantity) for cmp in self + ] + return ComparerCollection(cmps) diff --git a/src/modelskill/comparison/_comparison.py b/src/modelskill/comparison/_comparison.py index eff157a34..151188c44 100644 --- a/src/modelskill/comparison/_comparison.py +++ b/src/modelskill/comparison/_comparison.py @@ -44,6 +44,7 @@ if TYPE_CHECKING: from ._collection import ComparerCollection + from numpy._typing._array_like import NDArray Serializable = Union[str, int, float] @@ -1273,3 +1274,35 @@ def load(filename: Union[str, Path]) -> "Comparer": else: raise NotImplementedError(f"Unknown gtype: {data.gtype}") + + def transform_values( + self, + func: Callable[[NDArray[np.floating]], NDArray[np.floating]], + new_quantity: Quantity | None = None, + ) -> Comparer: + """Transform observation and model values using a function + + Parameters + ---------- + func : Callable[NDArray[np.floating]], NDArray[np.floating]] + function to apply to observation and model values + new_quantity : Quantity, optional + new quantity, by default None + + Returns + ------- + Comparer + new Comparer with transformed values + + """ + cmp = self.copy() + for var in cmp.data.data_vars: + if cmp.data[var].attrs["kind"] in ["observation", "model"]: + cmp.data[var].values = func(cmp.data[var].values) + for var, ts in cmp.raw_mod_data.items(): + ts.data[var].values = func(ts.data[var].values) + cmp.raw_mod_data[var] = ts + + if new_quantity is not None: + cmp.quantity = new_quantity + return cmp diff --git a/tests/test_comparer.py b/tests/test_comparer.py index 5e27818a9..dc712f30c 100644 --- a/tests/test_comparer.py +++ b/tests/test_comparer.py @@ -990,3 +990,23 @@ def test_load_comparer_from_root_namespace(pt_df, tmp_path): assert cmp2.n_points == 6 assert cmp2.name == "Observation" assert cmp2.score()["m1"] == pytest.approx(0.5916079783099617) + + +def test_transform_values_new_quantity(pc: Comparer) -> None: + assert pc.data.m1.values[0] == 1.5 + assert pc.data.Observation.values[0] == 1.0 + pc2 = pc.transform_values(lambda x: x / 100.0, ms.Quantity(pc.quantity.name, "cm")) + assert pc2.quantity.unit == "cm" + assert pc.data.m1.values[0] == 1.5 + assert pc2.data.m1.values[0] == 0.015 + assert pc.data.Observation.values[0] == 1.0 + assert pc2.data.Observation.values[0] == 0.01 + + +def test_transform_values_keep_quantity(pc: Comparer) -> None: + pc2 = pc.transform_values(lambda x: x + 1.0) + assert pc2.quantity.unit == "m" + assert pc.data.m1.values[0] == 1.5 + assert pc2.data.m1.values[0] == 2.5 + assert pc.data.Observation.values[0] == 1.0 + assert pc2.data.Observation.values[0] == 2.0 diff --git a/tests/test_comparercollection.py b/tests/test_comparercollection.py index 85587979c..5b5806def 100644 --- a/tests/test_comparercollection.py +++ b/tests/test_comparercollection.py @@ -670,3 +670,10 @@ def test_score_changes_when_weights_override_defaults(): assert cc.score()["m"] == pytest.approx(1.90909) assert cc.score(weights={"bar": 2.0})["m"] == pytest.approx(1.8333333) assert cc.score(weights={"foo": 1.0, "bar": 2.0})["m"] == pytest.approx(1.333333) + + +def test_transform_values_new_quantity(cc: ms.ComparerCollection) -> None: + cc2 = cc.transform_values( + lambda x: x / 100.0, ms.Quantity(cc[0].quantity.name, "cm") + ) + assert all(cmp.quantity.unit == "cm" for cmp in cc2)