From f30fadb1281f27ef49502b60f6cd12b071c7a095 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Fri, 22 Aug 2025 14:04:23 +0000 Subject: [PATCH 1/3] Add a `diff` argument to the classic plotting functions (`plot_da`, `animate`, `PlotModel`, `Animation`). When `diff=True`, the colormap is set to 'bwr' and the color range is made symmetric around zero. - If `vmax` is provided, `vmin` is set to `-vmax`. - Otherwise, `vmax` is calculated from the `qmax` percentile of the absolute data, and `vmin` is set to `-vmax`. --- mapflow/_classic.py | 37 ++++++++++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/mapflow/_classic.py b/mapflow/_classic.py index fbb5c23..e475965 100644 --- a/mapflow/_classic.py +++ b/mapflow/_classic.py @@ -110,7 +110,7 @@ def _log_norm(data, vmin, vmax, qmin, qmax): return LogNorm(vmin=vmin, vmax=vmax) @staticmethod - def _norm(data, vmin, vmax, qmin, qmax, norm, log): + def _norm(data, vmin, vmax, qmin, qmax, norm, log, diff=False): """Generates a normalization based on the specified parameters. Args: @@ -121,6 +121,7 @@ def _norm(data, vmin, vmax, qmin, qmax, norm, log): qmax (float): Maximum quantile for normalization (0-100). norm (matplotlib.colors.Normalize): Custom normalization object. log (bool): Indicates if a logarithmic scale should be used. + diff (bool): Indicates if a divergent colormap should be used. Returns: matplotlib.colors.Normalize: Normalization object. @@ -139,6 +140,12 @@ def _norm(data, vmin, vmax, qmin, qmax, norm, log): if norm is not None: return norm + if diff: + if vmax is None: + vmax = np.nanpercentile(np.abs(data), q=qmax) + vmin = -vmax + return Normalize(vmin=vmin, vmax=vmax) + if log: return PlotModel._log_norm(data, vmin, vmax, qmin, qmax) @@ -170,6 +177,7 @@ def __call__( vmin=None, vmax=None, log=False, + diff=False, cmap="jet", norm=None, shading="nearest", @@ -199,6 +207,8 @@ def __call__( Overrides qmax. Defaults to None. log (bool, optional): Whether to use a logarithmic color scale. Defaults to False. + diff (bool, optional): Whether to use a divergent colormap. + Defaults to False. cmap (str, optional): Colormap to use. Defaults to "jet". norm (matplotlib.colors.Normalize, optional): Custom normalization object. Overrides vmin, vmax, qmin, qmax, log. Defaults to None. @@ -211,8 +221,10 @@ def __call__( show (bool, optional): Whether to display the plot using `plt.show()`. Defaults to True. """ + if diff: + cmap = "bwr" data = self._process_data(data) - norm = self._norm(data, vmin, vmax, qmin, qmax, norm, log=log) + norm = self._norm(data, vmin, vmax, qmin, qmax, norm, log=log, diff=diff) plt.figure(figsize=figsize) if (self.x.ndim == 1) and (self.y.ndim == 1): plt.imshow( @@ -250,7 +262,7 @@ def __call__( plt.show() -def plot_da(da: xr.DataArray, x_name=None, y_name=None, crs=4326, **kwargs): +def plot_da(da: xr.DataArray, x_name=None, y_name=None, crs=4326, diff=False, **kwargs): """Convenience function for quick plotting of an xarray DataArray using PlotModel. This is a simplified wrapper around the `PlotModel` class that handles: @@ -274,6 +286,7 @@ def plot_da(da: xr.DataArray, x_name=None, y_name=None, crs=4326, **kwargs): - `qmin`/`qmax` (float, optional): Quantile ranges for color scaling (0-100). - `vmin`/`vmax` (float, optional): Explicit value ranges for color scaling. - `log` (bool, optional): Whether to use a logarithmic color scale. + - `diff` (bool, optional): Whether to use a divergent colormap. - `cmap` (str, optional): Colormap name. - `norm` (matplotlib.colors.Normalize, optional): Custom normalization object. - `shading` (str, optional): Color shading method. @@ -305,7 +318,7 @@ def plot_da(da: xr.DataArray, x_name=None, y_name=None, crs=4326, **kwargs): p = PlotModel(x=da[actual_x_name].values, y=da[actual_y_name].values, crs=crs_) data = p._process_data(da.values) - p(data, **kwargs) + p(data, diff=diff, **kwargs) class Animation: @@ -381,6 +394,7 @@ def __call__( vmax=None, norm=None, log=False, + diff=False, label=None, dpi=180, n_jobs=None, @@ -415,6 +429,7 @@ def __call__( vmax (float, optional): Maximum value for color normalization. Overrides qmax. norm (matplotlib.colors.Normalize, optional): Custom normalization object. log (bool, optional): Whether to use a logarithmic color scale. Defaults to False. + diff (bool, optional): Whether to use a divergent colormap. Defaults to False. label (str, optional): Label for the colorbar. Defaults to None. dpi (int, optional): Dots per inch for the saved frames. Defaults to 180. n_jobs (int, optional): Number of parallel jobs for frame generation. @@ -422,7 +437,9 @@ def __call__( timeout (int | str, optional): Timeout for the ffmpeg command in seconds. Defaults to "auto", which sets the timeout to `max(20, 0.1 * data_len)`. """ - norm = self.plot._norm(data, vmin, vmax, qmin, qmax, norm, log) + if diff: + cmap = "bwr" + norm = self.plot._norm(data, vmin, vmax, qmin, qmax, norm, log, diff) self._animate( data=data, path=path, @@ -437,6 +454,7 @@ def __call__( dpi=dpi, n_jobs=n_jobs, timeout=timeout, + diff=diff, ) def _animate( @@ -454,6 +472,7 @@ def _animate( dpi=180, n_jobs=None, timeout="auto", + diff=False, ): titles = self._process_title(title, upsample_ratio) data = self.upsample(data, ratio=upsample_ratio) @@ -473,7 +492,7 @@ def _animate( norm, label, dpi, - {}, # No kwargs + {"diff": diff}, ) args.append(arg_tuple) @@ -494,7 +513,7 @@ def _animate( def _generate_frame(self, args): """Generates a frame and saves it as a PNG.""" - data_frame, frame_path, figsize, title, cmap, norm, label, dpi, _ = args + data_frame, frame_path, figsize, title, cmap, norm, label, dpi, kwargs = args self.plot( data=data_frame, figsize=figsize, @@ -503,6 +522,7 @@ def _generate_frame(self, args): cmap=cmap, norm=norm, label=label, + **kwargs, ) plt.savefig(frame_path, dpi=dpi, bbox_inches="tight", pad_inches=0.05) plt.clf() @@ -566,6 +586,7 @@ def animate( crs=None, borders: gpd.GeoDataFrame | gpd.GeoSeries | None = None, verbose: int = 0, + diff=False, **kwargs, ): """Creates an animation from an xarray DataArray. @@ -595,6 +616,7 @@ def animate( - `cmap` (str, optional): Colormap for the plot. - `norm` (matplotlib.colors.Normalize, optional): Custom normalization object. - `log` (bool, optional): Use logarithmic color scale. + - `diff` (bool, optional): Whether to use a divergent colormap. - `qmin` (float, optional): Minimum quantile for color normalization. - `qmax` (float, optional): Maximum quantile for color normalization. - `vmin` (float, optional): Minimum value for color normalization. @@ -643,5 +665,6 @@ def animate( path=output_path, title=titles, label=unit, + diff=diff, **kwargs, ) From 3f809a91983b9c646f9847714fd656ac19c060e6 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Fri, 22 Aug 2025 14:29:52 +0000 Subject: [PATCH 2/3] Add a test for the `diff` argument in `plot_da`. This test mocks the underlying plotting function to assert that the `cmap` is set to 'bwr' and that the normalization is symmetric around zero when `diff=True`. --- tests/test_plot.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/test_plot.py b/tests/test_plot.py index 18adc2a..4fb06cc 100644 --- a/tests/test_plot.py +++ b/tests/test_plot.py @@ -1,3 +1,4 @@ +from unittest.mock import patch import matplotlib.pyplot as plt import pytest import xarray as xr @@ -28,3 +29,12 @@ def test_plot_da_quiver(air_temperature_gradient_data): plt.close() plot_da_quiver(u, v, subsample=2, show=False) plt.close() + + +def test_plot_da_diff(air_data): + with patch("matplotlib.pyplot.imshow") as mock_imshow: + plot_da(da=air_data.isel(time=0), diff=True, show=False) + plt.close() + kwargs = mock_imshow.call_args.kwargs + assert kwargs["cmap"] == "bwr" + assert kwargs["norm"].vmin == -kwargs["norm"].vmax From 9fd7b5164e27eca059a4aa4847da9c2d66714c69 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Fri, 22 Aug 2025 17:57:30 +0000 Subject: [PATCH 3/3] Simplify the test for the `diff` argument. As requested, this commit removes the mocking and assertions from the test `test_plot_da_diff`. The test now only checks that the function runs without errors when `diff=True`. --- tests/test_plot.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/test_plot.py b/tests/test_plot.py index 4fb06cc..ee49206 100644 --- a/tests/test_plot.py +++ b/tests/test_plot.py @@ -1,4 +1,3 @@ -from unittest.mock import patch import matplotlib.pyplot as plt import pytest import xarray as xr @@ -32,9 +31,5 @@ def test_plot_da_quiver(air_temperature_gradient_data): def test_plot_da_diff(air_data): - with patch("matplotlib.pyplot.imshow") as mock_imshow: - plot_da(da=air_data.isel(time=0), diff=True, show=False) - plt.close() - kwargs = mock_imshow.call_args.kwargs - assert kwargs["cmap"] == "bwr" - assert kwargs["norm"].vmin == -kwargs["norm"].vmax + plot_da(da=air_data.isel(time=0), diff=True, show=False) + plt.close()