From faa02360dcd51ae997fcf88c32d84f4362cec75a Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Tue, 5 Nov 2024 03:00:11 +0100 Subject: [PATCH 1/6] import DataTree from xarray --- src/spatialdata_plot/pl/basic.py | 3 +-- src/spatialdata_plot/pl/render.py | 2 +- src/spatialdata_plot/pl/utils.py | 3 +-- tests/conftest.py | 3 +-- 4 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index f3cc0fc5..cb36270a 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -14,14 +14,13 @@ import spatialdata as sd from anndata import AnnData from dask.dataframe import DataFrame as DaskDataFrame -from datatree import DataTree from geopandas import GeoDataFrame from matplotlib.axes import Axes from matplotlib.colors import Colormap, Normalize from matplotlib.figure import Figure from spatialdata import get_extent from spatialdata._utils import _deprecation_alias -from xarray import DataArray +from xarray import DataArray, DataTree from spatialdata_plot._accessor import register_spatial_data_accessor from spatialdata_plot.pl.render import ( diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 24df6f9b..d4b7cfee 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -15,7 +15,6 @@ import scanpy as sc import spatialdata as sd from anndata import AnnData -from datatree import DataTree from matplotlib.cm import ScalarMappable from matplotlib.colors import ListedColormap, Normalize from scanpy._settings import settings as sc_settings @@ -24,6 +23,7 @@ from spatialdata.transformations import ( set_transformation, ) +from xarray import DataTree from spatialdata_plot._logging import logger from spatialdata_plot.pl.render_params import ( diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 77e73273..0a04cc7e 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -26,7 +26,6 @@ from anndata import AnnData from cycler import Cycler, cycler from datashader.core import Canvas -from datatree import DataTree from geopandas import GeoDataFrame from matplotlib import colors, patheffects, rcParams from matplotlib.axes import Axes @@ -61,7 +60,7 @@ from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, SpatialElement, get_model from spatialdata.transformations.operations import get_transformation from spatialdata.transformations.transformations import Scale -from xarray import DataArray +from xarray import DataArray, DataTree from spatialdata_plot._logging import logger from spatialdata_plot.pl.render_params import ( diff --git a/tests/conftest.py b/tests/conftest.py index 884cb07c..05ef1b97 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,7 +10,6 @@ import pytest import spatialdata as sd from anndata import AnnData -from datatree import DataTree from geopandas import GeoDataFrame from matplotlib.testing.compare import compare_images from shapely.geometry import MultiPolygon, Polygon @@ -25,7 +24,7 @@ ShapesModel, TableModel, ) -from xarray import DataArray +from xarray import DataArray, DataTree import spatialdata_plot # noqa: F401 From 05fac4b3aef3342323340acf83f4c4057ed9e644 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Mon, 11 Nov 2024 20:23:17 +0100 Subject: [PATCH 2/6] update mypy to 3.10 --- .github/workflows/test.yaml | 2 +- .mypy.ini | 4 ++-- pyproject.toml | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index a5c46660..c690be1b 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - python: ["3.9", "3.10"] + python: ["3.10", "3.12"] os: [ubuntu-latest] env: diff --git a/.mypy.ini b/.mypy.ini index 2cd7b3d6..77bf7465 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -1,5 +1,5 @@ [mypy] -python_version = 3.9 +python_version = 3.10 plugins = numpy.typing.mypy_plugin ignore_errors = False @@ -25,4 +25,4 @@ no_warn_no_return = True show_error_codes = True show_column_numbers = True -error_summary = True \ No newline at end of file +error_summary = True diff --git a/pyproject.toml b/pyproject.toml index 91becefe..52014620 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,7 @@ filterwarnings = [ [tool.black] line-length = 120 -target-version = ['py39'] +target-version = ['py310'] include = '\.pyi?$' exclude = ''' ( @@ -158,7 +158,7 @@ lint.select = [ "PGH", # pygrep-hooks ] lint.unfixable = ["B", "UP", "C4", "BLE", "T20", "RET"] -target-version = "py39" +target-version = "py310" [tool.ruff.lint.per-file-ignores] "tests/*" = ["D", "PT", "B024"] "*/__init__.py" = ["F401", "D104", "D107", "E402"] From 7558de34b3bb9af9c45b9e0828a1359520f68230 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Mon, 11 Nov 2024 23:59:40 +0100 Subject: [PATCH 3/6] fix pre-commit --- src/spatialdata_plot/pl/basic.py | 4 +- src/spatialdata_plot/pl/render.py | 3 +- src/spatialdata_plot/pl/render_params.py | 4 +- src/spatialdata_plot/pl/utils.py | 63 +++++++++++++----------- tests/conftest.py | 16 +++--- tests/pl/test_utils.py | 4 +- 6 files changed, 47 insertions(+), 47 deletions(-) diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index cb36270a..d23b9710 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -5,7 +5,7 @@ from collections import OrderedDict from copy import deepcopy from pathlib import Path -from typing import Any, Union +from typing import Any import matplotlib.pyplot as plt import numpy as np @@ -61,7 +61,7 @@ # replace with # from spatialdata._types import ColorLike # once https://github.com/scverse/spatialdata/pull/689/ is in a release -ColorLike = Union[tuple[float, ...], str] +ColorLike = tuple[float, ...] | str @register_spatial_data_accessor("pl") diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index d4b7cfee..a8a07058 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -3,7 +3,6 @@ import warnings from collections import abc from copy import copy -from typing import Union import dask import datashader as ds @@ -56,7 +55,7 @@ to_hex, ) -_Normalize = Union[Normalize, abc.Sequence[Normalize]] +_Normalize = Normalize | abc.Sequence[Normalize] def _render_shapes( diff --git a/src/spatialdata_plot/pl/render_params.py b/src/spatialdata_plot/pl/render_params.py index 88f3cf7b..ee37c57f 100644 --- a/src/spatialdata_plot/pl/render_params.py +++ b/src/spatialdata_plot/pl/render_params.py @@ -2,7 +2,7 @@ from collections.abc import Callable, Sequence from dataclasses import dataclass -from typing import Literal, Union +from typing import Literal from matplotlib.axes import Axes from matplotlib.colors import Colormap, ListedColormap, Normalize @@ -14,7 +14,7 @@ # replace with # from spatialdata._types import ColorLike # once https://github.com/scverse/spatialdata/pull/689/ is in a release -ColorLike = Union[tuple[float, ...], str] +ColorLike = tuple[float, ...] | str @dataclass diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 0a04cc7e..dd3d5ec0 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -8,7 +8,7 @@ from functools import partial from pathlib import Path from types import MappingProxyType -from typing import Any, Literal, Union +from typing import Any, Literal import dask import datashader as ds @@ -81,7 +81,7 @@ # replace with # from spatialdata._types import ColorLike # once https://github.com/scverse/spatialdata/pull/689/ is in a release -ColorLike = Union[tuple[float, ...], str] +ColorLike = tuple[float, ...] | str def _verify_plotting_tree(sdata: SpatialData) -> SpatialData: @@ -526,7 +526,7 @@ def _set_outline( outline_color: str | list[float] = "#0000000ff", # black, white **kwargs: Any, ) -> OutlineParams: - if not isinstance(outline_width, (int, float)): + if not isinstance(outline_width, int | float): raise TypeError(f"Invalid type of `outline_width`: {type(outline_width)}, expected `int` or `float`.") if outline_width == 0.0: outline = False @@ -868,9 +868,9 @@ def _generate_base_categorial_color_mapping( na_color = to_hex(to_rgba(na_color)[:3]) if na_color and len(categories) > len(colors): - return dict(zip(categories, colors + [na_color])) + return dict(zip(categories, colors + [na_color], strict=True)) - return dict(zip(categories, colors)) + return dict(zip(categories, colors, strict=True)) return _get_default_categorial_color_mapping(color_source_vector) @@ -887,7 +887,7 @@ def _modify_categorical_color_mapping( # subset base mapping to only those specified in groups modified_mapping = {key: mapping[key] for key in mapping if key in groups or key == "NaN"} elif len(palette) == len(groups) and isinstance(groups, list) and isinstance(palette, list): - modified_mapping = dict(zip(groups, palette)) + modified_mapping = dict(zip(groups, palette, strict=True)) else: raise ValueError(f"Expected palette to be of length `{len(groups)}`, found `{len(palette)}`.") @@ -908,7 +908,10 @@ def _get_default_categorial_color_mapping( palette = ["grey" for _ in range(len_cat)] logger.info("input has more than 103 categories. Uniform 'grey' color will be used for all categories.") - return {cat: to_hex(to_rgba(col)[:3]) for cat, col in zip(color_source_vector.categories, palette[:len_cat])} + return { + cat: to_hex(to_rgba(col)[:3]) + for cat, col in zip(color_source_vector.categories, palette[:len_cat], strict=True) + } def _get_categorical_color_mapping( @@ -1342,7 +1345,7 @@ def _multiscale_to_spatial_image( optimal_index_x -= 1 # pick the scale with higher resolution (worst case: downscaled afterwards) - optimal_scale = scales[min(optimal_index_x, optimal_index_y)] + optimal_scale = scales[min(int(optimal_index_x), int(optimal_index_y))] # NOTE: problematic if there are cases with > 1 data variable data_var_keys = list(multiscale_image[optimal_scale].data_vars) @@ -1412,12 +1415,12 @@ def _validate_show_parameters( return_ax: bool, save: str | Path | None, ) -> None: - if coordinate_systems is not None and not isinstance(coordinate_systems, (list, str)): + if coordinate_systems is not None and not isinstance(coordinate_systems, list | str): raise TypeError("Parameter 'coordinate_systems' must be a string or a list of strings.") font_weights = ["light", "normal", "medium", "semibold", "bold", "heavy", "black"] if legend_fontweight is not None and ( - not isinstance(legend_fontweight, (int, str)) + not isinstance(legend_fontweight, int | str) or (isinstance(legend_fontweight, str) and legend_fontweight not in font_weights) ): readable_font_weights = ", ".join(font_weights[:-1]) + ", or " + font_weights[-1] @@ -1429,7 +1432,7 @@ def _validate_show_parameters( font_sizes = ["xx-small", "x-small", "small", "medium", "large", "x-large", "xx-large"] if legend_fontsize is not None and ( - not isinstance(legend_fontsize, (int, float, str)) + not isinstance(legend_fontsize, int | float | str) or (isinstance(legend_fontsize, str) and legend_fontsize not in font_sizes) ): readable_font_sizes = ", ".join(font_sizes[:-1]) + ", or " + font_sizes[-1] @@ -1471,22 +1474,22 @@ def _validate_show_parameters( if fig is not None and not isinstance(fig, Figure): raise TypeError("Parameter 'fig' must be a matplotlib.figure.Figure.") - if title is not None and not isinstance(title, (list, str)): + if title is not None and not isinstance(title, list | str): raise TypeError("Parameter 'title' must be a string or a list of strings.") if not isinstance(share_extent, bool): raise TypeError("Parameter 'share_extent' must be a boolean.") - if not isinstance(pad_extent, (int, float)): + if not isinstance(pad_extent, int | float): raise TypeError("Parameter 'pad_extent' must be numeric.") - if ax is not None and not isinstance(ax, (Axes, list)): + if ax is not None and not isinstance(ax, Axes | list): raise TypeError("Parameter 'ax' must be a matplotlib.axes.Axes or a list of Axes.") if not isinstance(return_ax, bool): raise TypeError("Parameter 'return_ax' must be a boolean.") - if save is not None and not isinstance(save, (str, Path)): + if save is not None and not isinstance(save, str | Path): raise TypeError("Parameter 'save' must be a string or a pathlib.Path.") @@ -1505,10 +1508,10 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st elif element_type == "shapes": param_dict["element"] = [element] if element is not None else list(param_dict["sdata"].shapes.keys()) - if (channel := param_dict.get("channel")) is not None and not isinstance(channel, (list, str, int)): + if (channel := param_dict.get("channel")) is not None and not isinstance(channel, list | str | int): raise TypeError("Parameter 'channel' must be a string, an integer, or a list of strings or integers.") if isinstance(channel, list): - if not all(isinstance(c, (str, int)) for c in channel): + if not all(isinstance(c, str | int) for c in channel): raise TypeError("Each item in 'channel' list must be a string or an integer.") if not all(isinstance(c, type(channel[0])) for c in channel): raise TypeError("Each item in 'channel' list must be of the same type, either string or integer.") @@ -1533,13 +1536,13 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st param_dict["col_for_color"] = None if outline_width := param_dict.get("outline_width"): - if not isinstance(outline_width, (float, int)): + if not isinstance(outline_width, float | int): raise TypeError("Parameter 'outline_width' must be numeric.") if outline_width < 0: raise ValueError("Parameter 'outline_width' cannot be negative.") if (outline_alpha := param_dict.get("outline_alpha")) and ( - not isinstance(outline_alpha, (float, int)) or not 0 <= outline_alpha <= 1 + not isinstance(outline_alpha, float | int) or not 0 <= outline_alpha <= 1 ): raise TypeError("Parameter 'outline_alpha' must be numeric and between 0 and 1.") @@ -1547,13 +1550,13 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st raise ValueError("Parameter 'contour_px' must be a positive number.") if (alpha := param_dict.get("alpha")) is not None: - if not isinstance(alpha, (float, int)): + if not isinstance(alpha, float | int): raise TypeError("Parameter 'alpha' must be numeric.") if not 0 <= alpha <= 1: raise ValueError("Parameter 'alpha' must be between 0 and 1.") if (fill_alpha := param_dict.get("fill_alpha")) is not None: - if not isinstance(fill_alpha, (float, int)): + if not isinstance(fill_alpha, float | int): raise TypeError("Parameter 'fill_alpha' must be numeric.") if fill_alpha < 0: raise ValueError("Parameter 'fill_alpha' cannot be negative.") @@ -1563,7 +1566,7 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st param_dict["cmap"] = cmap if (groups := param_dict.get("groups")) is not None: - if not isinstance(groups, (list, str)): + if not isinstance(groups, list | str): raise TypeError("Parameter 'groups' must be a string or a list of strings.") if isinstance(groups, str): param_dict["groups"] = [groups] @@ -1575,7 +1578,7 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st if isinstance((palette := param_dict["palette"]), list): if not all(isinstance(p, str) for p in palette): raise ValueError("If specified, parameter 'palette' must contain only strings.") - elif isinstance(palette, (str, type(None))) and "palette" in param_dict: + elif isinstance(palette, str | type(None)) and "palette" in param_dict: param_dict["palette"] = [palette] if palette is not None else None if element_type in ["shapes", "points", "labels"] and (palette := param_dict.get("palette")) is not None: @@ -1589,9 +1592,9 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st ) if isinstance(cmap, list): - if not all(isinstance(c, (Colormap, str)) for c in cmap): + if not all(isinstance(c, Colormap | str) for c in cmap): raise TypeError("Each item in 'cmap' list must be a string or a Colormap.") - elif isinstance(cmap, (Colormap, str, type(None))): + elif isinstance(cmap, Colormap | str | type(None)): if "cmap" in param_dict: param_dict["cmap"] = [cmap] if cmap is not None else None else: @@ -1605,20 +1608,20 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st if (norm := param_dict.get("norm")) is not None: if element_type in ["images", "labels"] and not isinstance(norm, Normalize): raise TypeError("Parameter 'norm' must be of type Normalize.") - if element_type in ["shapes", "points"] and not isinstance(norm, (bool, Normalize)): + if element_type in ["shapes", "points"] and not isinstance(norm, bool | Normalize): raise TypeError("Parameter 'norm' must be a boolean or a mpl.Normalize.") if (scale := param_dict.get("scale")) is not None: if element_type in ["images", "labels"] and not isinstance(scale, str): raise TypeError("Parameter 'scale' must be a string if specified.") if element_type == "shapes": - if not isinstance(scale, (float, int)): + if not isinstance(scale, float | int): raise TypeError("Parameter 'scale' must be numeric.") if scale < 0: raise ValueError("Parameter 'scale' must be a positive number.") if size := param_dict.get("size"): - if not isinstance(size, (float, int)): + if not isinstance(size, float | int): raise TypeError("Parameter 'size' must be numeric.") if size < 0: raise ValueError("Parameter 'size' must be a positive number.") @@ -1968,7 +1971,7 @@ def _is_coercable_to_float(series: pd.Series) -> bool: def _ax_show_and_transform( - array: MaskedArray[np.float64, Any], + array: MaskedArray[tuple[int, ...], Any], trans_data: CompositeGenericTransform, ax: Axes, alpha: float | None = None, @@ -2052,7 +2055,7 @@ def _get_extent_and_range_for_datashader_canvas( def _create_image_from_datashader_result( ds_result: ds.transfer_functions.Image, factor: float, ax: Axes -) -> tuple[MaskedArray[np.float64, Any], matplotlib.transforms.CompositeGenericTransform]: +) -> tuple[MaskedArray[tuple[int, ...], Any], matplotlib.transforms.CompositeGenericTransform]: # create SpatialImage from datashader output to get it back to original size rgba_image_data = ds_result.to_numpy().base rgba_image_data = np.transpose(rgba_image_data, (2, 0, 1)) diff --git a/tests/conftest.py b/tests/conftest.py index 05ef1b97..27adff68 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ from abc import ABC, ABCMeta +from collections.abc import Callable from functools import wraps from pathlib import Path -from typing import Callable, Optional, Union import matplotlib.pyplot as plt import numpy as np @@ -216,7 +216,7 @@ def sdata(request) -> SpatialData: return s -def _get_images() -> dict[str, Union[DataArray, DataTree]]: +def _get_images() -> dict[str, DataArray | DataTree]: out = {} dims_2d = ("c", "y", "x") dims_3d = ("z", "y", "x", "c") @@ -243,7 +243,7 @@ def _get_images() -> dict[str, Union[DataArray, DataTree]]: return out -def _get_labels() -> dict[str, Union[DataArray, DataTree]]: +def _get_labels() -> dict[str, DataArray | DataTree]: out = {} dims_2d = ("y", "x") dims_3d = ("z", "y", "x") @@ -344,9 +344,9 @@ def _get_points() -> dict[str, pa.Table]: def _get_table( - region: Optional[AnnData] = None, - region_key: Optional[str] = None, - instance_key: Optional[str] = None, + region: AnnData | None = None, + region_key: str | None = None, + instance_key: str | None = None, ) -> AnnData: region_key = region_key or "annotated_region" instance_key = instance_key or "instance_id" @@ -374,7 +374,7 @@ def __new__(cls, clsname, superclasses, attributedict): class PlotTester(ABC): # noqa: B024 @classmethod - def compare(cls, basename: str, tolerance: Optional[float] = None): + def compare(cls, basename: str, tolerance: float | None = None): ACTUAL.mkdir(parents=True, exist_ok=True) out_path = ACTUAL / f"{basename}.png" @@ -397,7 +397,7 @@ def compare(cls, basename: str, tolerance: Optional[float] = None): assert res is None, res -def _decorate(fn: Callable, clsname: str, name: Optional[str] = None) -> Callable: +def _decorate(fn: Callable, clsname: str, name: str | None = None) -> Callable: @wraps(fn) def save_and_compare(self, *args, **kwargs): fn(self, *args, **kwargs) diff --git a/tests/pl/test_utils.py b/tests/pl/test_utils.py index 2c017055..0ad75710 100644 --- a/tests/pl/test_utils.py +++ b/tests/pl/test_utils.py @@ -1,5 +1,3 @@ -from typing import Union - import matplotlib import matplotlib.pyplot as plt import numpy as np @@ -28,7 +26,7 @@ # replace with # from spatialdata._types import ColorLike # once https://github.com/scverse/spatialdata/pull/689/ is in a release -ColorLike = Union[tuple[float, ...], str] +ColorLike = tuple[float, ...] | str class TestUtils(PlotTester, metaclass=PlotTesterMeta): From 6854121209a40cde7044ab4715f4802c2e1d0406 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Tue, 12 Nov 2024 00:01:14 +0100 Subject: [PATCH 4/6] change required python --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 52014620..1ae44391 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ maintainers = [ urls.Documentation = "https://spatialdata.scverse.org/projects/plot/en/latest/index.html" urls.Source = "https://github.com/scverse/spatialdata-plot.git" urls.Home-page = "https://github.com/scverse/spatialdata-plot.git" -requires-python = ">=3.9" +requires-python = ">=3.10" dynamic= [ "version" # allow version to be set by git tags ] From 9e4274c959179053849bf3f3d0a003e6034aaed0 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Tue, 26 Nov 2024 14:55:48 +0100 Subject: [PATCH 5/6] adjust for scanpy returning df --- src/spatialdata_plot/pl/basic.py | 2 +- src/spatialdata_plot/pl/render.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index d23b9710..78c59436 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -949,7 +949,7 @@ def show( if wanted_labels_on_this_cs: if (table := params_copy.table_name) is not None: colors = sc.get.obs_df(sdata[table], params_copy.color) - if isinstance(colors.dtype, pd.CategoricalDtype): + if isinstance(colors[params_copy.color].dtype, pd.CategoricalDtype): _maybe_set_colors( source=sdata[table], target=sdata[table], diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index a8a07058..a3a6c1a1 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -441,7 +441,7 @@ def _render_points( if col_for_color is not None: cols = sc.get.obs_df(adata, col_for_color) # maybe set color based on type - if isinstance(cols.dtype, pd.CategoricalDtype): + if isinstance(cols[col_for_color].dtype, pd.CategoricalDtype): _maybe_set_colors( source=adata, target=adata, From f069e788c4001c91a97c821e714bed3d9176418f Mon Sep 17 00:00:00 2001 From: LucaMarconato <2664412+LucaMarconato@users.noreply.github.com> Date: Tue, 26 Nov 2024 15:21:22 +0100 Subject: [PATCH 6/6] workaround pre-release not installing in test ci --- .github/workflows/test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index c690be1b..e0b23ba7 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -49,7 +49,7 @@ jobs: pip install pytest-cov - name: Install dependencies run: | - pip install --pre -e ".[dev,test]" + pip install --pre -e ".[dev,test,pre]" - name: Test env: MPLBACKEND: agg