Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import geopandas as gpd
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.ticker
import numpy as np
import pandas as pd
import scanpy as sc
Expand Down Expand Up @@ -141,6 +142,15 @@ def _render_shapes(
color_source_vector = color_source_vector[mask]
color_vector = color_vector[mask]

# continuous case: leave NaNs as NaNs; utils maps them to na_color during draw
if color_source_vector is None and not values_are_categorical:
color_vector = np.asarray(color_vector, dtype=float)
if np.isnan(color_vector).any():
nan_count = int(np.isnan(color_vector).sum())
logger.warning(
f"Found {nan_count} NaN values in color data. These observations will be colored with the 'na_color'."
)

# Using dict.fromkeys here since set returns in arbitrary order
# remove the color of NaN values, else it might be assigned to a category
# order of color in the palette should agree to order of occurence
Expand Down Expand Up @@ -195,7 +205,10 @@ def _render_shapes(

# Handle circles encoded as points with radius
if is_point.any():
scale = shapes[is_point]["radius"] * render_params.scale
radius_values = shapes[is_point]["radius"]
# Convert to numeric, replacing non-numeric values with NaN
radius_numeric = pd.to_numeric(radius_values, errors="coerce")
scale = radius_numeric * render_params.scale
shapes.loc[is_point, "geometry"] = _geometry[is_point].buffer(scale.to_numpy())

# apply transformations to the individual points
Expand All @@ -218,6 +231,20 @@ def _render_shapes(

# in case we are coloring by a column in table
if col_for_color is not None and col_for_color not in transformed_element.columns:
# Ensure color vector length matches the number of shapes
if len(color_vector) != len(transformed_element):
if len(color_vector) == 1:
# If single color, broadcast to all shapes
color_vector = [color_vector[0]] * len(transformed_element)
else:
# If lengths don't match, pad or truncate to match
if len(color_vector) > len(transformed_element):
color_vector = color_vector[: len(transformed_element)]
else:
# Pad with the last color or na_color
na_color = render_params.cmap_params.na_color.get_hex_with_alpha()
color_vector = list(color_vector) + [na_color] * (len(transformed_element) - len(color_vector))

transformed_element[col_for_color] = color_vector if color_source_vector is None else color_source_vector
# Render shapes with datashader
color_by_categorical = col_for_color is not None and color_source_vector is not None
Expand Down Expand Up @@ -447,12 +474,13 @@ def _render_shapes(
path.vertices = trans.transform(path.vertices)

if not values_are_categorical:
# If the user passed a Normalize object with vmin/vmax we'll use those,
# if not we'll use the min/max of the color_vector
_cax.set_clim(
vmin=render_params.cmap_params.norm.vmin or min(color_vector),
vmax=render_params.cmap_params.norm.vmax or max(color_vector),
)
vmin = render_params.cmap_params.norm.vmin
vmax = render_params.cmap_params.norm.vmax
if vmin is None:
vmin = float(np.nanmin(color_vector))
if vmax is None:
vmax = float(np.nanmax(color_vector))
_cax.set_clim(vmin=vmin, vmax=vmax)

if (
len(set(color_vector)) != 1
Expand Down
Loading