diff --git a/mapflow/_classic.py b/mapflow/_classic.py index e475965..6275660 100644 --- a/mapflow/_classic.py +++ b/mapflow/_classic.py @@ -9,11 +9,10 @@ import matplotlib.pyplot as plt import numpy as np import xarray as xr -from matplotlib.collections import PatchCollection +from matplotlib.collections import LineCollection from matplotlib.colors import LogNorm, Normalize -from matplotlib.patches import Polygon as PolygonPatch from pyproj import CRS -from shapely.geometry import MultiPolygon +from shapely.geometry import LineString, MultiLineString, Polygon, MultiPolygon from tqdm.auto import tqdm from ._misc import ( @@ -82,18 +81,23 @@ def __init__(self, x, y, crs=4326, borders=None): else: raise TypeError("borders must be a geopandas GeoDataFrame, GeoSeries, or None.") borders_ = borders_.to_crs(self.crs).clip(bbox) - self.borders = self._shp_to_patches(borders_) + self.borders = self._shp_to_lines(borders_) @staticmethod - def _shp_to_patches(gdf): - patches = [] - for poly in gdf.geometry.values: - if isinstance(poly, MultiPolygon): - for polygon in poly.geoms: - patches.append(PolygonPatch(polygon.exterior.coords)) - else: - patches.append(PolygonPatch(poly.exterior.coords)) - return PatchCollection(patches, facecolor="none", linewidth=0.5, edgecolor="k") + def _shp_to_lines(gdf): + lines = [] + for geom in gdf.geometry.values: + if isinstance(geom, Polygon): + lines.append(geom.exterior.coords) + elif isinstance(geom, MultiPolygon): + for poly in geom.geoms: + lines.append(poly.exterior.coords) + elif isinstance(geom, LineString): + lines.append(geom.coords) + elif isinstance(geom, MultiLineString): + for line in geom.geoms: + lines.append(line.coords) + return LineCollection(lines, linewidth=0.5, edgecolor="k") @staticmethod def _log_norm(data, vmin, vmax, qmin, qmax): @@ -262,7 +266,7 @@ def __call__( plt.show() -def plot_da(da: xr.DataArray, x_name=None, y_name=None, crs=4326, diff=False, **kwargs): +def plot_da(da: xr.DataArray, x_name=None, y_name=None, crs=4326, borders=None, diff=False, **kwargs): """Convenience function for quick plotting of an xarray DataArray using PlotModel. This is a simplified wrapper around the `PlotModel` class that handles: @@ -281,12 +285,14 @@ def plot_da(da: xr.DataArray, x_name=None, y_name=None, crs=4326, diff=False, ** y_name (str, optional): Name of the y-coordinate dimension. If None, will attempt to guess from `["y", "lat", "latitude"]`. crs (int | str | CRS, optional): Coordinate Reference System. Can be an EPSG code, a PROJ string, or a pyproj.CRS object. If the DataArray has a 'crs' attribute, that will be used by default. Defaults to 4326 (WGS84). + borders (gpd.GeoDataFrame | gpd.GeoSeries | None): Custom borders to use. + If None, defaults to world borders from a packaged GeoPackage. + diff (bool, optional): Whether to use a divergent colormap. Defaults to False. **kwargs: Additional arguments passed to `PlotModel.__call__`, including: - `figsize` (tuple, optional): Figure size (width, height) in inches. - `qmin`/`qmax` (float, optional): Quantile ranges for color scaling (0-100). - `vmin`/`vmax` (float, optional): Explicit value ranges for color scaling. - `log` (bool, optional): Whether to use a logarithmic color scale. - - `diff` (bool, optional): Whether to use a divergent colormap. - `cmap` (str, optional): Colormap name. - `norm` (matplotlib.colors.Normalize, optional): Custom normalization object. - `shading` (str, optional): Color shading method. @@ -316,7 +322,12 @@ def plot_da(da: xr.DataArray, x_name=None, y_name=None, crs=4326, diff=False, ** if crs_.is_geographic: da[actual_x_name] = xr.where(da[actual_x_name] > 180, da[actual_x_name] - 360, da[actual_x_name]) - p = PlotModel(x=da[actual_x_name].values, y=da[actual_y_name].values, crs=crs_) + p = PlotModel( + x=da[actual_x_name].values, + y=da[actual_y_name].values, + crs=crs_, + borders=borders, + ) data = p._process_data(da.values) p(data, diff=diff, **kwargs) diff --git a/tests/test_plot.py b/tests/test_plot.py index ee49206..b61c546 100644 --- a/tests/test_plot.py +++ b/tests/test_plot.py @@ -1,6 +1,10 @@ import matplotlib.pyplot as plt import pytest import xarray as xr +import tempfile +from pathlib import Path +import geopandas as gpd +from shapely.geometry import Polygon, LineString from mapflow import plot_da, plot_da_quiver @@ -33,3 +37,11 @@ def test_plot_da_quiver(air_temperature_gradient_data): def test_plot_da_diff(air_data): plot_da(da=air_data.isel(time=0), diff=True, show=False) plt.close() + + +def test_plot_da_with_linestring_borders(air_data): + polygon = Polygon([(250, 20), (251, 21), (251, 20)]) + line = LineString([(260, 50), (261, 51)]) + gdf = gpd.GeoDataFrame(geometry=[polygon, line], crs="EPSG:4326") + plot_da(da=air_data.isel(time=0), borders=gdf, show=False) + plt.close()