Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/user-guide/plotting.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ cmp.plot.timeseries();
cmp.plot.scatter();
```

## Transform values
```{python}
(cmp.transform_values(lambda x: x / 100.0, new_quantity=ms.Quantity(cmp.quantity.name, "cm"))
.plot.timeseries()
);
```


## Taylor diagrams

Expand Down
29 changes: 29 additions & 0 deletions src/modelskill/comparison/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
overload,
Hashable,
Tuple,
TYPE_CHECKING,
)
import zipfile
import numpy as np
Expand All @@ -38,6 +39,10 @@
IdxOrNameTypes,
TimeTypes,
)
from ..quantity import Quantity

if TYPE_CHECKING:
from numpy._typing._array_like import NDArray


class ComparerCollection(Mapping):
Expand Down Expand Up @@ -875,3 +880,27 @@ def _load_comparer(folder: str, f: str) -> Comparer:
cmp = Comparer.load(f)
os.remove(f)
return cmp

def transform_values(
self,
func: Callable[[NDArray[np.floating]], NDArray[np.floating]],
new_quantity: Quantity | None = None,
) -> "ComparerCollection":
"""Transform the values of all comparers using a function.

Parameters
----------
func : Callable
Function to apply to the values.
new_quantity : Quantity, optional
New quantity for the transformed values. If None, the original quantity is used.

Returns
-------
ComparerCollection
New ComparerCollection with transformed values.
"""
cmps = [
cmp.transform_values(func=func, new_quantity=new_quantity) for cmp in self
]
return ComparerCollection(cmps)
33 changes: 33 additions & 0 deletions src/modelskill/comparison/_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

if TYPE_CHECKING:
from ._collection import ComparerCollection
from numpy._typing._array_like import NDArray

Serializable = Union[str, int, float]

Expand Down Expand Up @@ -1273,3 +1274,35 @@ def load(filename: Union[str, Path]) -> "Comparer":

else:
raise NotImplementedError(f"Unknown gtype: {data.gtype}")

def transform_values(
self,
func: Callable[[NDArray[np.floating]], NDArray[np.floating]],
new_quantity: Quantity | None = None,
) -> Comparer:
"""Transform observation and model values using a function

Parameters
----------
func : Callable[NDArray[np.floating]], NDArray[np.floating]]
function to apply to observation and model values
new_quantity : Quantity, optional
new quantity, by default None

Returns
-------
Comparer
new Comparer with transformed values

"""
cmp = self.copy()
for var in cmp.data.data_vars:
if cmp.data[var].attrs["kind"] in ["observation", "model"]:
cmp.data[var].values = func(cmp.data[var].values)
for var, ts in cmp.raw_mod_data.items():
ts.data[var].values = func(ts.data[var].values)
cmp.raw_mod_data[var] = ts

if new_quantity is not None:
cmp.quantity = new_quantity
return cmp
20 changes: 20 additions & 0 deletions tests/test_comparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,3 +990,23 @@ def test_load_comparer_from_root_namespace(pt_df, tmp_path):
assert cmp2.n_points == 6
assert cmp2.name == "Observation"
assert cmp2.score()["m1"] == pytest.approx(0.5916079783099617)


def test_transform_values_new_quantity(pc: Comparer) -> None:
assert pc.data.m1.values[0] == 1.5
assert pc.data.Observation.values[0] == 1.0
pc2 = pc.transform_values(lambda x: x / 100.0, ms.Quantity(pc.quantity.name, "cm"))
assert pc2.quantity.unit == "cm"
assert pc.data.m1.values[0] == 1.5
assert pc2.data.m1.values[0] == 0.015
assert pc.data.Observation.values[0] == 1.0
assert pc2.data.Observation.values[0] == 0.01


def test_transform_values_keep_quantity(pc: Comparer) -> None:
pc2 = pc.transform_values(lambda x: x + 1.0)
assert pc2.quantity.unit == "m"
assert pc.data.m1.values[0] == 1.5
assert pc2.data.m1.values[0] == 2.5
assert pc.data.Observation.values[0] == 1.0
assert pc2.data.Observation.values[0] == 2.0
7 changes: 7 additions & 0 deletions tests/test_comparercollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,3 +670,10 @@ def test_score_changes_when_weights_override_defaults():
assert cc.score()["m"] == pytest.approx(1.90909)
assert cc.score(weights={"bar": 2.0})["m"] == pytest.approx(1.8333333)
assert cc.score(weights={"foo": 1.0, "bar": 2.0})["m"] == pytest.approx(1.333333)


def test_transform_values_new_quantity(cc: ms.ComparerCollection) -> None:
cc2 = cc.transform_values(
lambda x: x / 100.0, ms.Quantity(cc[0].quantity.name, "cm")
)
assert all(cmp.quantity.unit == "cm" for cmp in cc2)