Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 15 additions & 191 deletions ratapi/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ratapi.rat_core import PlotEventData, makeSLDProfile


def _extract_plot_data(event_data: PlotEventData, q4: bool, show_error_bar: bool):
def _extract_plot_data(event_data: PlotEventData, q4: bool, show_error_bar: bool, shift_value: float):
"""Extract the plot data for the sld, ref, error plot lines.

Parameters
Expand All @@ -33,6 +33,8 @@ def _extract_plot_data(event_data: PlotEventData, q4: bool, show_error_bar: bool
Controls whether Q^4 is plotted on the reflectivity plot
show_error_bar : bool, default: True
Controls whether the error bars are shown
shift_value : float
A value between 1 and 100 that controls the spacing between the reflectivity plots for each of the contrasts

Returns
-------
Expand All @@ -42,9 +44,12 @@ def _extract_plot_data(event_data: PlotEventData, q4: bool, show_error_bar: bool
"""
results = {"ref": [], "error": [], "sld": [], "sld_resample": []}

if shift_value < 1 or shift_value > 100:
raise ValueError("Parameter `shift_value` must be between 1 and 100")

for i, (r, data, sld) in enumerate(zip(event_data.reflectivity, event_data.shiftedData, event_data.sldProfiles)):
# Calculate the divisor
div = 1 if i == 0 and not q4 else 2 ** (4 * (i + 1))
div = 1 if i == 0 and not q4 else 10 ** ((i / 100) * shift_value)
q4_data = 1 if not q4 or not event_data.dataPresent[i] else data[:, 0] ** 4
mult = q4_data / div

Expand Down Expand Up @@ -87,194 +92,6 @@ def _extract_plot_data(event_data: PlotEventData, q4: bool, show_error_bar: bool
return results


class PlotSLDWithBlitting:
"""Create a SLD plot that uses blitting to get faster draws.

The blit plot stores the background from an
initial draw then updates the foreground (lines and error bars) if the background is not changed.

Parameters
----------
data : PlotEventData
The plot event data that contains all the information
to generate the ref and sld plots
fig : matplotlib.pyplot.figure, optional
The figure class that has two subplots
linear_x : bool, default: False
Controls whether the x-axis on reflectivity plot uses the linear scale
q4 : bool, default: False
Controls whether Q^4 is plotted on the reflectivity plot
show_error_bar : bool, default: True
Controls whether the error bars are shown
show_grid : bool, default: False
Controls whether the grid is shown
show_legend : bool, default: True
Controls whether the legend is shown
"""

def __init__(
self,
data: PlotEventData,
fig: Optional[matplotlib.pyplot.figure] = None,
linear_x: bool = False,
q4: bool = False,
show_error_bar: bool = True,
show_grid: bool = False,
show_legend: bool = True,
):
self.figure = fig
self.linear_x = linear_x
self.q4 = q4
self.show_error_bar = show_error_bar
self.show_grid = show_grid
self.show_legend = show_legend
self.updatePlot(data)
self.event_id = self.figure.canvas.mpl_connect("resize_event", self.resizeEvent)

def __del__(self):
self.figure.canvas.mpl_disconnect(self.event_id)

def resizeEvent(self, _event):
"""Ensure the background is updated after a resize event."""
self.__background_changed = True

def update(self, data: PlotEventData):
"""Update the foreground, if background has not changed otherwise it updates full plot.

Parameters
----------
data : PlotEventData
The plot event data that contains all the information
to generate the ref and sld plots
"""
if self.__background_changed:
self.updatePlot(data)
else:
self.updateForeground(data)

def __setattr__(self, name, value):
super().__setattr__(name, value)
if name in ["figure", "linear_x", "q4", "show_error_bar", "show_grid", "show_legend"]:
self.__background_changed = True

def setAnimated(self, is_animated: bool):
"""Set the animated property of foreground plot elements.

Parameters
----------
is_animated : bool
Indicates if the animated property should been set.
"""
for line in self.figure.axes[0].lines:
line.set_animated(is_animated)
for line in self.figure.axes[1].lines:
line.set_animated(is_animated)
for container in self.figure.axes[0].containers:
container[2][0].set_animated(is_animated)

def adjustErrorBar(self, error_bar_container, x, y, y_error):
"""Adjust the error bar data.

Parameters
----------
error_bar_container : Tuple
Tuple containing the artist of the errorbar i.e. (data line, cap lines, bar lines)
x : np.ndarray
The shifted data x axis data
y : np.ndarray
The shifted data y axis data
y_error : np.ndarray
The shifted data y axis error data
"""
line, _, (bars_y,) = error_bar_container

line.set_data(x, y)
x_base = x
y_base = y

y_error_top = y_base + y_error
y_error_bottom = y_base - y_error

new_segments_y = [np.array([[x, yt], [x, yb]]) for x, yt, yb in zip(x_base, y_error_top, y_error_bottom)]
bars_y.set_segments(new_segments_y)

def updatePlot(self, data: PlotEventData):
"""Update the full plot.

Parameters
----------
data : PlotEventData
The plot event data that contains all the information
to generate the ref and sld plots
"""
if self.figure is not None:
self.figure.clf()
self.figure = plot_ref_sld_helper(
data,
self.figure,
linear_x=self.linear_x,
q4=self.q4,
show_error_bar=self.show_error_bar,
show_grid=self.show_grid,
show_legend=self.show_legend,
animated=True,
)

self.bg = self.figure.canvas.copy_from_bbox(self.figure.bbox)
for line in self.figure.axes[0].lines:
self.figure.axes[0].draw_artist(line)
for line in self.figure.axes[1].lines:
self.figure.axes[1].draw_artist(line)
for container in self.figure.axes[0].containers:
self.figure.axes[0].draw_artist(container[2][0])
self.figure.canvas.blit(self.figure.bbox)
self.setAnimated(False)
self.__background_changed = False

def updateForeground(self, data: PlotEventData):
"""Update the plot foreground only.

Parameters
----------
data : PlotEventData
The plot event data that contains all the information
to generate the ref and sld plots
"""
self.setAnimated(True)
self.figure.canvas.restore_region(self.bg)
plot_data = _extract_plot_data(data, self.q4, self.show_error_bar)

offset = 2 if self.show_error_bar else 1
for i in range(
0,
len(self.figure.axes[0].lines),
):
self.figure.axes[0].lines[i].set_data(plot_data["ref"][i // offset][0], plot_data["ref"][i // offset][1])
self.figure.axes[0].draw_artist(self.figure.axes[0].lines[i])

i = 0
for j in range(len(plot_data["sld"])):
for sld in plot_data["sld"][j]:
self.figure.axes[1].lines[i].set_data(sld[0], sld[1])
self.figure.axes[1].draw_artist(self.figure.axes[1].lines[i])
i += 1

if plot_data["sld_resample"]:
for resampled in plot_data["sld_resample"][j]:
self.figure.axes[1].lines[i].set_data(resampled[0], resampled[1])
self.figure.axes[1].draw_artist(self.figure.axes[1].lines[i])
i += 1

for i, container in enumerate(self.figure.axes[0].containers):
self.adjustErrorBar(container, plot_data["error"][i][0], plot_data["error"][i][1], plot_data["error"][i][2])
self.figure.axes[0].draw_artist(container[2][0])
self.figure.axes[0].draw_artist(container[0])

self.figure.canvas.blit(self.figure.bbox)
self.figure.canvas.flush_events()
self.setAnimated(False)


def plot_ref_sld_helper(
data: PlotEventData,
fig: Optional[matplotlib.pyplot.figure] = None,
Expand All @@ -285,6 +102,7 @@ def plot_ref_sld_helper(
show_error_bar: bool = True,
show_grid: bool = False,
show_legend: bool = True,
shift_value: float = 100,
animated=False,
):
"""Clear the previous plots and updates the ref and SLD plots.
Expand All @@ -311,6 +129,8 @@ def plot_ref_sld_helper(
Controls whether the grid is shown
show_legend : bool, default: True
Controls whether the legend is shown
shift_value : float, default: 100
A value between 1 and 100 that controls the spacing between the reflectivity plots for each of the contrasts
animated : bool, default: False
Controls whether the animated property of foreground plot elements should be set.

Expand Down Expand Up @@ -339,7 +159,7 @@ def plot_ref_sld_helper(
ref_plot.cla()
sld_plot.cla()

plot_data = _extract_plot_data(data, q4, show_error_bar)
plot_data = _extract_plot_data(data, q4, show_error_bar, shift_value)
for i, name in enumerate(data.contrastNames):
ref_plot.plot(plot_data["ref"][i][0], plot_data["ref"][i][1], label=name, linewidth=1, animated=animated)
color = ref_plot.get_lines()[-1].get_color()
Expand Down Expand Up @@ -427,6 +247,7 @@ def plot_ref_sld(
show_error_bar: bool = True,
show_grid: bool = False,
show_legend: bool = True,
shift_value: float = 100,
) -> Union[plt.Figure, None]:
"""Plot the reflectivity and SLD profiles.

Expand Down Expand Up @@ -454,6 +275,8 @@ def plot_ref_sld(
Controls whether the grid is shown
show_legend : bool, default: True
Controls whether the legend is shown
shift_value : float, default: 100
A value between 1 and 100 that controls the spacing between the reflectivity plots for each of the contrasts

Returns
-------
Expand Down Expand Up @@ -524,6 +347,7 @@ def plot_ref_sld(
show_error_bar=show_error_bar,
show_grid=show_grid,
show_legend=show_legend,
shift_value=shift_value,
)

if return_fig:
Expand Down
18 changes: 5 additions & 13 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,20 +481,12 @@ def test_bayes_validation(input_project, reflectivity_calculation_results):

@pytest.mark.parametrize("data", [data(), domains_data()])
def test_extract_plot_data(data) -> None:
plot_data = RATplot._extract_plot_data(data, False, True)
plot_data = RATplot._extract_plot_data(data, False, True, 50)
assert len(plot_data["ref"]) == len(data.reflectivity)
assert len(plot_data["sld"]) == len(data.shiftedData)

with pytest.raises(ValueError, match=r"Parameter `shift_value` must be between 1 and 100"):
RATplot._extract_plot_data(data, False, True, 0)

@patch("ratapi.utils.plotting.plot_ref_sld_helper")
def test_blit_plot(plot_helper, fig: plt.figure) -> None:
plot_helper.return_value = fig
event_data = data()
new_plot = RATplot.PlotSLDWithBlitting(event_data)
assert plot_helper.call_count == 1
new_plot.update(event_data)
assert plot_helper.call_count == 1 # foreground only is updated so no call to plot helper
new_plot.show_grid = False
new_plot.figure = plt.subplots(1, 2)[0]
new_plot.update(event_data) # plot properties have changed so update should call plot_helper
assert plot_helper.call_count == 2
with pytest.raises(ValueError, match=r"Parameter `shift_value` must be between 1 and 100"):
RATplot._extract_plot_data(data, False, True, 100.5)
Loading