From 769c1c55c9df3cc6dff66f0759d10b5c7896c970 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Tue, 13 Jan 2026 13:31:31 +1000 Subject: [PATCH 1/4] Fix SubplotGrid indexing and allow legend placement decoupling --- ultraplot/figure.py | 12 ++++++++++++ ultraplot/gridspec.py | 33 +++++++++++++++++++++++++++++++-- 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/ultraplot/figure.py b/ultraplot/figure.py index 5d302f318..53b335bcb 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -1417,6 +1417,18 @@ def _add_axes_panel( if span_override is not None: kw["span_override"] = span_override + # Check for position override (row for horizontal panels, col for vertical panels) + pos_override = None + if side in ("left", "right"): + if _not_none(cols, col) is not None: + pos_override = _not_none(cols, col) + else: + if _not_none(rows, row) is not None: + pos_override = _not_none(rows, row) + + if pos_override is not None: + kw["pos_override"] = pos_override + ss, share = gs._insert_panel_slot(side, ax, **kw) # Guard: GeoAxes with non-rectilinear projections cannot share with panels if isinstance(ax, paxes.GeoAxes) and not ax._is_rectilinear(): diff --git a/ultraplot/gridspec.py b/ultraplot/gridspec.py index 288f1abc4..bae0c389c 100644 --- a/ultraplot/gridspec.py +++ b/ultraplot/gridspec.py @@ -425,6 +425,12 @@ def _encode_indices(self, *args, which=None, panel=False): nums = [] idxs = self._get_indices(which=which, panel=panel) for arg in args: + if isinstance(arg, (list, np.ndarray)): + try: + nums.append([idxs[int(i)] for i in arg]) + except (IndexError, TypeError): + raise ValueError(f"Invalid gridspec index {arg}.") + continue try: nums.append(idxs[arg]) except (IndexError, TypeError): @@ -595,6 +601,7 @@ def _parse_panel_arg_with_span( side: str, ax: "paxes.Axes", span_override: Optional[Union[int, Tuple[int, int]]], + pos_override: Optional[Union[int, Tuple[int, int]]] = None, ) -> Tuple[str, int, slice]: """ Parse panel arg with span override. Uses ax for position, span for extent. @@ -607,6 +614,8 @@ def _parse_panel_arg_with_span( The axes to position the panel relative to span_override : int or tuple The span extent (1-indexed like subplot numbers) + pos_override : int or tuple, optional + The row or column index (1-indexed like subplot numbers) Returns ------- @@ -621,6 +630,20 @@ def _parse_panel_arg_with_span( ss = ax.get_subplotspec().get_topmost_subplotspec() row1, row2, col1, col2 = ss._get_rows_columns() + # Override axes position if requested + if pos_override is not None: + if isinstance(pos_override, Integral): + pos1, pos2 = pos_override - 1, pos_override - 1 + else: + pos_override = np.atleast_1d(pos_override) + pos1, pos2 = pos_override[0] - 1, pos_override[-1] - 1 + + # NOTE: We only need the relevant coordinate (row or col) + if side in ("left", "right"): + col1, col2 = pos1, pos2 + else: + row1, row2 = pos1, pos2 + # Determine slot and index based on side slot = side[0] offset = len(ax._panel_dict[side]) + 1 @@ -663,6 +686,7 @@ def _insert_panel_slot( pad: Optional[Union[float, str]] = None, filled: bool = False, span_override: Optional[Union[int, Tuple[int, int]]] = None, + pos_override: Optional[Union[int, Tuple[int, int]]] = None, ): """ Insert a panel slot into the existing gridspec. The `side` is the panel side @@ -676,7 +700,9 @@ def _insert_panel_slot( raise ValueError(f"Invalid side {side}.") # Use span override if provided if span_override is not None: - slot, idx, span = self._parse_panel_arg_with_span(side, arg, span_override) + slot, idx, span = self._parse_panel_arg_with_span( + side, arg, span_override, pos_override=pos_override + ) else: slot, idx, span = self._parse_panel_arg(side, arg) pad = units(pad, "em", "in") @@ -1612,10 +1638,13 @@ def __getitem__(self, key): >>> axs[:, 0] # a SubplotGrid containing the subplots in the first column """ # Allow 1D list-like indexing - if isinstance(key, int): + if isinstance(key, (Integral, np.integer)): return list.__getitem__(self, key) elif isinstance(key, slice): return SubplotGrid(list.__getitem__(self, key)) + elif isinstance(key, (list, np.ndarray)): + # NOTE: list.__getitem__ does not support numpy integers + return SubplotGrid([list.__getitem__(self, int(i)) for i in key]) # Allow 2D array-like indexing # NOTE: We assume this is a 2D array of subplots, because this is From ab996e1fd70fd1892ec138803fb8456e9e4dae8b Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Tue, 13 Jan 2026 13:42:35 +1000 Subject: [PATCH 2/4] Add ref argument to fig.legend, support 1D slicing, and intelligent placement inference --- docs/colorbars_legends.py | 41 +++++++++ ultraplot/figure.py | 138 +++++++++++++++++++++++++------ ultraplot/gridspec.py | 22 +---- ultraplot/tests/test_gridspec.py | 56 ++++++++++++- ultraplot/tests/test_legend.py | 40 +++++++++ 5 files changed, 249 insertions(+), 48 deletions(-) diff --git a/docs/colorbars_legends.py b/docs/colorbars_legends.py index 10a4099c8..8e8002975 100644 --- a/docs/colorbars_legends.py +++ b/docs/colorbars_legends.py @@ -469,3 +469,44 @@ ax = axs[1] ax.legend(hs2, loc="b", ncols=3, center=True, title="centered rows") axs.format(xlabel="xlabel", ylabel="ylabel", suptitle="Legend formatting demo") +# %% [raw] raw_mimetype="text/restructuredtext" +# .. _ug_guides_decouple: +# +# Decoupling legend content and location +# -------------------------------------- +# +# Sometimes you may want to generate a legend using handles from specific axes +# but place it relative to other axes. In UltraPlot, you can achieve this by passing +# both the `ax` and `ref` keywords to :func:`~ultraplot.figure.Figure.legend` +# (or :func:`~ultraplot.figure.Figure.colorbar`). The `ax` keyword specifies the +# axes used to generate the legend handles, while the `ref` keyword specifies the +# reference axes used to determine the legend location. +# +# For example, to draw a legend based on the handles in the second row of subplots +# but place it below the first row of subplots, you can use +# ``fig.legend(ax=axs[1, :], ref=axs[0, :], loc='bottom')``. If ``ref`` is a list +# of axes, UltraPlot intelligently infers the span (width or height) and anchors +# the legend to the appropriate outer edge (e.g., the bottom-most axis for ``loc='bottom'`` +# or the right-most axis for ``loc='right'``). + +# %% +import numpy as np + +import ultraplot as uplt + +fig, axs = uplt.subplots(nrows=2, ncols=2, refwidth=2, share=False) +axs.format(abc="A.", suptitle="Decoupled legend location demo") + +# Plot data on all axes +state = np.random.RandomState(51423) +data = (state.rand(20, 4) - 0.5).cumsum(axis=0) +for ax in axs: + ax.plot(data, cycle="mplotcolors", labels=list("abcd")) + +# Legend 1: Content from Row 2 (ax=axs[1, :]), Location below Row 1 (ref=axs[0, :]) +# This places a legend describing the bottom row data underneath the top row. +fig.legend(ax=axs[1, :], ref=axs[0, :], loc="bottom", title="Data from Row 2") + +# Legend 2: Content from Row 1 (ax=axs[0, :]), Location below Row 2 (ref=axs[1, :]) +# This places a legend describing the top row data underneath the bottom row. +fig.legend(ax=axs[0, :], ref=axs[1, :], loc="bottom", title="Data from Row 1") diff --git a/ultraplot/figure.py b/ultraplot/figure.py index 53b335bcb..53d955835 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -1417,18 +1417,6 @@ def _add_axes_panel( if span_override is not None: kw["span_override"] = span_override - # Check for position override (row for horizontal panels, col for vertical panels) - pos_override = None - if side in ("left", "right"): - if _not_none(cols, col) is not None: - pos_override = _not_none(cols, col) - else: - if _not_none(rows, row) is not None: - pos_override = _not_none(rows, row) - - if pos_override is not None: - kw["pos_override"] = pos_override - ss, share = gs._insert_panel_slot(side, ax, **kw) # Guard: GeoAxes with non-rectilinear projections cannot share with panels if isinstance(ax, paxes.GeoAxes) and not ax._is_rectilinear(): @@ -2712,27 +2700,125 @@ def legend( matplotlib.axes.Axes.legend """ ax = kwargs.pop("ax", None) + ref = kwargs.pop("ref", None) + loc_ax = ref if ref is not None else ax + # Axes panel legend - if ax is not None: + if loc_ax is not None: + content_ax = ax if ax is not None else loc_ax # Check if span parameters are provided has_span = _not_none(span, row, col, rows, cols) is not None - # Extract a single axes from array if span is provided - # Otherwise, pass the array as-is for normal legend behavior - # Automatically collect handles and labels from spanned axes if not provided - if has_span and np.iterable(ax) and not isinstance(ax, (str, maxes.Axes)): - # Auto-collect handles and labels if not explicitly provided - if handles is None and labels is None: - handles, labels = [], [] - for axi in ax: + + # Automatically collect handles and labels from content axes if not provided + # Case 1: content_ax is a list (we must auto-collect) + # Case 2: content_ax != loc_ax (we must auto-collect because loc_ax.legend won't find content_ax handles) + must_collect = ( + np.iterable(content_ax) + and not isinstance(content_ax, (str, maxes.Axes)) + ) or (content_ax is not loc_ax) + + if must_collect and handles is None and labels is None: + handles, labels = [], [] + # Handle list of axes + if np.iterable(content_ax) and not isinstance( + content_ax, (str, maxes.Axes) + ): + for axi in content_ax: h, l = axi.get_legend_handles_labels() handles.extend(h) labels.extend(l) - try: - ax_single = next(iter(ax)) - except (TypeError, StopIteration): - ax_single = ax + # Handle single axis + else: + handles, labels = content_ax.get_legend_handles_labels() + + # Infer span from loc_ax if it is a list and no span provided + if ( + not has_span + and np.iterable(loc_ax) + and not isinstance(loc_ax, (str, maxes.Axes)) + ): + loc_trans = _translate_loc(loc, "legend", default=rc["legend.loc"]) + side = ( + loc_trans + if loc_trans in ("left", "right", "top", "bottom") + else None + ) + + if side: + r_min, r_max = float("inf"), float("-inf") + c_min, c_max = float("inf"), float("-inf") + valid_ax = False + for axi in loc_ax: + if not hasattr(axi, "get_subplotspec"): + continue + ss = axi.get_subplotspec().get_topmost_subplotspec() + r1, r2, c1, c2 = ss._get_rows_columns() + r_min = min(r_min, r1) + r_max = max(r_max, r2) + c_min = min(c_min, c1) + c_max = max(c_max, c2) + valid_ax = True + + if valid_ax: + if side in ("left", "right"): + rows = (r_min + 1, r_max + 1) + else: + cols = (c_min + 1, c_max + 1) + has_span = True + + # Extract a single axes from array if span is provided (or if ref is a list) + # Otherwise, pass the array as-is for normal legend behavior (only if loc_ax is list) + if ( + has_span + and np.iterable(loc_ax) + and not isinstance(loc_ax, (str, maxes.Axes)) + ): + # Pick the best axis to anchor to based on the legend side + loc_trans = _translate_loc(loc, "legend", default=rc["legend.loc"]) + side = ( + loc_trans + if loc_trans in ("left", "right", "top", "bottom") + else None + ) + + best_ax = None + best_coord = float("-inf") + + # If side is determined, search for the edge axis + if side: + for axi in loc_ax: + if not hasattr(axi, "get_subplotspec"): + continue + ss = axi.get_subplotspec().get_topmost_subplotspec() + r1, r2, c1, c2 = ss._get_rows_columns() + + if side == "right": + val = c2 # Maximize column index + elif side == "left": + val = -c1 # Minimize column index + elif side == "bottom": + val = r2 # Maximize row index + elif side == "top": + val = -r1 # Minimize row index + else: + val = 0 + + if val > best_coord: + best_coord = val + best_ax = axi + + # Fallback to first axis if no best axis found (or side is None) + if best_ax is None: + try: + ax_single = next(iter(loc_ax)) + except (TypeError, StopIteration): + ax_single = loc_ax + else: + ax_single = best_ax + else: - ax_single = ax + ax_single = loc_ax + leg = ax_single.legend( handles, labels, diff --git a/ultraplot/gridspec.py b/ultraplot/gridspec.py index bae0c389c..93a6343a5 100644 --- a/ultraplot/gridspec.py +++ b/ultraplot/gridspec.py @@ -601,7 +601,6 @@ def _parse_panel_arg_with_span( side: str, ax: "paxes.Axes", span_override: Optional[Union[int, Tuple[int, int]]], - pos_override: Optional[Union[int, Tuple[int, int]]] = None, ) -> Tuple[str, int, slice]: """ Parse panel arg with span override. Uses ax for position, span for extent. @@ -614,8 +613,6 @@ def _parse_panel_arg_with_span( The axes to position the panel relative to span_override : int or tuple The span extent (1-indexed like subplot numbers) - pos_override : int or tuple, optional - The row or column index (1-indexed like subplot numbers) Returns ------- @@ -630,20 +627,6 @@ def _parse_panel_arg_with_span( ss = ax.get_subplotspec().get_topmost_subplotspec() row1, row2, col1, col2 = ss._get_rows_columns() - # Override axes position if requested - if pos_override is not None: - if isinstance(pos_override, Integral): - pos1, pos2 = pos_override - 1, pos_override - 1 - else: - pos_override = np.atleast_1d(pos_override) - pos1, pos2 = pos_override[0] - 1, pos_override[-1] - 1 - - # NOTE: We only need the relevant coordinate (row or col) - if side in ("left", "right"): - col1, col2 = pos1, pos2 - else: - row1, row2 = pos1, pos2 - # Determine slot and index based on side slot = side[0] offset = len(ax._panel_dict[side]) + 1 @@ -686,7 +669,6 @@ def _insert_panel_slot( pad: Optional[Union[float, str]] = None, filled: bool = False, span_override: Optional[Union[int, Tuple[int, int]]] = None, - pos_override: Optional[Union[int, Tuple[int, int]]] = None, ): """ Insert a panel slot into the existing gridspec. The `side` is the panel side @@ -700,9 +682,7 @@ def _insert_panel_slot( raise ValueError(f"Invalid side {side}.") # Use span override if provided if span_override is not None: - slot, idx, span = self._parse_panel_arg_with_span( - side, arg, span_override, pos_override=pos_override - ) + slot, idx, span = self._parse_panel_arg_with_span(side, arg, span_override) else: slot, idx, span = self._parse_panel_arg(side, arg) pad = units(pad, "em", "in") diff --git a/ultraplot/tests/test_gridspec.py b/ultraplot/tests/test_gridspec.py index e3890d7a3..b676f36a9 100644 --- a/ultraplot/tests/test_gridspec.py +++ b/ultraplot/tests/test_gridspec.py @@ -1,5 +1,6 @@ -import ultraplot as uplt import pytest + +import ultraplot as uplt from ultraplot.gridspec import SubplotGrid @@ -72,3 +73,56 @@ def test_tight_layout_disabled(): gs = ax.get_subplotspec().get_gridspec() with pytest.raises(RuntimeError): gs.tight_layout(fig) + + +def test_gridspec_slicing(): + """ + Test various slicing methods on SubplotGrid, including 1D list/array indexing. + """ + import numpy as np + + fig, axs = uplt.subplots(nrows=4, ncols=4) + + # Test 1D integer indexing + assert axs[0].number == 1 + assert axs[15].number == 16 + + # Test 1D slice indexing + subset = axs[0:2] + assert isinstance(subset, SubplotGrid) + assert len(subset) == 2 + assert subset[0].number == 1 + assert subset[1].number == 2 + + # Test 1D list indexing (Fix #1) + subset_list = axs[[0, 5]] + assert isinstance(subset_list, SubplotGrid) + assert len(subset_list) == 2 + assert subset_list[0].number == 1 + assert subset_list[1].number == 6 + + # Test 1D array indexing + subset_array = axs[np.array([0, 5])] + assert isinstance(subset_array, SubplotGrid) + assert len(subset_array) == 2 + assert subset_array[0].number == 1 + assert subset_array[1].number == 6 + + # Test 2D slicing (tuple of slices) + # axs[0:2, :] -> Rows 0 and 1, all cols + subset_2d = axs[0:2, :] + assert isinstance(subset_2d, SubplotGrid) + # 2 rows * 4 cols = 8 axes + assert len(subset_2d) == 8 + + # Test 2D mixed slicing (list in one dim) (Fix #2 related to _encode_indices) + # axs[[0, 1], :] -> Row indices 0 and 1, all cols + subset_mixed = axs[[0, 1], :] + assert isinstance(subset_mixed, SubplotGrid) + assert len(subset_mixed) == 8 + + # Verify content + # subset_mixed[0] -> Row 0, Col 0 -> Number 1 + # subset_mixed[4] -> Row 1, Col 0 -> Number 5 (since 4 cols per row) + assert subset_mixed[0].number == 1 + assert subset_mixed[4].number == 5 diff --git a/ultraplot/tests/test_legend.py b/ultraplot/tests/test_legend.py index 6b984a55e..6efeff02c 100644 --- a/ultraplot/tests/test_legend.py +++ b/ultraplot/tests/test_legend.py @@ -529,3 +529,43 @@ def test_legend_explicit_handles_labels_override_auto_collection(): assert leg is not None assert len(leg.get_texts()) == 1 assert leg.get_texts()[0].get_text() == "custom_label" + + +import numpy as np + +import ultraplot as uplt + + +def test_legend_ref_argument(): + """Test using 'ref' to decouple legend location from content axes.""" + fig, axs = uplt.subplots(nrows=2, ncols=2) + axs[0, 0].plot([], [], label="line1") # Row 0 + axs[1, 0].plot([], [], label="line2") # Row 1 + + # Place legend below Row 0 (axs[0, :]) using content from Row 1 (axs[1, :]) + leg = fig.legend(ax=axs[1, :], ref=axs[0, :], loc="bottom") + + assert leg is not None + + # Should be a single legend because span is inferred from ref + assert not isinstance(leg, tuple) + + texts = [t.get_text() for t in leg.get_texts()] + assert "line2" in texts + assert "line1" not in texts + + +def test_legend_ref_argument_no_ax(): + """Test using 'ref' where 'ax' is implied to be 'ref'.""" + fig, axs = uplt.subplots(nrows=1, ncols=1) + axs[0].plot([], [], label="line1") + + # ref provided, ax=None. Should behave like ax=ref. + leg = fig.legend(ref=axs[0], loc="bottom") + assert leg is not None + + # Should be a single legend + assert not isinstance(leg, tuple) + + texts = [t.get_text() for t in leg.get_texts()] + assert "line1" in texts From c1f5a097fae3ac172229a4422c1a9d37c23d8241 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Tue, 13 Jan 2026 14:25:39 +1000 Subject: [PATCH 3/4] Add ref argument to fig.legend and fig.colorbar, support 1D slicing, intelligent placement, and robust checks --- ultraplot/figure.py | 113 ++++++++++++++++++++++++++++++--- ultraplot/tests/test_legend.py | 62 ++++++++++++++++-- 2 files changed, 161 insertions(+), 14 deletions(-) diff --git a/ultraplot/figure.py b/ultraplot/figure.py index 53d955835..ed7f1b6a1 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -2594,6 +2594,8 @@ def colorbar( """ # Backwards compatibility ax = kwargs.pop("ax", None) + ref = kwargs.pop("ref", None) + loc_ax = ref if ref is not None else ax cax = kwargs.pop("cax", None) if isinstance(values, maxes.Axes): cax = _not_none(cax_positional=values, cax=cax) @@ -2613,20 +2615,102 @@ def colorbar( with context._state_context(cax, _internal_call=True): # do not wrap pcolor cb = super().colorbar(mappable, cax=cax, **kwargs) # Axes panel colorbar - elif ax is not None: + elif loc_ax is not None: # Check if span parameters are provided has_span = _not_none(span, row, col, rows, cols) is not None + # Infer span from loc_ax if it is a list and no span provided + if ( + not has_span + and np.iterable(loc_ax) + and not isinstance(loc_ax, (str, maxes.Axes)) + ): + loc_trans = _translate_loc(loc, "colorbar", default=rc["colorbar.loc"]) + side = ( + loc_trans + if loc_trans in ("left", "right", "top", "bottom") + else None + ) + + if side: + r_min, r_max = float("inf"), float("-inf") + c_min, c_max = float("inf"), float("-inf") + valid_ax = False + for axi in loc_ax: + if not hasattr(axi, "get_subplotspec"): + continue + ss = axi.get_subplotspec() + if ss is None: + continue + ss = ss.get_topmost_subplotspec() + r1, r2, c1, c2 = ss._get_rows_columns() + r_min = min(r_min, r1) + r_max = max(r_max, r2) + c_min = min(c_min, c1) + c_max = max(c_max, c2) + valid_ax = True + + if valid_ax: + if side in ("left", "right"): + rows = (r_min + 1, r_max + 1) + else: + cols = (c_min + 1, c_max + 1) + has_span = True + # Extract a single axes from array if span is provided # Otherwise, pass the array as-is for normal colorbar behavior - if has_span and np.iterable(ax) and not isinstance(ax, (str, maxes.Axes)): - try: - ax_single = next(iter(ax)) + if ( + has_span + and np.iterable(loc_ax) + and not isinstance(loc_ax, (str, maxes.Axes)) + ): + # Pick the best axis to anchor to based on the colorbar side + loc_trans = _translate_loc(loc, "colorbar", default=rc["colorbar.loc"]) + side = ( + loc_trans + if loc_trans in ("left", "right", "top", "bottom") + else None + ) + + best_ax = None + best_coord = float("-inf") + + # If side is determined, search for the edge axis + if side: + for axi in loc_ax: + if not hasattr(axi, "get_subplotspec"): + continue + ss = axi.get_subplotspec() + if ss is None: + continue + ss = ss.get_topmost_subplotspec() + r1, r2, c1, c2 = ss._get_rows_columns() + + if side == "right": + val = c2 # Maximize column index + elif side == "left": + val = -c1 # Minimize column index + elif side == "bottom": + val = r2 # Maximize row index + elif side == "top": + val = -r1 # Minimize row index + else: + val = 0 - except (TypeError, StopIteration): - ax_single = ax + if val > best_coord: + best_coord = val + best_ax = axi + + # Fallback to first axis + if best_ax is None: + try: + ax_single = next(iter(loc_ax)) + except (TypeError, StopIteration): + ax_single = loc_ax + else: + ax_single = best_ax else: - ax_single = ax + ax_single = loc_ax # Pass span parameters through to axes colorbar cb = ax_single.colorbar( @@ -2751,7 +2835,10 @@ def legend( for axi in loc_ax: if not hasattr(axi, "get_subplotspec"): continue - ss = axi.get_subplotspec().get_topmost_subplotspec() + ss = axi.get_subplotspec() + if ss is None: + continue + ss = ss.get_topmost_subplotspec() r1, r2, c1, c2 = ss._get_rows_columns() r_min = min(r_min, r1) r_max = max(r_max, r2) @@ -2789,7 +2876,10 @@ def legend( for axi in loc_ax: if not hasattr(axi, "get_subplotspec"): continue - ss = axi.get_subplotspec().get_topmost_subplotspec() + ss = axi.get_subplotspec() + if ss is None: + continue + ss = ss.get_topmost_subplotspec() r1, r2, c1, c2 = ss._get_rows_columns() if side == "right": @@ -2818,6 +2908,11 @@ def legend( else: ax_single = loc_ax + if isinstance(ax_single, list): + try: + ax_single = pgridspec.SubplotGrid(ax_single) + except ValueError: + ax_single = ax_single[0] leg = ax_single.legend( handles, diff --git a/ultraplot/tests/test_legend.py b/ultraplot/tests/test_legend.py index 6efeff02c..15684632d 100644 --- a/ultraplot/tests/test_legend.py +++ b/ultraplot/tests/test_legend.py @@ -531,11 +531,6 @@ def test_legend_explicit_handles_labels_override_auto_collection(): assert leg.get_texts()[0].get_text() == "custom_label" -import numpy as np - -import ultraplot as uplt - - def test_legend_ref_argument(): """Test using 'ref' to decouple legend location from content axes.""" fig, axs = uplt.subplots(nrows=2, ncols=2) @@ -569,3 +564,60 @@ def test_legend_ref_argument_no_ax(): texts = [t.get_text() for t in leg.get_texts()] assert "line1" in texts +import matplotlib.pyplot as plt +import pytest + +import ultraplot as uplt + + +def test_ref_with_explicit_handles(): + """Test using ref with explicit handles and labels.""" + fig, axs = uplt.subplots(ncols=2) + h = axs[0].plot([0, 1], [0, 1], label="line") + + # Place legend below both axes (ref=axs) using explicit handle + leg = fig.legend(handles=h, labels=["explicit"], ref=axs, loc="bottom") + + assert leg is not None + texts = [t.get_text() for t in leg.get_texts()] + assert texts == ["explicit"] + + +def test_ref_with_non_edge_location(): + """Test using ref with an inset location (should not infer span).""" + fig, axs = uplt.subplots(ncols=2) + axs[0].plot([0, 1], label="test") + + # ref=axs (list of 2). + # 'upper left' is inset. Should fallback to first axis. + leg = fig.legend(ref=axs, loc="upper left") + + assert leg is not None + if isinstance(leg, tuple): + leg = leg[0] + # Should be associated with axs[0] (or a panel of it? Inset is child of axes) + # leg.axes is the axes containing the legend. For inset, it's the parent axes? + # No, legend itself is an artist. leg.axes should be axs[0]. + assert leg.axes is axs[0] + + +def test_ref_with_single_axis(): + """Test using ref with a single axis object.""" + fig, axs = uplt.subplots(ncols=2) + axs[0].plot([0, 1], label="line") + + # ref=axs[1]. loc='bottom'. + leg = fig.legend(ref=axs[1], ax=axs[0], loc="bottom") + assert leg is not None + + +def test_ref_with_manual_axes_no_subplotspec(): + """Test using ref with axes that don't have subplotspec.""" + fig = uplt.figure() + ax1 = fig.add_axes([0.1, 0.1, 0.4, 0.4]) + ax2 = fig.add_axes([0.5, 0.1, 0.4, 0.4]) + ax1.plot([0, 1], [0, 1], label="line") + + # ref=[ax1, ax2]. loc='upper right' (inset). + leg = fig.legend(ref=[ax1, ax2], loc="upper right") + assert leg is not None From 572adf3a7ad5e4ace97a32046f39853a4b0085ce Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Tue, 13 Jan 2026 15:24:49 +1000 Subject: [PATCH 4/4] Remove xdist from image compare --- .github/workflows/build-ultraplot.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build-ultraplot.yml b/.github/workflows/build-ultraplot.yml index 158f4bbb1..5a130c816 100644 --- a/.github/workflows/build-ultraplot.yml +++ b/.github/workflows/build-ultraplot.yml @@ -97,7 +97,7 @@ jobs: # Generate the baseline images and hash library python -c "import ultraplot as plt; plt.config.Configurator()._save_yaml('ultraplot.yml')" - pytest -x -n auto -W ignore \ + pytest -x -W ignore \ --mpl-generate-path=./ultraplot/tests/baseline/ \ --mpl-default-style="./ultraplot.yml"\ ultraplot/tests @@ -113,7 +113,7 @@ jobs: mkdir -p results python -c "import ultraplot as plt; plt.config.Configurator()._save_yaml('ultraplot.yml')" - pytest -x -n auto -W ignore -n auto\ + pytest -x -W ignore -n auto\ --mpl \ --mpl-baseline-path=./ultraplot/tests/baseline \ --mpl-results-path=./results/ \