Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 27 additions & 16 deletions mapflow/_classic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from matplotlib.collections import PatchCollection
from matplotlib.collections import LineCollection
from matplotlib.colors import LogNorm, Normalize
from matplotlib.patches import Polygon as PolygonPatch
from pyproj import CRS
from shapely.geometry import MultiPolygon
from shapely.geometry import LineString, MultiLineString, Polygon, MultiPolygon
from tqdm.auto import tqdm

from ._misc import (
Expand Down Expand Up @@ -82,18 +81,23 @@ def __init__(self, x, y, crs=4326, borders=None):
else:
raise TypeError("borders must be a geopandas GeoDataFrame, GeoSeries, or None.")
borders_ = borders_.to_crs(self.crs).clip(bbox)
self.borders = self._shp_to_patches(borders_)
self.borders = self._shp_to_lines(borders_)

@staticmethod
def _shp_to_patches(gdf):
patches = []
for poly in gdf.geometry.values:
if isinstance(poly, MultiPolygon):
for polygon in poly.geoms:
patches.append(PolygonPatch(polygon.exterior.coords))
else:
patches.append(PolygonPatch(poly.exterior.coords))
return PatchCollection(patches, facecolor="none", linewidth=0.5, edgecolor="k")
def _shp_to_lines(gdf):
lines = []
for geom in gdf.geometry.values:
if isinstance(geom, Polygon):
lines.append(geom.exterior.coords)
elif isinstance(geom, MultiPolygon):
for poly in geom.geoms:
lines.append(poly.exterior.coords)
elif isinstance(geom, LineString):
lines.append(geom.coords)
elif isinstance(geom, MultiLineString):
for line in geom.geoms:
lines.append(line.coords)
return LineCollection(lines, linewidth=0.5, edgecolor="k")

@staticmethod
def _log_norm(data, vmin, vmax, qmin, qmax):
Expand Down Expand Up @@ -262,7 +266,7 @@ def __call__(
plt.show()


def plot_da(da: xr.DataArray, x_name=None, y_name=None, crs=4326, diff=False, **kwargs):
def plot_da(da: xr.DataArray, x_name=None, y_name=None, crs=4326, borders=None, diff=False, **kwargs):
"""Convenience function for quick plotting of an xarray DataArray using PlotModel.

This is a simplified wrapper around the `PlotModel` class that handles:
Expand All @@ -281,12 +285,14 @@ def plot_da(da: xr.DataArray, x_name=None, y_name=None, crs=4326, diff=False, **
y_name (str, optional): Name of the y-coordinate dimension. If None, will attempt to guess from `["y", "lat", "latitude"]`.
crs (int | str | CRS, optional): Coordinate Reference System. Can be an EPSG code, a PROJ string, or a pyproj.CRS object.
If the DataArray has a 'crs' attribute, that will be used by default. Defaults to 4326 (WGS84).
borders (gpd.GeoDataFrame | gpd.GeoSeries | None): Custom borders to use.
If None, defaults to world borders from a packaged GeoPackage.
diff (bool, optional): Whether to use a divergent colormap. Defaults to False.
**kwargs: Additional arguments passed to `PlotModel.__call__`, including:
- `figsize` (tuple, optional): Figure size (width, height) in inches.
- `qmin`/`qmax` (float, optional): Quantile ranges for color scaling (0-100).
- `vmin`/`vmax` (float, optional): Explicit value ranges for color scaling.
- `log` (bool, optional): Whether to use a logarithmic color scale.
- `diff` (bool, optional): Whether to use a divergent colormap.
- `cmap` (str, optional): Colormap name.
- `norm` (matplotlib.colors.Normalize, optional): Custom normalization object.
- `shading` (str, optional): Color shading method.
Expand Down Expand Up @@ -316,7 +322,12 @@ def plot_da(da: xr.DataArray, x_name=None, y_name=None, crs=4326, diff=False, **
if crs_.is_geographic:
da[actual_x_name] = xr.where(da[actual_x_name] > 180, da[actual_x_name] - 360, da[actual_x_name])

p = PlotModel(x=da[actual_x_name].values, y=da[actual_y_name].values, crs=crs_)
p = PlotModel(
x=da[actual_x_name].values,
y=da[actual_y_name].values,
crs=crs_,
borders=borders,
)
data = p._process_data(da.values)
p(data, diff=diff, **kwargs)

Expand Down
12 changes: 12 additions & 0 deletions tests/test_plot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import matplotlib.pyplot as plt
import pytest
import xarray as xr
import tempfile
from pathlib import Path
import geopandas as gpd
from shapely.geometry import Polygon, LineString

from mapflow import plot_da, plot_da_quiver

Expand Down Expand Up @@ -33,3 +37,11 @@ def test_plot_da_quiver(air_temperature_gradient_data):
def test_plot_da_diff(air_data):
plot_da(da=air_data.isel(time=0), diff=True, show=False)
plt.close()


def test_plot_da_with_linestring_borders(air_data):
polygon = Polygon([(250, 20), (251, 21), (251, 20)])
line = LineString([(260, 50), (261, 51)])
gdf = gpd.GeoDataFrame(geometry=[polygon, line], crs="EPSG:4326")
plot_da(da=air_data.isel(time=0), borders=gdf, show=False)
plt.close()