Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion mapflow/_classic.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def __call__(
plt.show()


def plot_da(da: xr.DataArray, x_name=None, y_name=None, crs=None, borders=None, diff=False, **kwargs):
def plot_da(da: xr.DataArray, x_name=None, y_name=None, crs=None, borders=None, diff=False, subsample=None, **kwargs):
"""Convenience function for quick plotting of an xarray DataArray using PlotModel.

This is a simplified wrapper around the `PlotModel` class that handles:
Expand All @@ -288,6 +288,8 @@ def plot_da(da: xr.DataArray, x_name=None, y_name=None, crs=None, borders=None,
borders (gpd.GeoDataFrame | gpd.GeoSeries | None): Custom borders to use.
If None, defaults to world borders from a packaged GeoPackage.
diff (bool, optional): Whether to use a divergent colormap. Defaults to False.
subsample (int, optional): If provided, subsamples the data by this factor for plotting.
Useful for large datasets to speed up plotting. Defaults to None.
**kwargs: Additional arguments passed to `PlotModel.__call__`, including:
- `figsize` (tuple, optional): Figure size (width, height) in inches.
- `qmin`/`qmax` (float, optional): Quantile ranges for color scaling (0-100).
Expand Down Expand Up @@ -316,6 +318,9 @@ def plot_da(da: xr.DataArray, x_name=None, y_name=None, crs=None, borders=None,
actual_x_name = guess_coord_name(da.coords, X_NAME_CANDIDATES, x_name, "x")
actual_y_name = guess_coord_name(da.coords, Y_NAME_CANDIDATES, y_name, "y")

if subsample is not None:
da = da.isel({actual_x_name: slice(None, None, subsample), actual_y_name: slice(None, None, subsample)})

if da[actual_x_name].ndim == 1 and da[actual_y_name].ndim == 1:
da = da.sortby(actual_x_name).sortby(actual_y_name)
crs_ = process_crs(da, crs)
Expand Down
5 changes: 5 additions & 0 deletions tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,8 @@ def test_plot_da_with_linestring_borders(air_data):
gdf = gpd.GeoDataFrame(geometry=[polygon, line], crs="EPSG:4326")
plot_da(da=air_data.isel(time=0), borders=gdf, show=False)
plt.close()


def test_plot_da_subsample(air_data):
plot_da(da=air_data.isel(time=0), subsample=2, show=False)
plt.close()