Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build_wheel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ jobs:
env:
CIBW_SKIP: 'pp*'
CIBW_ARCHS: 'auto64'
CIBW_MANYLINUX_X86_64_IMAGE: 'manylinux_2_28'
CIBW_PROJECT_REQUIRES_PYTHON: '>=3.10'
CIBW_TEST_REQUIRES: 'pytest'
defaults:
Expand Down
41 changes: 27 additions & 14 deletions ratapi/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,10 @@ def plot_contour(


def panel_plot_helper(
plot_func: Callable, indices: list[int], fig: matplotlib.figure.Figure | None = None
plot_func: Callable,
indices: list[int],
fig: matplotlib.figure.Figure | None = None,
progress_callback: Callable[[int, int], None] | None = None,
) -> matplotlib.figure.Figure:
"""Generate a panel-based plot from a single plot function.

Expand All @@ -994,6 +997,9 @@ def panel_plot_helper(
The list of indices to pass into ``plot_func``.
fig : matplotlib.figure.Figure, optional
The figure object to use for plot.
progress_callback: Union[Callable[[int, int], None], None]
Callback function for providing progress during plot creation
First argument is current completed sub plot and second is total number of sub plots

Returns
-------
Expand All @@ -1005,21 +1011,19 @@ def panel_plot_helper(
nrows, ncols = ceil(sqrt(nplots)), round(sqrt(nplots))

if fig is None:
fig = plt.subplots(nrows, ncols, figsize=(11, 10))[0]
fig = plt.subplots(nrows, ncols, figsize=(11, 10), subplot_kw={"visible": False})[0]
else:
fig.clf()
fig.subplots(nrows, ncols)
fig.subplots(nrows, ncols, subplot_kw={"visible": False})
axs = fig.get_axes()

for plot_num, index in enumerate(indices):
axs[plot_num].tick_params(which="both", labelsize="medium")
axs[plot_num].xaxis.offsetText.set_fontsize("small")
axs[plot_num].yaxis.offsetText.set_fontsize("small")
plot_func(axs[plot_num], index)

# blank unused plots
for i in range(nplots, len(axs)):
axs[i].set_visible(False)
for index, plot_num in enumerate(indices):
axs[index].tick_params(which="both", labelsize="medium")
axs[index].xaxis.offsetText.set_fontsize("small")
axs[index].yaxis.offsetText.set_fontsize("small")
axs[index].set_visible(True)
plot_func(axs[index], plot_num)
if progress_callback is not None:
progress_callback(index, nplots)

fig.tight_layout()
return fig
Expand All @@ -1036,6 +1040,7 @@ def plot_hists(
block: bool = False,
fig: matplotlib.figure.Figure | None = None,
return_fig: bool = False,
progress_callback: Callable[[int, int], None] | None = None,
**hist_settings,
):
"""Plot marginalised posteriors for several parameters from a Bayesian analysis.
Expand Down Expand Up @@ -1072,6 +1077,9 @@ def plot_hists(
The figure object to use for plot.
return_fig: bool, default False
If True, return the figure as an object instead of showing it.
progress_callback: Union[Callable[[int, int], None], None]
Callback function for providing progress during plot creation
First argument is current completed sub plot and second is total number of sub plots
hist_settings :
Settings passed to `np.histogram`. By default, the settings
passed are `bins = 25` and `density = True`.
Expand Down Expand Up @@ -1130,6 +1138,7 @@ def validate_dens_type(dens_type: str | None, param: str):
),
params,
fig,
progress_callback,
)
if return_fig:
return fig
Expand All @@ -1144,6 +1153,7 @@ def plot_chain(
block: bool = False,
fig: matplotlib.figure.Figure | None = None,
return_fig: bool = False,
progress_callback: Callable[[int, int], None] | None = None,
):
"""Plot the MCMC chain for each parameter of a Bayesian analysis.

Expand All @@ -1162,6 +1172,9 @@ def plot_chain(
The figure object to use for plot.
return_fig: bool, default False
If True, return the figure as an object instead of showing it.
progress_callback: Union[Callable[[int, int], None], None]
Callback function for providing progress during plot creation
First argument is current completed sub plot and second is total number of sub plots

Returns
-------
Expand All @@ -1187,7 +1200,7 @@ def plot_one_chain(axes: Axes, i: int):
axes.plot(range(0, nsimulations, skip), chain[:, i][0:nsimulations:skip])
axes.set_title(results.fitNames[i], fontsize="small")

fig = panel_plot_helper(plot_one_chain, params, fig=fig)
fig = panel_plot_helper(plot_one_chain, params, fig, progress_callback)
if return_fig:
return fig
plt.show(block=block)
Expand Down