diff --git a/RATapi/utils/plotting.py b/RATapi/utils/plotting.py index 92322c98..e6d925a3 100644 --- a/RATapi/utils/plotting.py +++ b/RATapi/utils/plotting.py @@ -1,14 +1,15 @@ """Plot results using the matplotlib library.""" import copy +import types from functools import partial, wraps from math import ceil, floor, sqrt from statistics import stdev -from textwrap import fill from typing import Callable, Literal, Optional, Union import matplotlib import matplotlib.pyplot as plt +import matplotlib.transforms as mtransforms import numpy as np from matplotlib.axes._axes import Axes from scipy.ndimage import gaussian_filter1d @@ -668,11 +669,15 @@ def plot_corner( num_params = len(params) - fig, axes = plt.subplots(num_params, num_params, figsize=(2 * num_params, 2 * num_params)) + fig, axes = plt.subplots(num_params, num_params, figsize=(11, 10)) # i is row, j is column for i, row_param in enumerate(params): for j, col_param in enumerate(params): current_axes: Axes = axes[i][j] + current_axes.tick_params(which="both", labelsize="medium") + current_axes.xaxis.offsetText.set_fontsize("small") + current_axes.yaxis.offsetText.set_fontsize("small") + if i == j: # diagonal: histograms plot_one_hist(results, param=row_param, smooth=smooth, axes=current_axes, **hist_kwargs) elif i > j: # lower triangle: 2d histograms @@ -687,10 +692,12 @@ def plot_corner( if i != len(params) - 1: current_axes.get_xaxis().set_visible(False) # make labels invisible as titles cover that + current_axes.yaxis._update_offset_text_position = types.MethodType( + _y_update_offset_text_position, current_axes.yaxis + ) + current_axes.yaxis.offset_text_position = "center" current_axes.set_ylabel("") current_axes.set_xlabel("") - - fig.tight_layout() if return_fig: return fig plt.show(block=block) @@ -776,7 +783,7 @@ def plot_one_hist( color="white", ) - axes.set_title(fill(results.fitNames[param], 20)) # use `fill` to wrap long titles + axes.set_title(results.fitNames[param], loc="left", fontsize="medium") if estimated_density: dx = bins[1] - bins[0] @@ -806,6 +813,47 @@ def plot_one_hist( plt.show(block=block) +def _y_update_offset_text_position(axis, _bboxes, bboxes2): + """Update the position of the Y axis offset text using the provided bounding boxes. + + Adapted from https://github.com/matplotlib/matplotlib/issues/4476#issuecomment-105627334. + + Parameters + ---------- + axis : matplotlib.axis.YAxis + Y axis to update. + _bboxes : List + list of bounding boxes + bboxes2 : List + list of bounding boxes + """ + x, y = axis.offsetText.get_position() + + if axis.offset_text_position == "left": + # y in axes coords, x in display coords + axis.offsetText.set_transform( + mtransforms.blended_transform_factory(axis.axes.transAxes, mtransforms.IdentityTransform()) + ) + + top = axis.axes.bbox.ymax + y = top + axis.OFFSETTEXTPAD * axis.figure.dpi / 72.0 + + else: + # x & y in display coords + axis.offsetText.set_transform(mtransforms.IdentityTransform()) + + # Northwest of upper-right corner of right-hand extent of tick labels + if bboxes2: + bbox = mtransforms.Bbox.union(bboxes2) + else: + bbox = axis.axes.bbox + center = bbox.ymin + (bbox.ymax - bbox.ymin) / 2 + x = bbox.xmin - axis.OFFSETTEXTPAD * axis.figure.dpi / 72.0 + y = center + x_offset = 110 + axis.offsetText.set_position((x - x_offset, y)) + + @assert_bayesian("Contour") def plot_contour( results: RATapi.outputs.BayesResults, @@ -899,7 +947,7 @@ def panel_plot_helper(plot_func: Callable, indices: list[int]) -> matplotlib.fig """ nplots = len(indices) nrows, ncols = ceil(sqrt(nplots)), round(sqrt(nplots)) - fig = plt.subplots(nrows, ncols, figsize=(2.5 * ncols, 2 * nrows))[0] + fig = plt.subplots(nrows, ncols, figsize=(11, 10))[0] axs = fig.get_axes() for plot_num, index in enumerate(indices): diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 4cdcb0aa..abeb4177 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -1,7 +1,6 @@ import os import pickle from math import ceil, sqrt -from textwrap import fill from unittest.mock import MagicMock, patch import matplotlib.pyplot as plt @@ -293,7 +292,7 @@ def test_hist(dream_results, param, hist_settings, est_dens): # assert title is as expected # also tests string to index conversion - assert ax.get_title() == fill(dream_results.fitNames[param] if isinstance(param, int) else param, 20) + assert ax.get_title(loc="left") == dream_results.fitNames[param] if isinstance(param, int) else param # assert range is default, unless given # this tests non-default hist_settings propagates correctly @@ -377,8 +376,10 @@ def test_corner(dream_results, params): assert current_axes.get_xbound() == axes[-1][j].get_xbound() elif i == j: # check title is correct - assert current_axes.get_title() == fill( - dream_results.fitNames[params[i]] if isinstance(params[i], int) else params[i], 20 + assert ( + current_axes.get_title(loc="left") == dream_results.fitNames[params[i]] + if isinstance(params[i], int) + else params[i] ) plt.close(fig)