Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 36 additions & 13 deletions src/spatialdata/_core/concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@

from spatialdata._core._utils import _find_common_table_keys
from spatialdata._core.spatialdata import SpatialData
from spatialdata.models import SpatialElement, TableModel, get_table_keys
from spatialdata.models import TableModel, get_table_keys
from spatialdata.transformations import (
get_transformation,
remove_transformation,
set_transformation,
)

__all__ = [
"concatenate",
Expand Down Expand Up @@ -78,7 +83,8 @@ def concatenate(
concatenate_tables: bool = False,
obs_names_make_unique: bool = True,
modify_tables_inplace: bool = False,
attrs_merge: StrategiesLiteral | Callable[[list[dict[Any, Any]]], dict[Any, Any]] | None = None,
merge_coordinate_systems_on_name: bool = False,
attrs_merge: (StrategiesLiteral | Callable[[list[dict[Any, Any]]], dict[Any, Any]] | None) = None,
**kwargs: Any,
) -> SpatialData:
"""
Expand Down Expand Up @@ -107,6 +113,8 @@ def concatenate(
modify_tables_inplace
Whether to modify the tables in place. If `True`, the tables will be modified in place. If `False`, the tables
will be copied before modification. Copying is enabled by default but can be disabled for performance reasons.
merge_coordinate_systems_on_name
Whether to keep coordinate system names unchanged (True) or add suffixes (False).
attrs_merge
How the elements of `.attrs` are selected. Uses the same set of strategies as the `uns_merge` argument of [anndata.concat](https://anndata.readthedocs.io/en/latest/generated/anndata.concat.html)
kwargs
Expand Down Expand Up @@ -138,7 +146,10 @@ def concatenate(
rename_tables=not concatenate_tables,
rename_obs_names=obs_names_make_unique and concatenate_tables,
modify_tables_inplace=modify_tables_inplace,
merge_coordinate_systems_on_name=merge_coordinate_systems_on_name,
)
elif merge_coordinate_systems_on_name:
raise ValueError("`merge_coordinate_systems_on_name` can only be used if `sdatas` is a dictionary")

ERROR_STR = (
" must have unique names across the SpatialData objects to concatenate. Please pass a `dict[str, SpatialData]`"
Expand Down Expand Up @@ -219,12 +230,27 @@ def _fix_ensure_unique_element_names(
rename_tables: bool,
rename_obs_names: bool,
modify_tables_inplace: bool,
merge_coordinate_systems_on_name: bool,
) -> list[SpatialData]:
elements_by_sdata: list[dict[str, SpatialElement]] = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this refactoring is fine

tables_by_sdata: list[dict[str, AnnData]] = []
sdatas_fixed = []
for suffix, sdata in sdatas.items():
elements = {f"{name}-{suffix}": el for _, name, el in sdata.gen_spatial_elements()}
elements_by_sdata.append(elements)
# Create new elements dictionary with suffixed names
elements = {}
for _, name, el in sdata.gen_spatial_elements():
new_element_name = f"{name}-{suffix}"
if not merge_coordinate_systems_on_name:
# Set new transformations with suffixed coordinate system names
transformations = get_transformation(el, get_all=True)
assert isinstance(transformations, dict)

remove_transformation(el, remove_all=True)
for cs, t in transformations.items():
new_cs = f"{cs}-{suffix}"
set_transformation(el, t, to_coordinate_system=new_cs)

elements[new_element_name] = el

# Handle tables with suffix
tables = {}
for name, table in sdata.tables.items():
if not modify_tables_inplace:
Expand All @@ -248,11 +274,8 @@ def _fix_ensure_unique_element_names(
# fix the table name
new_name = f"{name}-{suffix}" if rename_tables else name
tables[new_name] = table
tables_by_sdata.append(tables)
sdatas_fixed = []
for elements, tables in zip(elements_by_sdata, tables_by_sdata, strict=True):
if tables is not None:
elements.update(tables)
sdata = SpatialData.init_from_elements(elements)
sdatas_fixed.append(sdata)

# Create new SpatialData object with suffixed elements and tables
sdata_fixed = SpatialData.init_from_elements(elements | tables)
sdatas_fixed.append(sdata_fixed)
return sdatas_fixed
8 changes: 6 additions & 2 deletions src/spatialdata/transformations/ngff/ngff_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,9 @@ def transform_points(self, points: ArrayLike) -> ArrayLike:
self._validate_transform_points_shapes(len(input_axes), points.shape)
p = np.vstack([points.T, np.ones(points.shape[0])])
q = self.affine @ p
return q[: len(output_axes), :].T # type: ignore[no-any-return]
res = q[: len(output_axes), :].T
assert isinstance(res, np.ndarray)
return res

def to_affine(self) -> "NgffAffine":
return NgffAffine(
Expand Down Expand Up @@ -743,7 +745,9 @@ def _get_and_validate_axes(self) -> tuple[tuple[str, ...], tuple[str, ...]]:
def transform_points(self, points: ArrayLike) -> ArrayLike:
input_axes, _ = self._get_and_validate_axes()
self._validate_transform_points_shapes(len(input_axes), points.shape)
return (self.rotation @ points.T).T # type: ignore[no-any-return]
res = (self.rotation @ points.T).T
assert isinstance(res, np.ndarray)
return res

def to_affine(self) -> NgffAffine:
m = np.eye(len(self.rotation) + 1)
Expand Down
41 changes: 41 additions & 0 deletions tests/core/operations/test_spatialdata_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,47 @@ def _n_elements(sdata: SpatialData) -> int:
assert "blobs_image-sample" in c.images


@pytest.mark.parametrize("merge_coordinate_systems_on_name", [True, False])
def test_concatenate_merge_coordinate_systems_on_name(merge_coordinate_systems_on_name):
blob1 = blobs()
blob2 = blobs()

if merge_coordinate_systems_on_name:
with pytest.raises(
ValueError,
match="`merge_coordinate_systems_on_name` can only be used if `sdatas` is a dictionary",
):
concatenate((blob1, blob2), merge_coordinate_systems_on_name=merge_coordinate_systems_on_name)

sdata_keys = ["blob1", "blob2"]
sdata = concatenate(
dict(zip(sdata_keys, [blob1, blob2], strict=True)),
merge_coordinate_systems_on_name=merge_coordinate_systems_on_name,
)

if merge_coordinate_systems_on_name:
assert set(sdata.coordinate_systems) == {"global"}
else:
assert set(sdata.coordinate_systems) == {"global-blob1", "global-blob2"}

# extra checks not specific to this test, we could remove them or leave them just
# in case
expected_images = ["blobs_image", "blobs_multiscale_image"]
expected_labels = ["blobs_labels", "blobs_multiscale_labels"]
expected_points = ["blobs_points"]
expected_shapes = ["blobs_circles", "blobs_polygons", "blobs_multipolygons"]

expected_suffixed_images = [f"{name}-{key}" for key in sdata_keys for name in expected_images]
expected_suffixed_labels = [f"{name}-{key}" for key in sdata_keys for name in expected_labels]
expected_suffixed_points = [f"{name}-{key}" for key in sdata_keys for name in expected_points]
expected_suffixed_shapes = [f"{name}-{key}" for key in sdata_keys for name in expected_shapes]

assert set(sdata.images.keys()) == set(expected_suffixed_images)
assert set(sdata.labels.keys()) == set(expected_suffixed_labels)
assert set(sdata.points.keys()) == set(expected_suffixed_points)
assert set(sdata.shapes.keys()) == set(expected_suffixed_shapes)


def test_locate_spatial_element(full_sdata: SpatialData) -> None:
assert full_sdata.locate_element(full_sdata.images["image2d"])[0] == "images/image2d"
im = full_sdata.images["image2d"]
Expand Down
Loading