Skip to content
12 changes: 12 additions & 0 deletions src/spatialdata/_core/operations/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,18 @@ def _(
transformed_dask, raster_translation_single_scale = _transform_raster(
data=xdata.data, axes=xdata.dims, transformation=composed, **kwargs
)

# if a scale in the transformed data has zero shape, we skip it
if 0 in transformed_dask.shape:
if k == "scale0":
raise ValueError(
"The transformation leads to zero shaped data even at the highest resolution level. "
"Check the scaling component of the transformation."
)
# no risk of skipping a scale (e.g. scale1) but not the next ones (e.g. scale2), because once a scale
# is skipped, all the lower scales are also skipped
continue

if raster_translation is None:
raster_translation = raster_translation_single_scale
# we set a dummy empty dict for the transformation that will be replaced with the correct transformation for
Expand Down
33 changes: 32 additions & 1 deletion tests/core/operations/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from spatialdata._core.data_extent import are_extents_equal, get_extent
from spatialdata._core.spatialdata import SpatialData
from spatialdata._utils import unpad_raster
from spatialdata.models import PointsModel, ShapesModel, get_axes_names
from spatialdata.models import Image2DModel, PointsModel, ShapesModel, get_axes_names
from spatialdata.transformations.operations import (
align_elements_using_landmarks,
get_transformation,
Expand Down Expand Up @@ -229,6 +229,37 @@ def test_transform_shapes(shapes: SpatialData):
assert geom_almost_equals(p0["geometry"], p1["geometry"])


def test_transform_datatree_scale_handling():
"""
Test the cases in which the lowest and highest scale of the result of a
transformed multi-scale image would be zero shape.
"""

test_image = Image2DModel.parse(
np.ones((1, 10, 10)),
dims=("c", "y", "x"),
scale_factors=[2, 4],
transformations={
"cs1": Scale([0.5] * 2, axes=["y", "x"]),
"cs2": Scale([0.01] * 2, axes=["y", "x"]),
},
)

# check that the transform doesn't raise an error and that it
# discards the lowest resolution level
test_image_t = transform(test_image, to_coordinate_system="cs1")
assert list(test_image.keys()) == ["scale0", "scale1", "scale2"]
assert list(test_image_t.keys()) == ["scale0", "scale1"]

# check that a ValueError is raised when no resolution level
# is left after the transformation
with pytest.raises(
ValueError,
match="The transformation leads to zero shaped data even at the highest resolution level",
):
transform(test_image, to_coordinate_system="cs2")


def test_map_coordinate_systems_single_path(full_sdata: SpatialData):
scale = Scale([2], axes=("x",))
translation = Translation([100], axes=("x",))
Expand Down
Loading