From 46f7173bb8f8f3cb855d953cf9fd6cbb0850612d Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Wed, 17 Sep 2025 09:53:14 +0000 Subject: [PATCH 1/2] feat: Handle LineString and MultiLineString in _shp_to_patches Refactored `_shp_to_patches` to `_shp_to_lines` to better reflect its new functionality. The function now handles `LineString` and `MultiLineString` geometries, in addition to `Polygon` and `MultiPolygon` geometries. Polygons are processed as their exteriors. The function now returns a `LineCollection` instead of a `PatchCollection`. The `plot_da` function was updated to allow passing `borders` as an argument. A test case was added to verify the new functionality. --- mapflow/_classic.py | 49 +++++++++++++++++++++++++++++---------------- tests/test_plot.py | 12 +++++++++++ 2 files changed, 44 insertions(+), 17 deletions(-) diff --git a/mapflow/_classic.py b/mapflow/_classic.py index e475965..cba2b0f 100644 --- a/mapflow/_classic.py +++ b/mapflow/_classic.py @@ -9,11 +9,10 @@ import matplotlib.pyplot as plt import numpy as np import xarray as xr -from matplotlib.collections import PatchCollection +from matplotlib.collections import LineCollection from matplotlib.colors import LogNorm, Normalize -from matplotlib.patches import Polygon as PolygonPatch from pyproj import CRS -from shapely.geometry import MultiPolygon +from shapely.geometry import LineString, MultiLineString, Polygon, MultiPolygon from tqdm.auto import tqdm from ._misc import ( @@ -82,18 +81,23 @@ def __init__(self, x, y, crs=4326, borders=None): else: raise TypeError("borders must be a geopandas GeoDataFrame, GeoSeries, or None.") borders_ = borders_.to_crs(self.crs).clip(bbox) - self.borders = self._shp_to_patches(borders_) + self.borders = self._shp_to_lines(borders_) @staticmethod - def _shp_to_patches(gdf): - patches = [] - for poly in gdf.geometry.values: - if isinstance(poly, MultiPolygon): - for polygon in poly.geoms: - patches.append(PolygonPatch(polygon.exterior.coords)) - else: - patches.append(PolygonPatch(poly.exterior.coords)) - return PatchCollection(patches, facecolor="none", linewidth=0.5, edgecolor="k") + def _shp_to_lines(gdf): + lines = [] + for geom in gdf.geometry.values: + if isinstance(geom, Polygon): + lines.append(geom.exterior.coords) + elif isinstance(geom, MultiPolygon): + for poly in geom.geoms: + lines.append(poly.exterior.coords) + elif isinstance(geom, LineString): + lines.append(geom.coords) + elif isinstance(geom, MultiLineString): + for line in geom.geoms: + lines.append(line.coords) + return LineCollection(lines, linewidth=0.5, edgecolor="k") @staticmethod def _log_norm(data, vmin, vmax, qmin, qmax): @@ -262,7 +266,9 @@ def __call__( plt.show() -def plot_da(da: xr.DataArray, x_name=None, y_name=None, crs=4326, diff=False, **kwargs): +def plot_da( + da: xr.DataArray, x_name=None, y_name=None, crs=4326, borders=None, diff=False, **kwargs +): """Convenience function for quick plotting of an xarray DataArray using PlotModel. This is a simplified wrapper around the `PlotModel` class that handles: @@ -281,12 +287,14 @@ def plot_da(da: xr.DataArray, x_name=None, y_name=None, crs=4326, diff=False, ** y_name (str, optional): Name of the y-coordinate dimension. If None, will attempt to guess from `["y", "lat", "latitude"]`. crs (int | str | CRS, optional): Coordinate Reference System. Can be an EPSG code, a PROJ string, or a pyproj.CRS object. If the DataArray has a 'crs' attribute, that will be used by default. Defaults to 4326 (WGS84). + borders (gpd.GeoDataFrame | gpd.GeoSeries | None): Custom borders to use. + If None, defaults to world borders from a packaged GeoPackage. + diff (bool, optional): Whether to use a divergent colormap. Defaults to False. **kwargs: Additional arguments passed to `PlotModel.__call__`, including: - `figsize` (tuple, optional): Figure size (width, height) in inches. - `qmin`/`qmax` (float, optional): Quantile ranges for color scaling (0-100). - `vmin`/`vmax` (float, optional): Explicit value ranges for color scaling. - `log` (bool, optional): Whether to use a logarithmic color scale. - - `diff` (bool, optional): Whether to use a divergent colormap. - `cmap` (str, optional): Colormap name. - `norm` (matplotlib.colors.Normalize, optional): Custom normalization object. - `shading` (str, optional): Color shading method. @@ -314,9 +322,16 @@ def plot_da(da: xr.DataArray, x_name=None, y_name=None, crs=4326, diff=False, ** da = da.sortby(actual_x_name).sortby(actual_y_name) crs_ = process_crs(da, crs) if crs_.is_geographic: - da[actual_x_name] = xr.where(da[actual_x_name] > 180, da[actual_x_name] - 360, da[actual_x_name]) + da[actual_x_name] = xr.where( + da[actual_x_name] > 180, da[actual_x_name] - 360, da[actual_x_name] + ) - p = PlotModel(x=da[actual_x_name].values, y=da[actual_y_name].values, crs=crs_) + p = PlotModel( + x=da[actual_x_name].values, + y=da[actual_y_name].values, + crs=crs_, + borders=borders, + ) data = p._process_data(da.values) p(data, diff=diff, **kwargs) diff --git a/tests/test_plot.py b/tests/test_plot.py index ee49206..b61c546 100644 --- a/tests/test_plot.py +++ b/tests/test_plot.py @@ -1,6 +1,10 @@ import matplotlib.pyplot as plt import pytest import xarray as xr +import tempfile +from pathlib import Path +import geopandas as gpd +from shapely.geometry import Polygon, LineString from mapflow import plot_da, plot_da_quiver @@ -33,3 +37,11 @@ def test_plot_da_quiver(air_temperature_gradient_data): def test_plot_da_diff(air_data): plot_da(da=air_data.isel(time=0), diff=True, show=False) plt.close() + + +def test_plot_da_with_linestring_borders(air_data): + polygon = Polygon([(250, 20), (251, 21), (251, 20)]) + line = LineString([(260, 50), (261, 51)]) + gdf = gpd.GeoDataFrame(geometry=[polygon, line], crs="EPSG:4326") + plot_da(da=air_data.isel(time=0), borders=gdf, show=False) + plt.close() From 3e47c07432a71fa13e70fa7886320d2cf1db222e Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Wed, 17 Sep 2025 09:55:48 +0000 Subject: [PATCH 2/2] Format code with Ruff --- mapflow/_classic.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/mapflow/_classic.py b/mapflow/_classic.py index cba2b0f..6275660 100644 --- a/mapflow/_classic.py +++ b/mapflow/_classic.py @@ -266,9 +266,7 @@ def __call__( plt.show() -def plot_da( - da: xr.DataArray, x_name=None, y_name=None, crs=4326, borders=None, diff=False, **kwargs -): +def plot_da(da: xr.DataArray, x_name=None, y_name=None, crs=4326, borders=None, diff=False, **kwargs): """Convenience function for quick plotting of an xarray DataArray using PlotModel. This is a simplified wrapper around the `PlotModel` class that handles: @@ -322,9 +320,7 @@ def plot_da( da = da.sortby(actual_x_name).sortby(actual_y_name) crs_ = process_crs(da, crs) if crs_.is_geographic: - da[actual_x_name] = xr.where( - da[actual_x_name] > 180, da[actual_x_name] - 360, da[actual_x_name] - ) + da[actual_x_name] = xr.where(da[actual_x_name] > 180, da[actual_x_name] - 360, da[actual_x_name]) p = PlotModel( x=da[actual_x_name].values,