diff --git a/mapflow/_classic.py b/mapflow/_classic.py index fe31745..15e21aa 100644 --- a/mapflow/_classic.py +++ b/mapflow/_classic.py @@ -266,7 +266,7 @@ def __call__( plt.show() -def plot_da(da: xr.DataArray, x_name=None, y_name=None, crs=None, borders=None, diff=False, **kwargs): +def plot_da(da: xr.DataArray, x_name=None, y_name=None, crs=None, borders=None, diff=False, subsample=None, **kwargs): """Convenience function for quick plotting of an xarray DataArray using PlotModel. This is a simplified wrapper around the `PlotModel` class that handles: @@ -288,6 +288,8 @@ def plot_da(da: xr.DataArray, x_name=None, y_name=None, crs=None, borders=None, borders (gpd.GeoDataFrame | gpd.GeoSeries | None): Custom borders to use. If None, defaults to world borders from a packaged GeoPackage. diff (bool, optional): Whether to use a divergent colormap. Defaults to False. + subsample (int, optional): If provided, subsamples the data by this factor for plotting. + Useful for large datasets to speed up plotting. Defaults to None. **kwargs: Additional arguments passed to `PlotModel.__call__`, including: - `figsize` (tuple, optional): Figure size (width, height) in inches. - `qmin`/`qmax` (float, optional): Quantile ranges for color scaling (0-100). @@ -316,6 +318,9 @@ def plot_da(da: xr.DataArray, x_name=None, y_name=None, crs=None, borders=None, actual_x_name = guess_coord_name(da.coords, X_NAME_CANDIDATES, x_name, "x") actual_y_name = guess_coord_name(da.coords, Y_NAME_CANDIDATES, y_name, "y") + if subsample is not None: + da = da.isel({actual_x_name: slice(None, None, subsample), actual_y_name: slice(None, None, subsample)}) + if da[actual_x_name].ndim == 1 and da[actual_y_name].ndim == 1: da = da.sortby(actual_x_name).sortby(actual_y_name) crs_ = process_crs(da, crs) diff --git a/tests/test_plot.py b/tests/test_plot.py index b61c546..d011d0d 100644 --- a/tests/test_plot.py +++ b/tests/test_plot.py @@ -45,3 +45,8 @@ def test_plot_da_with_linestring_borders(air_data): gdf = gpd.GeoDataFrame(geometry=[polygon, line], crs="EPSG:4326") plot_da(da=air_data.isel(time=0), borders=gdf, show=False) plt.close() + + +def test_plot_da_subsample(air_data): + plot_da(da=air_data.isel(time=0), subsample=2, show=False) + plt.close()