diff --git a/CHANGELOG.md b/CHANGELOG.md index 1122a2c5..3489933e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ Changelog follow https://keepachangelog.com/ format. ## [Unreleased] * `epy`: + * Add a `epy.sliding_window` iterator. * Add a `epy.classproperty` * Add a `epy.getuser`, colab-friendly alias of `getpass.getuser` * Better error when `epy.pretty_repr_top_level` is miss-used. diff --git a/etils/epy/__init__.py b/etils/epy/__init__.py index e47b27ae..ada9c32e 100644 --- a/etils/epy/__init__.py +++ b/etils/epy/__init__.py @@ -28,6 +28,7 @@ from etils.epy.env_utils import is_notebook from etils.epy.env_utils import is_test from etils.epy.itertools import groupby +from etils.epy.itertools import sliding_window from etils.epy.itertools import splitby from etils.epy.itertools import zip_dict from etils.epy.lazy_api_imports_utils import lazy_api_imports diff --git a/etils/epy/itertools.py b/etils/epy/itertools.py index d5853a3c..2882de94 100644 --- a/etils/epy/itertools.py +++ b/etils/epy/itertools.py @@ -153,3 +153,34 @@ def zip_dict( # pytype: disable=invalid-annotation for key in d0: # set merge all keys # Will raise KeyError if the dict don't have the same keys yield key, tuple(d[key] for d in dicts) + + +def sliding_window(iterable: Iterable[_T], n: int) -> Iterator[tuple[_T, ...]]: + """Returns a sliding window (of width n) over an iterable. + + ```python + epy.sliding_window([1, 2, 3, 4, 5, 6, 7, 8, 9], 3) == [ + (1, 2, 3), + (2, 3, 4), + (3, 4, 5), + ..., + (7, 8, 9), + ] + ``` + + Args: + iterable: The iterable to create the sliding window over + n: The width of the sliding window + + Returns: + The sliding window + """ + # Create n independent iterators from the original iterable + iters = itertools.tee(iterable, n) + + # Advance each iterator by its position in the 'iters' tuple + for i, it in enumerate(iters): + # Use `None` as default if the iterator is exhausted. + next(itertools.islice(it, i, i), None) + + return zip(*iters) diff --git a/etils/epy/itertools_test.py b/etils/epy/itertools_test.py index 89c461e3..bfe7756a 100644 --- a/etils/epy/itertools_test.py +++ b/etils/epy/itertools_test.py @@ -20,6 +20,21 @@ import pytest +def test_sliding_window(): + out = epy.sliding_window([1, 2, 3, 4, 5, 6, 7, 8, 9], 3) + assert list(out) == [ + (1, 2, 3), + (2, 3, 4), + (3, 4, 5), + (4, 5, 6), + (5, 6, 7), + (6, 7, 8), + (7, 8, 9), + ] + out = epy.sliding_window([1, 2, 3], 4) + assert list(out) == [] # pylint: disable=g-explicit-bool-comparison + + def test_group_by(): out = epy.groupby( [0, 30, 2, 4, 2, 20, 3],