Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,7 @@ dmypy.json
# Pyre type checker
.pyre/

# ignore data except for .msk files
data/*
!data/*.msk
tests/testdata/tmp
/tmp/

Expand All @@ -148,6 +146,7 @@ streamlit_app.py
notes.md

notebooks/Untitled.ipynb
.envrc

docs/_site/
docs/_extensions/
Expand Down
4 changes: 2 additions & 2 deletions docs/user-guide/workflow.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,13 @@ This method allow filtering of the data in several ways:
It can be useful to save the comparer collection for later use. This can be done using the `save()` method:

```python
cc.save("my_comparer_collection.msk")
cc.save("my_comparer_collection.nc")
```

The comparer collection can be loaded again from disk, using the `load()` method:

```python
cc = ms.load("my_comparer_collection.msk")
cc = ms.load("my_comparer_collection.nc")
```


Expand Down
14,153 changes: 12,622 additions & 1,531 deletions notebooks/Metocean_MIKE21SW_DutchCoast.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dependencies = [
"pandas >= 1.4",
"mikeio >= 1.2",
"matplotlib",
"xarray",
"xarray>=2024.10.0",
"netCDF4",
"scipy",
"jinja2", # used for skill.style
Expand Down
4 changes: 2 additions & 2 deletions src/modelskill/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def load(filename: Union[str, Path]) -> Comparer | ComparerCollection:
Examples
--------
>>> cc = ms.match(obs, mod)
>>> cc.save("my_comparer_collection.msk")
>>> cc2 = ms.load("my_comparer_collection.msk")"""
>>> cc.save("my_comparer_collection.nc")
>>> cc2 = ms.load("my_comparer_collection.nc")"""

try:
return ComparerCollection.load(filename)
Expand Down
70 changes: 37 additions & 33 deletions src/modelskill/comparison/_collection.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from __future__ import annotations
from copy import deepcopy
import os
from pathlib import Path
import tempfile
from typing import (
Any,
Callable,
Expand All @@ -18,9 +16,9 @@
Tuple,
)
import warnings
import zipfile
import numpy as np
import pandas as pd
import xarray as xr


from .. import metrics as mtr
Expand Down Expand Up @@ -829,42 +827,35 @@ def score(
return score

def save(self, filename: Union[str, Path]) -> None:
"""Save the ComparerCollection to a zip file.
"""Save the ComparerCollection to a hierarchical NetCDF file.

Each comparer is stored as a netcdf file in the zip file.
Each comparer is stored as a netcdf group.

Parameters
----------
filename : str or Path
Filename of the zip file.
Filename of the nc file.

Examples
--------
>>> cc = ms.match(obs, mod)
>>> cc.save("my_comparer_collection.msk")
>>> cc.save("my_comparer_collection.nc")
"""

files = []
no = 0
dt = xr.DataTree()
for name, cmp in self._comparers.items():
cmp_fn = f"{no}_{name}.nc"
cmp.save(cmp_fn)
files.append(cmp_fn)
no += 1
dtc = cmp._save()
dt[name] = dtc

with zipfile.ZipFile(filename, "w") as zip:
for f in files:
zip.write(f)
os.remove(f)
dt.to_netcdf(filename)

@staticmethod
def load(filename: Union[str, Path]) -> "ComparerCollection":
"""Load a ComparerCollection from a zip file.
def load(filename: Union[str, Path], method: str = "tree") -> "ComparerCollection":
"""Load a ComparerCollection from a NetCDF file.

Parameters
----------
filename : str or Path
Filename of the zip file.
Filename of the nc file.

Returns
-------
Expand All @@ -874,25 +865,38 @@ def load(filename: Union[str, Path]) -> "ComparerCollection":
Examples
--------
>>> cc = ms.match(obs, mod)
>>> cc.save("my_comparer_collection.msk")
>>> cc2 = ms.ComparerCollection.load("my_comparer_collection.msk")
>>> cc.save("my_comparer_collection.nc")
>>> cc2 = ms.ComparerCollection.load("my_comparer_collection.nc")
"""

folder = tempfile.TemporaryDirectory().name
if method == "tree":
dt = xr.open_datatree(filename)
groups = [x for x in dt.children]
comparers = [Comparer._load(dt[group]) for group in groups]

return ComparerCollection(comparers)
else:
import tempfile
import os
import zipfile

with zipfile.ZipFile(filename, "r") as zip:
for f in zip.namelist():
if f.endswith(".nc"):
zip.extract(f, path=folder)
folder = tempfile.TemporaryDirectory().name

comparers = [
ComparerCollection._load_comparer(folder, f)
for f in sorted(os.listdir(folder))
]
return ComparerCollection(comparers)
with zipfile.ZipFile(filename, "r") as zip:
for f in zip.namelist():
if f.endswith(".nc"):
zip.extract(f, path=folder)

comparers = [
ComparerCollection._load_comparer(folder, f)
for f in sorted(os.listdir(folder))
]
return ComparerCollection(comparers)

@staticmethod
def _load_comparer(folder: str, f: str) -> Comparer:
import os

f = os.path.join(folder, f)
cmp = Comparer.load(f)
os.remove(f)
Expand Down
70 changes: 53 additions & 17 deletions src/modelskill/comparison/_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,6 +1211,29 @@ def to_dataframe(self) -> pd.DataFrame:
else:
raise NotImplementedError(f"Unknown gtype: {self.gtype}")

def _save(self) -> xr.DataTree:
ds = self.data

if self.gtype == "point":
dt = xr.DataTree()
dt["matched"] = ds
dt["raw"] = xr.DataTree()

for key, ts_mod in self.raw_mod_data.items():
ts_mod = ts_mod.copy()
dt["raw"][key] = ts_mod.data

dt.attrs["gtype"] = "point"
return dt
elif self.gtype == "track":
# There is no need to save raw data for track data, since it is identical to the matched data
dt = xr.DataTree()
dt.attrs["gtype"] = "track"
dt["matched"] = ds
return dt

raise NotImplementedError(f"Unknown gtype: {self.gtype}")

def save(self, filename: Union[str, Path]) -> None:
"""Save to netcdf file

Expand All @@ -1219,24 +1242,31 @@ def save(self, filename: Union[str, Path]) -> None:
filename : str or Path
filename
"""
ds = self.data
dt = self._save()

# add self.raw_mod_data to ds with prefix 'raw_' to avoid name conflicts
# an alternative strategy would be to use NetCDF groups
# https://docs.xarray.dev/en/stable/user-guide/io.html#groups
dt.to_netcdf(filename)

# There is no need to save raw data for track data, since it is identical to the matched data
if self.gtype == "point":
ds = self.data.copy() # copy needed to avoid modifying self.data
@staticmethod
def _load(data: xr.DataTree | xr.DataArray) -> "Comparer":
if data.gtype == "track":
return Comparer(matched_data=data["matched"].to_dataset())

for key, ts_mod in self.raw_mod_data.items():
ts_mod = ts_mod.copy()
# rename time to unique name
ts_mod.data = ts_mod.data.rename({"time": "_time_raw_" + key})
# da = ds_mod.to_xarray()[key]
ds["_raw_" + key] = ts_mod.data[key]
if data.gtype == "point":
raw_mod_data: Dict[str, PointModelResult] = {}

names = [x for x in data["raw"].children]
for var in names:
ds = data["raw"][var].to_dataset()
ts = PointModelResult(data=ds, name=var)

raw_mod_data[var] = ts

ds.to_netcdf(filename)
return Comparer(
matched_data=data["matched"].to_dataset(), raw_mod_data=raw_mod_data
)

else:
raise NotImplementedError(f"Unknown gtype: {data.gtype}")

@staticmethod
def load(filename: Union[str, Path]) -> "Comparer":
Expand All @@ -1251,6 +1281,15 @@ def load(filename: Union[str, Path]) -> "Comparer":
-------
Comparer
"""
try:
with xr.open_datatree(filename) as dt:
data = dt.load()
return Comparer._load(data)
except KeyError:
return Comparer._load_legacy(filename)

@staticmethod
def _load_legacy(filename: str | Path):
with xr.open_dataset(filename) as ds:
data = ds.load()

Expand All @@ -1275,6 +1314,3 @@ def load(filename: Union[str, Path]) -> "Comparer":
data = data[[v for v in data.data_vars if "time" in data[v].dims]]

return Comparer(matched_data=data, raw_mod_data=raw_mod_data)

else:
raise NotImplementedError(f"Unknown gtype: {data.gtype}")
6 changes: 3 additions & 3 deletions src/modelskill/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def vistula() -> ComparerCollection:
-------
ComparerCollection
"""
fn = str(files("modelskill.data") / "vistula.msk")
fn = str(files("modelskill.data") / "vistula.nc")
return ComparerCollection.load(fn)


Expand All @@ -48,5 +48,5 @@ def oresund() -> ComparerCollection:
-------
ComparerCollection
"""
fn = str(files("modelskill.data") / "oresund.msk")
return ComparerCollection.load(fn)
fn = str(files("modelskill.data") / "oresund.nc")
return ms.load(fn)
Binary file removed src/modelskill/data/oresund.msk
Binary file not shown.
Binary file added src/modelskill/data/oresund.nc
Binary file not shown.
Binary file modified src/modelskill/data/vistula.msk
Binary file not shown.
Binary file added src/modelskill/data/vistula.nc
Binary file not shown.
10 changes: 10 additions & 0 deletions tests/test_comparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,16 @@ def test_from_matched_dfs0():
) == pytest.approx(0.0476569069177831)


def test_save_and_load(pc: Comparer, tmp_path) -> None:
filename = tmp_path / "test.nc"
pc.save(filename)

pc2 = Comparer.load(filename)

assert pc2.name == pc.name
assert pc2.gtype == pc.gtype


def test_from_matched_x_or_x_item_not_both():
with pytest.raises(ValueError, match="x and x_item cannot both be specified"):
ms.from_matched(
Expand Down
8 changes: 4 additions & 4 deletions tests/test_comparercollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def test_save_and_load_preserves_order_of_comparers(tmp_path):
assert cc[1].name == "alpha"
assert cc[2].name == "bravo"

fn = tmp_path / "test_cc.msk"
fn = tmp_path / "test_cc.nc"
cc.save(fn)

cc2 = modelskill.load(fn)
Expand All @@ -417,7 +417,7 @@ def test_save_and_load_preserves_order_of_comparers(tmp_path):


def test_save(cc: modelskill.ComparerCollection, tmp_path):
fn = tmp_path / "test_cc.msk"
fn = tmp_path / "test_cc.nc"
assert cc[0].data.attrs["modelskill_version"] == modelskill.__version__
cc.save(fn)

Expand All @@ -429,15 +429,15 @@ def test_save(cc: modelskill.ComparerCollection, tmp_path):


def test_load_from_root_module(cc, tmp_path):
fn = tmp_path / "test_cc.msk"
fn = tmp_path / "test_cc.nc"
cc.save(fn)

cc2 = modelskill.load(fn)
assert len(cc2) == 2


def test_save_and_load_preserves_raw_model_data(cc, tmp_path):
fn = tmp_path / "test_cc.msk"
fn = tmp_path / "test_cc.nc"
assert len(cc["fake point obs"].raw_mod_data["m1"]) == 6
cc.save(fn)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def test_save_comparercollection(o1, o3, tmp_path):

cc = ms.match([o1, o3], da)

fn = tmp_path / "cc.msk"
fn = tmp_path / "cc.nc"
cc.save(fn)

assert fn.exists()
Expand Down
Loading