Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 54 additions & 6 deletions RATapi/utils/plotting.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""Plot results using the matplotlib library."""

import copy
import types
from functools import partial, wraps
from math import ceil, floor, sqrt
from statistics import stdev
from textwrap import fill
from typing import Callable, Literal, Optional, Union

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.transforms as mtransforms
import numpy as np
from matplotlib.axes._axes import Axes
from scipy.ndimage import gaussian_filter1d
Expand Down Expand Up @@ -668,11 +669,15 @@ def plot_corner(

num_params = len(params)

fig, axes = plt.subplots(num_params, num_params, figsize=(2 * num_params, 2 * num_params))
fig, axes = plt.subplots(num_params, num_params, figsize=(11, 10))
# i is row, j is column
for i, row_param in enumerate(params):
for j, col_param in enumerate(params):
current_axes: Axes = axes[i][j]
current_axes.tick_params(which="both", labelsize="medium")
current_axes.xaxis.offsetText.set_fontsize("small")
current_axes.yaxis.offsetText.set_fontsize("small")

if i == j: # diagonal: histograms
plot_one_hist(results, param=row_param, smooth=smooth, axes=current_axes, **hist_kwargs)
elif i > j: # lower triangle: 2d histograms
Expand All @@ -687,10 +692,12 @@ def plot_corner(
if i != len(params) - 1:
current_axes.get_xaxis().set_visible(False)
# make labels invisible as titles cover that
current_axes.yaxis._update_offset_text_position = types.MethodType(
_y_update_offset_text_position, current_axes.yaxis
)
current_axes.yaxis.offset_text_position = "center"
current_axes.set_ylabel("")
current_axes.set_xlabel("")

fig.tight_layout()
if return_fig:
return fig
plt.show(block=block)
Expand Down Expand Up @@ -776,7 +783,7 @@ def plot_one_hist(
color="white",
)

axes.set_title(fill(results.fitNames[param], 20)) # use `fill` to wrap long titles
axes.set_title(results.fitNames[param], loc="left", fontsize="medium")

if estimated_density:
dx = bins[1] - bins[0]
Expand Down Expand Up @@ -806,6 +813,47 @@ def plot_one_hist(
plt.show(block=block)


def _y_update_offset_text_position(axis, _bboxes, bboxes2):
"""Update the position of the Y axis offset text using the provided bounding boxes.

Adapted from https://github.com/matplotlib/matplotlib/issues/4476#issuecomment-105627334.

Parameters
----------
axis : matplotlib.axis.YAxis
Y axis to update.
_bboxes : List
list of bounding boxes
bboxes2 : List
list of bounding boxes
"""
x, y = axis.offsetText.get_position()

if axis.offset_text_position == "left":
# y in axes coords, x in display coords
axis.offsetText.set_transform(
mtransforms.blended_transform_factory(axis.axes.transAxes, mtransforms.IdentityTransform())
)

top = axis.axes.bbox.ymax
y = top + axis.OFFSETTEXTPAD * axis.figure.dpi / 72.0

else:
# x & y in display coords
axis.offsetText.set_transform(mtransforms.IdentityTransform())

# Northwest of upper-right corner of right-hand extent of tick labels
if bboxes2:
bbox = mtransforms.Bbox.union(bboxes2)
else:
bbox = axis.axes.bbox
center = bbox.ymin + (bbox.ymax - bbox.ymin) / 2
x = bbox.xmin - axis.OFFSETTEXTPAD * axis.figure.dpi / 72.0
y = center
x_offset = 110
axis.offsetText.set_position((x - x_offset, y))


@assert_bayesian("Contour")
def plot_contour(
results: RATapi.outputs.BayesResults,
Expand Down Expand Up @@ -899,7 +947,7 @@ def panel_plot_helper(plot_func: Callable, indices: list[int]) -> matplotlib.fig
"""
nplots = len(indices)
nrows, ncols = ceil(sqrt(nplots)), round(sqrt(nplots))
fig = plt.subplots(nrows, ncols, figsize=(2.5 * ncols, 2 * nrows))[0]
fig = plt.subplots(nrows, ncols, figsize=(11, 10))[0]
axs = fig.get_axes()

for plot_num, index in enumerate(indices):
Expand Down
9 changes: 5 additions & 4 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import pickle
from math import ceil, sqrt
from textwrap import fill
from unittest.mock import MagicMock, patch

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -293,7 +292,7 @@ def test_hist(dream_results, param, hist_settings, est_dens):

# assert title is as expected
# also tests string to index conversion
assert ax.get_title() == fill(dream_results.fitNames[param] if isinstance(param, int) else param, 20)
assert ax.get_title(loc="left") == dream_results.fitNames[param] if isinstance(param, int) else param

# assert range is default, unless given
# this tests non-default hist_settings propagates correctly
Expand Down Expand Up @@ -377,8 +376,10 @@ def test_corner(dream_results, params):
assert current_axes.get_xbound() == axes[-1][j].get_xbound()
elif i == j:
# check title is correct
assert current_axes.get_title() == fill(
dream_results.fitNames[params[i]] if isinstance(params[i], int) else params[i], 20
assert (
current_axes.get_title(loc="left") == dream_results.fitNames[params[i]]
if isinstance(params[i], int)
else params[i]
)

plt.close(fig)
Expand Down