diff --git a/mapflow/_classic.py b/mapflow/_classic.py index fbb5c23..e475965 100644 --- a/mapflow/_classic.py +++ b/mapflow/_classic.py @@ -110,7 +110,7 @@ def _log_norm(data, vmin, vmax, qmin, qmax): return LogNorm(vmin=vmin, vmax=vmax) @staticmethod - def _norm(data, vmin, vmax, qmin, qmax, norm, log): + def _norm(data, vmin, vmax, qmin, qmax, norm, log, diff=False): """Generates a normalization based on the specified parameters. Args: @@ -121,6 +121,7 @@ def _norm(data, vmin, vmax, qmin, qmax, norm, log): qmax (float): Maximum quantile for normalization (0-100). norm (matplotlib.colors.Normalize): Custom normalization object. log (bool): Indicates if a logarithmic scale should be used. + diff (bool): Indicates if a divergent colormap should be used. Returns: matplotlib.colors.Normalize: Normalization object. @@ -139,6 +140,12 @@ def _norm(data, vmin, vmax, qmin, qmax, norm, log): if norm is not None: return norm + if diff: + if vmax is None: + vmax = np.nanpercentile(np.abs(data), q=qmax) + vmin = -vmax + return Normalize(vmin=vmin, vmax=vmax) + if log: return PlotModel._log_norm(data, vmin, vmax, qmin, qmax) @@ -170,6 +177,7 @@ def __call__( vmin=None, vmax=None, log=False, + diff=False, cmap="jet", norm=None, shading="nearest", @@ -199,6 +207,8 @@ def __call__( Overrides qmax. Defaults to None. log (bool, optional): Whether to use a logarithmic color scale. Defaults to False. + diff (bool, optional): Whether to use a divergent colormap. + Defaults to False. cmap (str, optional): Colormap to use. Defaults to "jet". norm (matplotlib.colors.Normalize, optional): Custom normalization object. Overrides vmin, vmax, qmin, qmax, log. Defaults to None. @@ -211,8 +221,10 @@ def __call__( show (bool, optional): Whether to display the plot using `plt.show()`. Defaults to True. """ + if diff: + cmap = "bwr" data = self._process_data(data) - norm = self._norm(data, vmin, vmax, qmin, qmax, norm, log=log) + norm = self._norm(data, vmin, vmax, qmin, qmax, norm, log=log, diff=diff) plt.figure(figsize=figsize) if (self.x.ndim == 1) and (self.y.ndim == 1): plt.imshow( @@ -250,7 +262,7 @@ def __call__( plt.show() -def plot_da(da: xr.DataArray, x_name=None, y_name=None, crs=4326, **kwargs): +def plot_da(da: xr.DataArray, x_name=None, y_name=None, crs=4326, diff=False, **kwargs): """Convenience function for quick plotting of an xarray DataArray using PlotModel. This is a simplified wrapper around the `PlotModel` class that handles: @@ -274,6 +286,7 @@ def plot_da(da: xr.DataArray, x_name=None, y_name=None, crs=4326, **kwargs): - `qmin`/`qmax` (float, optional): Quantile ranges for color scaling (0-100). - `vmin`/`vmax` (float, optional): Explicit value ranges for color scaling. - `log` (bool, optional): Whether to use a logarithmic color scale. + - `diff` (bool, optional): Whether to use a divergent colormap. - `cmap` (str, optional): Colormap name. - `norm` (matplotlib.colors.Normalize, optional): Custom normalization object. - `shading` (str, optional): Color shading method. @@ -305,7 +318,7 @@ def plot_da(da: xr.DataArray, x_name=None, y_name=None, crs=4326, **kwargs): p = PlotModel(x=da[actual_x_name].values, y=da[actual_y_name].values, crs=crs_) data = p._process_data(da.values) - p(data, **kwargs) + p(data, diff=diff, **kwargs) class Animation: @@ -381,6 +394,7 @@ def __call__( vmax=None, norm=None, log=False, + diff=False, label=None, dpi=180, n_jobs=None, @@ -415,6 +429,7 @@ def __call__( vmax (float, optional): Maximum value for color normalization. Overrides qmax. norm (matplotlib.colors.Normalize, optional): Custom normalization object. log (bool, optional): Whether to use a logarithmic color scale. Defaults to False. + diff (bool, optional): Whether to use a divergent colormap. Defaults to False. label (str, optional): Label for the colorbar. Defaults to None. dpi (int, optional): Dots per inch for the saved frames. Defaults to 180. n_jobs (int, optional): Number of parallel jobs for frame generation. @@ -422,7 +437,9 @@ def __call__( timeout (int | str, optional): Timeout for the ffmpeg command in seconds. Defaults to "auto", which sets the timeout to `max(20, 0.1 * data_len)`. """ - norm = self.plot._norm(data, vmin, vmax, qmin, qmax, norm, log) + if diff: + cmap = "bwr" + norm = self.plot._norm(data, vmin, vmax, qmin, qmax, norm, log, diff) self._animate( data=data, path=path, @@ -437,6 +454,7 @@ def __call__( dpi=dpi, n_jobs=n_jobs, timeout=timeout, + diff=diff, ) def _animate( @@ -454,6 +472,7 @@ def _animate( dpi=180, n_jobs=None, timeout="auto", + diff=False, ): titles = self._process_title(title, upsample_ratio) data = self.upsample(data, ratio=upsample_ratio) @@ -473,7 +492,7 @@ def _animate( norm, label, dpi, - {}, # No kwargs + {"diff": diff}, ) args.append(arg_tuple) @@ -494,7 +513,7 @@ def _animate( def _generate_frame(self, args): """Generates a frame and saves it as a PNG.""" - data_frame, frame_path, figsize, title, cmap, norm, label, dpi, _ = args + data_frame, frame_path, figsize, title, cmap, norm, label, dpi, kwargs = args self.plot( data=data_frame, figsize=figsize, @@ -503,6 +522,7 @@ def _generate_frame(self, args): cmap=cmap, norm=norm, label=label, + **kwargs, ) plt.savefig(frame_path, dpi=dpi, bbox_inches="tight", pad_inches=0.05) plt.clf() @@ -566,6 +586,7 @@ def animate( crs=None, borders: gpd.GeoDataFrame | gpd.GeoSeries | None = None, verbose: int = 0, + diff=False, **kwargs, ): """Creates an animation from an xarray DataArray. @@ -595,6 +616,7 @@ def animate( - `cmap` (str, optional): Colormap for the plot. - `norm` (matplotlib.colors.Normalize, optional): Custom normalization object. - `log` (bool, optional): Use logarithmic color scale. + - `diff` (bool, optional): Whether to use a divergent colormap. - `qmin` (float, optional): Minimum quantile for color normalization. - `qmax` (float, optional): Maximum quantile for color normalization. - `vmin` (float, optional): Minimum value for color normalization. @@ -643,5 +665,6 @@ def animate( path=output_path, title=titles, label=unit, + diff=diff, **kwargs, ) diff --git a/tests/test_plot.py b/tests/test_plot.py index 18adc2a..ee49206 100644 --- a/tests/test_plot.py +++ b/tests/test_plot.py @@ -28,3 +28,8 @@ def test_plot_da_quiver(air_temperature_gradient_data): plt.close() plot_da_quiver(u, v, subsample=2, show=False) plt.close() + + +def test_plot_da_diff(air_data): + plot_da(da=air_data.isel(time=0), diff=True, show=False) + plt.close()