diff --git a/mapflow/_classic.py b/mapflow/_classic.py index ad34018..65834aa 100644 --- a/mapflow/_classic.py +++ b/mapflow/_classic.py @@ -397,14 +397,40 @@ def _process_title(title, upsample_ratio): else: raise ValueError("Title must be a string or a list of strings.") + def _calculate_animation_parameters(self, n_frames_raw, fps, upsample_ratio, duration): + if sum(p is not None for p in [fps, upsample_ratio, duration]) > 2: + raise ValueError("Only two of 'fps', 'upsample_ratio', and 'duration' can be provided.") + + if duration is not None: + if fps is not None: + if n_frames_raw > 1: + upsample_ratio = max(1, round((duration * fps - 1) / (n_frames_raw - 1))) + total_frames = (n_frames_raw - 1) * upsample_ratio + 1 + fps = total_frames / duration + else: + upsample_ratio = 1 + fps = 1 / duration + elif upsample_ratio is not None: + total_frames = (n_frames_raw - 1) * upsample_ratio + 1 if n_frames_raw > 1 else 1 + fps = total_frames / duration + else: # duration only + upsample_ratio = 2 + total_frames = (n_frames_raw - 1) * upsample_ratio + 1 if n_frames_raw > 1 else 1 + fps = total_frames / duration + else: # duration is None + fps = fps or 24 + upsample_ratio = upsample_ratio or 2 + return fps, upsample_ratio + def __call__( self, data, path, figsize: tuple = None, title=None, - fps: int = 24, - upsample_ratio: int = 2, + fps: int = None, + upsample_ratio: int = None, + duration: int = None, cmap="jet", qmin=0.01, qmax=99.9, @@ -439,6 +465,8 @@ def __call__( Defaults to 24. upsample_ratio (int, optional): Factor by which to upsample the data along the time axis for smoother animations. Defaults to 2. + duration (int, optional): Duration of the video in seconds. + Only two of 'fps', 'upsample_ratio', and 'duration' can be provided. cmap (str, optional): Colormap to use for the plot. Defaults to "jet". qmin (float, optional): Minimum quantile for color normalization. Defaults to 0.01. @@ -458,6 +486,9 @@ def __call__( """ if diff: cmap = "bwr" + + fps, upsample_ratio = self._calculate_animation_parameters(len(data), fps, upsample_ratio, duration) + norm = self.plot._norm(data, vmin, vmax, qmin, qmax, norm, log, diff) self._animate( data=data, @@ -615,6 +646,7 @@ def _create_video(tempdir, path, fps, timeout, crf=20): def animate( da: xr.DataArray, path: str, + *, time_name: str = None, x_name: str = None, y_name: str = None, @@ -622,6 +654,9 @@ def animate( borders: gpd.GeoDataFrame | gpd.GeoSeries | None = None, verbose: int = 0, diff=False, + fps: int = None, + upsample_ratio: int = None, + duration: int = None, **kwargs, ): """Creates an animation from an xarray DataArray. @@ -647,6 +682,9 @@ def animate( world borders. Defaults to None. verbose (int, optional): Verbosity level for the Animation class. Defaults to 0. + fps (int, optional): Frames per second for the output video. Defaults to 24. + upsample_ratio (int, optional): Factor to upsample data temporally. Defaults to 2. + duration (int, optional): Duration of the video in seconds. **kwargs: Additional keyword arguments passed to the `Animation` class, including: - `cmap` (str, optional): Colormap for the plot. - `norm` (matplotlib.colors.Normalize, optional): Custom normalization object. @@ -657,8 +695,6 @@ def animate( - `vmin` (float, optional): Minimum value for color normalization. - `vmax` (float, optional): Maximum value for color normalization. - `time_format` (str, optional): Strftime format for time in titles. - - `upsample_ratio` (int, optional): Factor to upsample data temporally. - - `fps` (int, optional): Frames per second for the video. - `n_jobs` (int, optional): Number of parallel jobs for frame generation. - `dpi` (int, optional): Dots per inch for the saved frames. - `timeout` (str | int, optional): Timeout for video creation. @@ -702,5 +738,8 @@ def animate( title=titles, label=unit, diff=diff, + fps=fps, + upsample_ratio=upsample_ratio, + duration=duration, **kwargs, ) diff --git a/mapflow/_misc.py b/mapflow/_misc.py index 8502433..90ca8fc 100644 --- a/mapflow/_misc.py +++ b/mapflow/_misc.py @@ -111,10 +111,10 @@ def check_da(da, time_name, x_name, y_name, crs): da = da.sortby(time_name).squeeze() - if da.ndim != 3: - raise ValueError( - f"DataArray must have 3 dimensions ({time_name}, {y_name}, {x_name}), got {da.ndim} dimensions." - ) + if da.ndim == 2: + da = da.expand_dims(time_name) + elif da.ndim != 3: + raise ValueError(f"DataArray must have 2 or 3 dimensions, but got {da.ndim} dimensions.") # Ensure time is the first dimension if da[x_name].ndim == 1 and da[y_name].ndim == 1: diff --git a/tests/test_animate.py b/tests/test_animate.py index afc2b51..613c8cc 100644 --- a/tests/test_animate.py +++ b/tests/test_animate.py @@ -1,4 +1,5 @@ import os +import subprocess from tempfile import TemporaryDirectory import geopandas as gpd @@ -76,6 +77,103 @@ def test_animate_2d(air_data_2d_coordinates): assert os.path.exists(path) +def get_video_duration(path): + cmd = [ + "ffprobe", + "-v", + "error", + "-show_entries", + "format=duration", + "-of", + "default=noprint_wrappers=1:nokey=1", + path, + ] + result = subprocess.run(cmd, capture_output=True, text=True) + return float(result.stdout) + + +def test_animate_duration_fps(air_data): + duration = 1 + with TemporaryDirectory() as tmpdir: + path = f"{tmpdir}/test_animation_duration_fps.mp4" + animate( + da=air_data, + path=path, + x_name="lon", + y_name="lat", + duration=duration, + fps=30, + verbose=True, + ) + assert os.path.exists(path) + assert abs(get_video_duration(path) - duration) < 0.1 + + +def test_animate_duration_upsample_ratio(air_data): + duration = 1 + with TemporaryDirectory() as tmpdir: + path = f"{tmpdir}/test_animation_duration_upsample.mp4" + animate( + da=air_data, + path=path, + x_name="lon", + y_name="lat", + duration=duration, + upsample_ratio=5, + verbose=True, + ) + assert os.path.exists(path) + assert abs(get_video_duration(path) - duration) < 0.1 + + +def test_animate_duration_only(air_data): + duration = 1 + with TemporaryDirectory() as tmpdir: + path = f"{tmpdir}/test_animation_duration_only.mp4" + animate( + da=air_data, + path=path, + x_name="lon", + y_name="lat", + duration=duration, + verbose=True, + ) + assert os.path.exists(path) + assert abs(get_video_duration(path) - duration) < 0.1 + + +def test_animate_single_frame_duration(air_data): + duration = 2 + with TemporaryDirectory() as tmpdir: + path = f"{tmpdir}/test_animation_single_frame_duration.mp4" + animate( + da=air_data.isel(time=slice(0, 1)), + path=path, + x_name="lon", + y_name="lat", + duration=duration, + fps=10, # This fps will be overridden + verbose=True, + ) + assert os.path.exists(path) + assert abs(get_video_duration(path) - duration) < 0.1 + + +def test_animate_conflicting_args(air_data): + with TemporaryDirectory() as tmpdir: + with pytest.raises(ValueError): + animate( + da=air_data, + path=f"{tmpdir}/test.mp4", + x_name="lon", + y_name="lat", + fps=24, + upsample_ratio=2, + duration=5, + verbose=True, + ) + + def test_animate_quiver(air_temperature_gradient_data): with TemporaryDirectory() as tmpdir: path = f"{tmpdir}/test_animation_quiver.mp4"