diff --git a/docs/colorbars_legends.py b/docs/colorbars_legends.py index 10a4099c8..8e8002975 100644 --- a/docs/colorbars_legends.py +++ b/docs/colorbars_legends.py @@ -469,3 +469,44 @@ ax = axs[1] ax.legend(hs2, loc="b", ncols=3, center=True, title="centered rows") axs.format(xlabel="xlabel", ylabel="ylabel", suptitle="Legend formatting demo") +# %% [raw] raw_mimetype="text/restructuredtext" +# .. _ug_guides_decouple: +# +# Decoupling legend content and location +# -------------------------------------- +# +# Sometimes you may want to generate a legend using handles from specific axes +# but place it relative to other axes. In UltraPlot, you can achieve this by passing +# both the `ax` and `ref` keywords to :func:`~ultraplot.figure.Figure.legend` +# (or :func:`~ultraplot.figure.Figure.colorbar`). The `ax` keyword specifies the +# axes used to generate the legend handles, while the `ref` keyword specifies the +# reference axes used to determine the legend location. +# +# For example, to draw a legend based on the handles in the second row of subplots +# but place it below the first row of subplots, you can use +# ``fig.legend(ax=axs[1, :], ref=axs[0, :], loc='bottom')``. If ``ref`` is a list +# of axes, UltraPlot intelligently infers the span (width or height) and anchors +# the legend to the appropriate outer edge (e.g., the bottom-most axis for ``loc='bottom'`` +# or the right-most axis for ``loc='right'``). + +# %% +import numpy as np + +import ultraplot as uplt + +fig, axs = uplt.subplots(nrows=2, ncols=2, refwidth=2, share=False) +axs.format(abc="A.", suptitle="Decoupled legend location demo") + +# Plot data on all axes +state = np.random.RandomState(51423) +data = (state.rand(20, 4) - 0.5).cumsum(axis=0) +for ax in axs: + ax.plot(data, cycle="mplotcolors", labels=list("abcd")) + +# Legend 1: Content from Row 2 (ax=axs[1, :]), Location below Row 1 (ref=axs[0, :]) +# This places a legend describing the bottom row data underneath the top row. +fig.legend(ax=axs[1, :], ref=axs[0, :], loc="bottom", title="Data from Row 2") + +# Legend 2: Content from Row 1 (ax=axs[0, :]), Location below Row 2 (ref=axs[1, :]) +# This places a legend describing the top row data underneath the bottom row. +fig.legend(ax=axs[0, :], ref=axs[1, :], loc="bottom", title="Data from Row 1") diff --git a/ultraplot/figure.py b/ultraplot/figure.py index 5d302f318..ed7f1b6a1 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -2594,6 +2594,8 @@ def colorbar( """ # Backwards compatibility ax = kwargs.pop("ax", None) + ref = kwargs.pop("ref", None) + loc_ax = ref if ref is not None else ax cax = kwargs.pop("cax", None) if isinstance(values, maxes.Axes): cax = _not_none(cax_positional=values, cax=cax) @@ -2613,20 +2615,102 @@ def colorbar( with context._state_context(cax, _internal_call=True): # do not wrap pcolor cb = super().colorbar(mappable, cax=cax, **kwargs) # Axes panel colorbar - elif ax is not None: + elif loc_ax is not None: # Check if span parameters are provided has_span = _not_none(span, row, col, rows, cols) is not None + # Infer span from loc_ax if it is a list and no span provided + if ( + not has_span + and np.iterable(loc_ax) + and not isinstance(loc_ax, (str, maxes.Axes)) + ): + loc_trans = _translate_loc(loc, "colorbar", default=rc["colorbar.loc"]) + side = ( + loc_trans + if loc_trans in ("left", "right", "top", "bottom") + else None + ) + + if side: + r_min, r_max = float("inf"), float("-inf") + c_min, c_max = float("inf"), float("-inf") + valid_ax = False + for axi in loc_ax: + if not hasattr(axi, "get_subplotspec"): + continue + ss = axi.get_subplotspec() + if ss is None: + continue + ss = ss.get_topmost_subplotspec() + r1, r2, c1, c2 = ss._get_rows_columns() + r_min = min(r_min, r1) + r_max = max(r_max, r2) + c_min = min(c_min, c1) + c_max = max(c_max, c2) + valid_ax = True + + if valid_ax: + if side in ("left", "right"): + rows = (r_min + 1, r_max + 1) + else: + cols = (c_min + 1, c_max + 1) + has_span = True + # Extract a single axes from array if span is provided # Otherwise, pass the array as-is for normal colorbar behavior - if has_span and np.iterable(ax) and not isinstance(ax, (str, maxes.Axes)): - try: - ax_single = next(iter(ax)) + if ( + has_span + and np.iterable(loc_ax) + and not isinstance(loc_ax, (str, maxes.Axes)) + ): + # Pick the best axis to anchor to based on the colorbar side + loc_trans = _translate_loc(loc, "colorbar", default=rc["colorbar.loc"]) + side = ( + loc_trans + if loc_trans in ("left", "right", "top", "bottom") + else None + ) - except (TypeError, StopIteration): - ax_single = ax + best_ax = None + best_coord = float("-inf") + + # If side is determined, search for the edge axis + if side: + for axi in loc_ax: + if not hasattr(axi, "get_subplotspec"): + continue + ss = axi.get_subplotspec() + if ss is None: + continue + ss = ss.get_topmost_subplotspec() + r1, r2, c1, c2 = ss._get_rows_columns() + + if side == "right": + val = c2 # Maximize column index + elif side == "left": + val = -c1 # Minimize column index + elif side == "bottom": + val = r2 # Maximize row index + elif side == "top": + val = -r1 # Minimize row index + else: + val = 0 + + if val > best_coord: + best_coord = val + best_ax = axi + + # Fallback to first axis + if best_ax is None: + try: + ax_single = next(iter(loc_ax)) + except (TypeError, StopIteration): + ax_single = loc_ax + else: + ax_single = best_ax else: - ax_single = ax + ax_single = loc_ax # Pass span parameters through to axes colorbar cb = ax_single.colorbar( @@ -2700,27 +2784,136 @@ def legend( matplotlib.axes.Axes.legend """ ax = kwargs.pop("ax", None) + ref = kwargs.pop("ref", None) + loc_ax = ref if ref is not None else ax + # Axes panel legend - if ax is not None: + if loc_ax is not None: + content_ax = ax if ax is not None else loc_ax # Check if span parameters are provided has_span = _not_none(span, row, col, rows, cols) is not None - # Extract a single axes from array if span is provided - # Otherwise, pass the array as-is for normal legend behavior - # Automatically collect handles and labels from spanned axes if not provided - if has_span and np.iterable(ax) and not isinstance(ax, (str, maxes.Axes)): - # Auto-collect handles and labels if not explicitly provided - if handles is None and labels is None: - handles, labels = [], [] - for axi in ax: + + # Automatically collect handles and labels from content axes if not provided + # Case 1: content_ax is a list (we must auto-collect) + # Case 2: content_ax != loc_ax (we must auto-collect because loc_ax.legend won't find content_ax handles) + must_collect = ( + np.iterable(content_ax) + and not isinstance(content_ax, (str, maxes.Axes)) + ) or (content_ax is not loc_ax) + + if must_collect and handles is None and labels is None: + handles, labels = [], [] + # Handle list of axes + if np.iterable(content_ax) and not isinstance( + content_ax, (str, maxes.Axes) + ): + for axi in content_ax: h, l = axi.get_legend_handles_labels() handles.extend(h) labels.extend(l) - try: - ax_single = next(iter(ax)) - except (TypeError, StopIteration): - ax_single = ax + # Handle single axis + else: + handles, labels = content_ax.get_legend_handles_labels() + + # Infer span from loc_ax if it is a list and no span provided + if ( + not has_span + and np.iterable(loc_ax) + and not isinstance(loc_ax, (str, maxes.Axes)) + ): + loc_trans = _translate_loc(loc, "legend", default=rc["legend.loc"]) + side = ( + loc_trans + if loc_trans in ("left", "right", "top", "bottom") + else None + ) + + if side: + r_min, r_max = float("inf"), float("-inf") + c_min, c_max = float("inf"), float("-inf") + valid_ax = False + for axi in loc_ax: + if not hasattr(axi, "get_subplotspec"): + continue + ss = axi.get_subplotspec() + if ss is None: + continue + ss = ss.get_topmost_subplotspec() + r1, r2, c1, c2 = ss._get_rows_columns() + r_min = min(r_min, r1) + r_max = max(r_max, r2) + c_min = min(c_min, c1) + c_max = max(c_max, c2) + valid_ax = True + + if valid_ax: + if side in ("left", "right"): + rows = (r_min + 1, r_max + 1) + else: + cols = (c_min + 1, c_max + 1) + has_span = True + + # Extract a single axes from array if span is provided (or if ref is a list) + # Otherwise, pass the array as-is for normal legend behavior (only if loc_ax is list) + if ( + has_span + and np.iterable(loc_ax) + and not isinstance(loc_ax, (str, maxes.Axes)) + ): + # Pick the best axis to anchor to based on the legend side + loc_trans = _translate_loc(loc, "legend", default=rc["legend.loc"]) + side = ( + loc_trans + if loc_trans in ("left", "right", "top", "bottom") + else None + ) + + best_ax = None + best_coord = float("-inf") + + # If side is determined, search for the edge axis + if side: + for axi in loc_ax: + if not hasattr(axi, "get_subplotspec"): + continue + ss = axi.get_subplotspec() + if ss is None: + continue + ss = ss.get_topmost_subplotspec() + r1, r2, c1, c2 = ss._get_rows_columns() + + if side == "right": + val = c2 # Maximize column index + elif side == "left": + val = -c1 # Minimize column index + elif side == "bottom": + val = r2 # Maximize row index + elif side == "top": + val = -r1 # Minimize row index + else: + val = 0 + + if val > best_coord: + best_coord = val + best_ax = axi + + # Fallback to first axis if no best axis found (or side is None) + if best_ax is None: + try: + ax_single = next(iter(loc_ax)) + except (TypeError, StopIteration): + ax_single = loc_ax + else: + ax_single = best_ax + else: - ax_single = ax + ax_single = loc_ax + if isinstance(ax_single, list): + try: + ax_single = pgridspec.SubplotGrid(ax_single) + except ValueError: + ax_single = ax_single[0] + leg = ax_single.legend( handles, labels, diff --git a/ultraplot/gridspec.py b/ultraplot/gridspec.py index 288f1abc4..93a6343a5 100644 --- a/ultraplot/gridspec.py +++ b/ultraplot/gridspec.py @@ -425,6 +425,12 @@ def _encode_indices(self, *args, which=None, panel=False): nums = [] idxs = self._get_indices(which=which, panel=panel) for arg in args: + if isinstance(arg, (list, np.ndarray)): + try: + nums.append([idxs[int(i)] for i in arg]) + except (IndexError, TypeError): + raise ValueError(f"Invalid gridspec index {arg}.") + continue try: nums.append(idxs[arg]) except (IndexError, TypeError): @@ -1612,10 +1618,13 @@ def __getitem__(self, key): >>> axs[:, 0] # a SubplotGrid containing the subplots in the first column """ # Allow 1D list-like indexing - if isinstance(key, int): + if isinstance(key, (Integral, np.integer)): return list.__getitem__(self, key) elif isinstance(key, slice): return SubplotGrid(list.__getitem__(self, key)) + elif isinstance(key, (list, np.ndarray)): + # NOTE: list.__getitem__ does not support numpy integers + return SubplotGrid([list.__getitem__(self, int(i)) for i in key]) # Allow 2D array-like indexing # NOTE: We assume this is a 2D array of subplots, because this is diff --git a/ultraplot/tests/test_gridspec.py b/ultraplot/tests/test_gridspec.py index e3890d7a3..b676f36a9 100644 --- a/ultraplot/tests/test_gridspec.py +++ b/ultraplot/tests/test_gridspec.py @@ -1,5 +1,6 @@ -import ultraplot as uplt import pytest + +import ultraplot as uplt from ultraplot.gridspec import SubplotGrid @@ -72,3 +73,56 @@ def test_tight_layout_disabled(): gs = ax.get_subplotspec().get_gridspec() with pytest.raises(RuntimeError): gs.tight_layout(fig) + + +def test_gridspec_slicing(): + """ + Test various slicing methods on SubplotGrid, including 1D list/array indexing. + """ + import numpy as np + + fig, axs = uplt.subplots(nrows=4, ncols=4) + + # Test 1D integer indexing + assert axs[0].number == 1 + assert axs[15].number == 16 + + # Test 1D slice indexing + subset = axs[0:2] + assert isinstance(subset, SubplotGrid) + assert len(subset) == 2 + assert subset[0].number == 1 + assert subset[1].number == 2 + + # Test 1D list indexing (Fix #1) + subset_list = axs[[0, 5]] + assert isinstance(subset_list, SubplotGrid) + assert len(subset_list) == 2 + assert subset_list[0].number == 1 + assert subset_list[1].number == 6 + + # Test 1D array indexing + subset_array = axs[np.array([0, 5])] + assert isinstance(subset_array, SubplotGrid) + assert len(subset_array) == 2 + assert subset_array[0].number == 1 + assert subset_array[1].number == 6 + + # Test 2D slicing (tuple of slices) + # axs[0:2, :] -> Rows 0 and 1, all cols + subset_2d = axs[0:2, :] + assert isinstance(subset_2d, SubplotGrid) + # 2 rows * 4 cols = 8 axes + assert len(subset_2d) == 8 + + # Test 2D mixed slicing (list in one dim) (Fix #2 related to _encode_indices) + # axs[[0, 1], :] -> Row indices 0 and 1, all cols + subset_mixed = axs[[0, 1], :] + assert isinstance(subset_mixed, SubplotGrid) + assert len(subset_mixed) == 8 + + # Verify content + # subset_mixed[0] -> Row 0, Col 0 -> Number 1 + # subset_mixed[4] -> Row 1, Col 0 -> Number 5 (since 4 cols per row) + assert subset_mixed[0].number == 1 + assert subset_mixed[4].number == 5 diff --git a/ultraplot/tests/test_legend.py b/ultraplot/tests/test_legend.py index 6b984a55e..a37f2ff0a 100644 --- a/ultraplot/tests/test_legend.py +++ b/ultraplot/tests/test_legend.py @@ -529,3 +529,91 @@ def test_legend_explicit_handles_labels_override_auto_collection(): assert leg is not None assert len(leg.get_texts()) == 1 assert leg.get_texts()[0].get_text() == "custom_label" + + +def test_legend_ref_argument(): + """Test using 'ref' to decouple legend location from content axes.""" + fig, axs = uplt.subplots(nrows=2, ncols=2) + axs[0, 0].plot([], [], label="line1") # Row 0 + axs[1, 0].plot([], [], label="line2") # Row 1 + + # Place legend below Row 0 (axs[0, :]) using content from Row 1 (axs[1, :]) + leg = fig.legend(ax=axs[1, :], ref=axs[0, :], loc="bottom") + + assert leg is not None + + # Should be a single legend because span is inferred from ref + assert not isinstance(leg, tuple) + + texts = [t.get_text() for t in leg.get_texts()] + assert "line2" in texts + assert "line1" not in texts + + +def test_legend_ref_argument_no_ax(): + """Test using 'ref' where 'ax' is implied to be 'ref'.""" + fig, axs = uplt.subplots(nrows=1, ncols=1) + axs[0].plot([], [], label="line1") + + # ref provided, ax=None. Should behave like ax=ref. + leg = fig.legend(ref=axs[0], loc="bottom") + assert leg is not None + + # Should be a single legend + assert not isinstance(leg, tuple) + + texts = [t.get_text() for t in leg.get_texts()] + assert "line1" in texts + + +def test_ref_with_explicit_handles(): + """Test using ref with explicit handles and labels.""" + fig, axs = uplt.subplots(ncols=2) + h = axs[0].plot([0, 1], [0, 1], label="line") + + # Place legend below both axes (ref=axs) using explicit handle + leg = fig.legend(handles=h, labels=["explicit"], ref=axs, loc="bottom") + + assert leg is not None + texts = [t.get_text() for t in leg.get_texts()] + assert texts == ["explicit"] + + +def test_ref_with_non_edge_location(): + """Test using ref with an inset location (should not infer span).""" + fig, axs = uplt.subplots(ncols=2) + axs[0].plot([0, 1], label="test") + + # ref=axs (list of 2). + # 'upper left' is inset. Should fallback to first axis. + leg = fig.legend(ref=axs, loc="upper left") + + assert leg is not None + if isinstance(leg, tuple): + leg = leg[0] + # Should be associated with axs[0] (or a panel of it? Inset is child of axes) + # leg.axes is the axes containing the legend. For inset, it's the parent axes? + # No, legend itself is an artist. leg.axes should be axs[0]. + assert leg.axes is axs[0] + + +def test_ref_with_single_axis(): + """Test using ref with a single axis object.""" + fig, axs = uplt.subplots(ncols=2) + axs[0].plot([0, 1], label="line") + + # ref=axs[1]. loc='bottom'. + leg = fig.legend(ref=axs[1], ax=axs[0], loc="bottom") + assert leg is not None + + +def test_ref_with_manual_axes_no_subplotspec(): + """Test using ref with axes that don't have subplotspec.""" + fig = uplt.figure() + ax1 = fig.add_axes([0.1, 0.1, 0.4, 0.4]) + ax2 = fig.add_axes([0.5, 0.1, 0.4, 0.4]) + ax1.plot([0, 1], [0, 1], label="line") + + # ref=[ax1, ax2]. loc='upper right' (inset). + leg = fig.legend(ref=[ax1, ax2], loc="upper right") + assert leg is not None