From 6e2814d0f8897e41640e58b810c9304c74154018 Mon Sep 17 00:00:00 2001 From: Stephen Nneji Date: Thu, 28 Aug 2025 13:37:33 +0100 Subject: [PATCH 1/2] Adds blitting support and updates live plot to use it --- ratapi/utils/plotting.py | 205 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 204 insertions(+), 1 deletion(-) diff --git a/ratapi/utils/plotting.py b/ratapi/utils/plotting.py index e61e40ae..a6f2f557 100644 --- a/ratapi/utils/plotting.py +++ b/ratapi/utils/plotting.py @@ -356,6 +356,205 @@ def plot_ref_sld( plt.show(block=block) +class BlittingSupport: + """Create a SLD plot that uses blitting to get faster draws. + + The blit plot stores the background from an + initial draw then updates the foreground (lines and error bars) if the background is not changed. + + Parameters + ---------- + data : PlotEventData + The plot event data that contains all the information + to generate the ref and sld plots + fig : matplotlib.pyplot.figure, optional + The figure class that has two subplots + linear_x : bool, default: False + Controls whether the x-axis on reflectivity plot uses the linear scale + q4 : bool, default: False + Controls whether Q^4 is plotted on the reflectivity plot + show_error_bar : bool, default: True + Controls whether the error bars are shown + show_grid : bool, default: False + Controls whether the grid is shown + show_legend : bool, default: True + Controls whether the legend is shown + shift_value : float, default: 100 + A value between 1 and 100 that controls the spacing between the reflectivity plots for each of the contrasts + """ + + def __init__( + self, + data, + fig=None, + linear_x: bool = False, + q4: bool = False, + show_error_bar: bool = True, + show_grid: bool = False, + show_legend: bool = True, + shift_value: float = 100, + ): + self.figure = fig + self.linear_x = linear_x + self.q4 = q4 + self.show_error_bar = show_error_bar + self.show_grid = show_grid + self.show_legend = show_legend + self.shift_value = shift_value + self.update_plot(data) + self.event_id = self.figure.canvas.mpl_connect("resize_event", self.resizeEvent) + + def __del__(self): + self.figure.canvas.mpl_disconnect(self.event_id) + + def resizeEvent(self, _event): + """Ensure the background is updated after a resize event.""" + self.__background_changed = True + + def update(self, data): + """Update the foreground, if background has not changed otherwise it updates full plot. + + Parameters + ---------- + data : PlotEventData + The plot event data that contains all the information + to generate the ref and sld plots + """ + if self.__background_changed: + self.update_plot(data) + else: + self.update_foreground(data) + + def __setattr__(self, name, value): + old_value = getattr(self, name, None) + if value == old_value: + return + + super().__setattr__(name, value) + if name in ["figure", "linear_x", "q4", "show_error_bar", "show_grid", "show_legend", "shift_value"]: + self.__background_changed = True + + def set_animated(self, is_animated: bool): + """Set the animated property of foreground plot elements. + + Parameters + ---------- + is_animated : bool + Indicates if the animated property should be set. + """ + for line in self.figure.axes[0].lines: + line.set_animated(is_animated) + for line in self.figure.axes[1].lines: + line.set_animated(is_animated) + for container in self.figure.axes[0].containers: + container[2][0].set_animated(is_animated) + + def adjust_error_bar(self, error_bar_container, x, y, y_error): + """Adjust the error bar data. + + Parameters + ---------- + error_bar_container : Tuple + Tuple containing the artist of the errorbar i.e. (data line, cap lines, bar lines) + x : np.ndarray + The shifted data x axis data + y : np.ndarray + The shifted data y axis data + y_error : np.ndarray + The shifted data y axis error data + """ + line, _, (bars_y,) = error_bar_container + + line.set_data(x, y) + x_base = x + y_base = y + + y_error_top = y_base + y_error + y_error_bottom = y_base - y_error + + new_segments_y = [np.array([[x, yt], [x, yb]]) for x, yt, yb in zip(x_base, y_error_top, y_error_bottom)] + bars_y.set_segments(new_segments_y) + + def update_plot(self, data): + """Update the full plot. + + Parameters + ---------- + data : PlotEventData + The plot event data that contains all the information + to generate the ref and sld plots + """ + if self.figure is not None: + self.figure.clf() + self.figure = ratapi.plotting.plot_ref_sld_helper( + data, + self.figure, + linear_x=self.linear_x, + q4=self.q4, + show_error_bar=self.show_error_bar, + show_grid=self.show_grid, + show_legend=self.show_legend, + animated=True, + ) + self.figure.tight_layout(pad=1) + self.figure.canvas.draw() + self.bg = self.figure.canvas.copy_from_bbox(self.figure.bbox) + for line in self.figure.axes[0].lines: + self.figure.axes[0].draw_artist(line) + for line in self.figure.axes[1].lines: + self.figure.axes[1].draw_artist(line) + for container in self.figure.axes[0].containers: + self.figure.axes[0].draw_artist(container[2][0]) + self.figure.canvas.blit(self.figure.bbox) + self.set_animated(False) + self.__background_changed = False + + def update_foreground(self, data): + """Update the plot foreground only. + + Parameters + ---------- + data : PlotEventData + The plot event data that contains all the information + to generate the ref and sld plots + """ + self.set_animated(True) + self.figure.canvas.restore_region(self.bg) + plot_data = ratapi.plotting._extract_plot_data(data, self.q4, self.show_error_bar, self.shift_value) + + offset = 2 if self.show_error_bar else 1 + for i in range( + 0, + len(self.figure.axes[0].lines), + ): + self.figure.axes[0].lines[i].set_data(plot_data["ref"][i // offset][0], plot_data["ref"][i // offset][1]) + self.figure.axes[0].draw_artist(self.figure.axes[0].lines[i]) + + i = 0 + for j in range(len(plot_data["sld"])): + for sld in plot_data["sld"][j]: + self.figure.axes[1].lines[i].set_data(sld[0], sld[1]) + self.figure.axes[1].draw_artist(self.figure.axes[1].lines[i]) + i += 1 + + if plot_data["sld_resample"]: + for resampled in plot_data["sld_resample"][j]: + self.figure.axes[1].lines[i].set_data(resampled[0], resampled[1]) + self.figure.axes[1].draw_artist(self.figure.axes[1].lines[i]) + i += 1 + + for i, container in enumerate(self.figure.axes[0].containers): + self.adjust_error_bar( + container, plot_data["error"][i][0], plot_data["error"][i][1], plot_data["error"][i][2] + ) + self.figure.axes[0].draw_artist(container[2][0]) + self.figure.axes[0].draw_artist(container[0]) + + self.figure.canvas.blit(self.figure.bbox) + self.figure.canvas.flush_events() + self.set_animated(False) + + class LivePlot: """Create a plot that gets updates from the plot event during a calculation. @@ -369,6 +568,7 @@ class LivePlot: def __init__(self, block=False): self.block = block self.closed = False + self.blit_plot = None def __enter__(self): self.figure = plt.subplots(1, 2)[0] @@ -394,7 +594,10 @@ def plotEvent(self, event): """ if not self.closed and self.figure.number in plt.get_fignums(): - plot_ref_sld_helper(event, self.figure) + if self.blit_plot is None: + self.blit_plot = BlittingSupport(event, self.figure) + else: + self.blit_plot.update(event) def __exit__(self, _exc_type, _exc_val, _traceback): ratapi.events.clear(ratapi.events.EventTypes.Plot, self.plotEvent) From 03b97095b00f27a3ccc41b314ac7b94db783d015 Mon Sep 17 00:00:00 2001 From: Stephen Nneji Date: Mon, 1 Sep 2025 10:21:54 +0100 Subject: [PATCH 2/2] Skips orsopy tests --- tests/test_orso_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_orso_utils.py b/tests/test_orso_utils.py index 39137fe9..a8c0791a 100644 --- a/tests/test_orso_utils.py +++ b/tests/test_orso_utils.py @@ -36,6 +36,7 @@ def prist(): ], ) @pytest.mark.parametrize("absorption", [True, False]) +@pytest.mark.skip(reason="orsopy database website (https://slddb.esss.dk/slddb/) is not available") def test_orso_model_to_rat(model, absorption): """Test that orso_model_to_rat gives the expected parameters, layers and model.""" @@ -72,6 +73,7 @@ def test_orso_model_to_rat(model, absorption): "prist5_10K_m_025.Rqz.ort", ], ) +@pytest.mark.skip(reason="orsopy database website (https://slddb.esss.dk/slddb/) is not available") def test_load_ort_data(test_data): """Test that .ort data is loaded correctly.""" # manually get the test data for comparison @@ -104,6 +106,7 @@ def test_load_ort_data(test_data): ["prist5_10K_m_025.Rqz.ort", "prist.json"], ], ) +@pytest.mark.skip(reason="orsopy database website (https://slddb.esss.dk/slddb/) is not available") def test_load_ort_project(test_data, expected_data): """Test that a project with model data is loaded correctly.""" ort_data = ORSOProject(Path(TEST_DIR_PATH, test_data))