Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 30 additions & 7 deletions mapflow/_classic.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@
return LogNorm(vmin=vmin, vmax=vmax)

@staticmethod
def _norm(data, vmin, vmax, qmin, qmax, norm, log):
def _norm(data, vmin, vmax, qmin, qmax, norm, log, diff=False):

Check warning on line 113 in mapflow/_classic.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

mapflow/_classic.py#L113

Method _norm has a cyclomatic complexity of 10 (limit is 8)
"""Generates a normalization based on the specified parameters.

Args:
Expand All @@ -121,6 +121,7 @@
qmax (float): Maximum quantile for normalization (0-100).
norm (matplotlib.colors.Normalize): Custom normalization object.
log (bool): Indicates if a logarithmic scale should be used.
diff (bool): Indicates if a divergent colormap should be used.

Returns:
matplotlib.colors.Normalize: Normalization object.
Expand All @@ -139,6 +140,12 @@
if norm is not None:
return norm

if diff:
if vmax is None:
vmax = np.nanpercentile(np.abs(data), q=qmax)
vmin = -vmax
return Normalize(vmin=vmin, vmax=vmax)

if log:
return PlotModel._log_norm(data, vmin, vmax, qmin, qmax)

Expand Down Expand Up @@ -170,6 +177,7 @@
vmin=None,
vmax=None,
log=False,
diff=False,
cmap="jet",
norm=None,
shading="nearest",
Expand Down Expand Up @@ -199,6 +207,8 @@
Overrides qmax. Defaults to None.
log (bool, optional): Whether to use a logarithmic color scale.
Defaults to False.
diff (bool, optional): Whether to use a divergent colormap.
Defaults to False.
cmap (str, optional): Colormap to use. Defaults to "jet".
norm (matplotlib.colors.Normalize, optional): Custom normalization object.
Overrides vmin, vmax, qmin, qmax, log. Defaults to None.
Expand All @@ -211,8 +221,10 @@
show (bool, optional): Whether to display the plot using `plt.show()`.
Defaults to True.
"""
if diff:
cmap = "bwr"
data = self._process_data(data)
norm = self._norm(data, vmin, vmax, qmin, qmax, norm, log=log)
norm = self._norm(data, vmin, vmax, qmin, qmax, norm, log=log, diff=diff)
plt.figure(figsize=figsize)
if (self.x.ndim == 1) and (self.y.ndim == 1):
plt.imshow(
Expand Down Expand Up @@ -250,7 +262,7 @@
plt.show()


def plot_da(da: xr.DataArray, x_name=None, y_name=None, crs=4326, **kwargs):
def plot_da(da: xr.DataArray, x_name=None, y_name=None, crs=4326, diff=False, **kwargs):
"""Convenience function for quick plotting of an xarray DataArray using PlotModel.

This is a simplified wrapper around the `PlotModel` class that handles:
Expand All @@ -274,6 +286,7 @@
- `qmin`/`qmax` (float, optional): Quantile ranges for color scaling (0-100).
- `vmin`/`vmax` (float, optional): Explicit value ranges for color scaling.
- `log` (bool, optional): Whether to use a logarithmic color scale.
- `diff` (bool, optional): Whether to use a divergent colormap.
- `cmap` (str, optional): Colormap name.
- `norm` (matplotlib.colors.Normalize, optional): Custom normalization object.
- `shading` (str, optional): Color shading method.
Expand Down Expand Up @@ -305,7 +318,7 @@

p = PlotModel(x=da[actual_x_name].values, y=da[actual_y_name].values, crs=crs_)
data = p._process_data(da.values)
p(data, **kwargs)
p(data, diff=diff, **kwargs)


class Animation:
Expand Down Expand Up @@ -381,6 +394,7 @@
vmax=None,
norm=None,
log=False,
diff=False,
label=None,
dpi=180,
n_jobs=None,
Expand Down Expand Up @@ -415,14 +429,17 @@
vmax (float, optional): Maximum value for color normalization. Overrides qmax.
norm (matplotlib.colors.Normalize, optional): Custom normalization object.
log (bool, optional): Whether to use a logarithmic color scale. Defaults to False.
diff (bool, optional): Whether to use a divergent colormap. Defaults to False.
label (str, optional): Label for the colorbar. Defaults to None.
dpi (int, optional): Dots per inch for the saved frames. Defaults to 180.
n_jobs (int, optional): Number of parallel jobs for frame generation.
Defaults to 2/3 of CPU cores.
timeout (int | str, optional): Timeout for the ffmpeg command in seconds.
Defaults to "auto", which sets the timeout to `max(20, 0.1 * data_len)`.
"""
norm = self.plot._norm(data, vmin, vmax, qmin, qmax, norm, log)
if diff:
cmap = "bwr"
norm = self.plot._norm(data, vmin, vmax, qmin, qmax, norm, log, diff)
self._animate(
data=data,
path=path,
Expand All @@ -437,6 +454,7 @@
dpi=dpi,
n_jobs=n_jobs,
timeout=timeout,
diff=diff,
)

def _animate(
Expand All @@ -454,6 +472,7 @@
dpi=180,
n_jobs=None,
timeout="auto",
diff=False,
):
titles = self._process_title(title, upsample_ratio)
data = self.upsample(data, ratio=upsample_ratio)
Expand All @@ -473,7 +492,7 @@
norm,
label,
dpi,
{}, # No kwargs
{"diff": diff},
)
args.append(arg_tuple)

Expand All @@ -494,7 +513,7 @@

def _generate_frame(self, args):
"""Generates a frame and saves it as a PNG."""
data_frame, frame_path, figsize, title, cmap, norm, label, dpi, _ = args
data_frame, frame_path, figsize, title, cmap, norm, label, dpi, kwargs = args
self.plot(
data=data_frame,
figsize=figsize,
Expand All @@ -503,6 +522,7 @@
cmap=cmap,
norm=norm,
label=label,
**kwargs,
)
plt.savefig(frame_path, dpi=dpi, bbox_inches="tight", pad_inches=0.05)
plt.clf()
Expand Down Expand Up @@ -566,6 +586,7 @@
crs=None,
borders: gpd.GeoDataFrame | gpd.GeoSeries | None = None,
verbose: int = 0,
diff=False,
**kwargs,
):
"""Creates an animation from an xarray DataArray.
Expand Down Expand Up @@ -595,6 +616,7 @@
- `cmap` (str, optional): Colormap for the plot.
- `norm` (matplotlib.colors.Normalize, optional): Custom normalization object.
- `log` (bool, optional): Use logarithmic color scale.
- `diff` (bool, optional): Whether to use a divergent colormap.
- `qmin` (float, optional): Minimum quantile for color normalization.
- `qmax` (float, optional): Maximum quantile for color normalization.
- `vmin` (float, optional): Minimum value for color normalization.
Expand Down Expand Up @@ -643,5 +665,6 @@
path=output_path,
title=titles,
label=unit,
diff=diff,
**kwargs,
)
5 changes: 5 additions & 0 deletions tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,8 @@ def test_plot_da_quiver(air_temperature_gradient_data):
plt.close()
plot_da_quiver(u, v, subsample=2, show=False)
plt.close()


def test_plot_da_diff(air_data):
plot_da(da=air_data.isel(time=0), diff=True, show=False)
plt.close()
Loading