Skip to content

Sliding window operation in IterDataset / build grain locally #1264

@MathiesW

Description

@MathiesW

I need a sliding window transform for my research, which (similar to the batch function) collects consecutive records and releases them as a window. The difference to the batch function is that

  1. windows are overlapping
  2. I need to have a stride >= 1 between elements in each window.

My attempt was to subclass IterDataset and simply add a window function. However, due to the hierarchical structure of grain this approach is not sustainable, as I would have to re-implement each class that inherits from IterDataset. Therefore, I cloned the repository and added the window function to the parent class IterDataset:

class IterDataset(_Dataset, Iterable[T], metaclass=IterDatasetMeta):
def window(
      self,
      window_size: int,
      shift: int = 1,
      stride: int = 1,
      *,
      window_fn: Callable[[Sequence[T]], S] | None = None,
  ) -> IterDataset[S]:
    """Returns a sliding window view of consecutive elements along a new first dimension.
    
    Dataset elements are expected to be PyTrees.

    Example usage::

      ds = MapDataset.range(5).to_iter_dataset()
      ds = ds.window(window_size=3)
      list(ds) == [np.ndarray([0, 1, 2]), np.ndarray([1, 2, 3]), np.ndarray([2, 3, 4]), np.ndarray([3, 4, 5])]

    """
    from grain._src.python.dataset.transformations import window
    return window.WindowIterDataset(
        parent=self,
        window_size=window_size,
        shift=shift,
        stride=stride,
        window_fn=window_fn
    )

The necessary window.WindowIterDataset transform code is attached at the bottom.
Now my problem is that I have to rebuild the project locally. I tried to get it to work via grain/oss/common_runner.sh; it worked for a long time but it did output a python wheel.
Please help me on how to build grain locally so that I can test my window function.

"""Implements window transformations."""

from __future__ import annotations

from collections import deque
from typing import Any, Callable, Sequence, TypeVar

from grain._src.core import tree_lib
from grain._src.python.dataset import base
from grain._src.python.dataset import dataset
from grain._src.python.dataset import stats

from .batch import make_batch as make_window

T = TypeVar("T")
S = TypeVar("S")


def _get_window_element_spec(
    input_spec: Any,
    window_size: int
):
    return tree_lib.map_structure(
        lambda x: base.ShapeDtypeStruct(
            shape=(window_size,) + x.shape, dtype=x.dtype
        ),
        input_spec,
    )


class _WindowDatasetIterator(dataset.DatasetIterator[T]):
    def __init__(
            self,
            parent: dataset.DatasetIterator[S],
            window_size: int,
            shift: int,
            stride: int,
            drop_remainder: bool,
            window_fn: Callable[[Sequence[S]], T]
    ):
        super.__init__(parent)
        self._window_size = window_size
        self._shift = shift
        self._stride = stride
        self._drop_remainder = drop_remainder
        self._window_fn = window_fn

        self._buffer = deque()
        self._index = 0
        self._num_required_elements = (window_size - 1) * stride + 1
        self._exhausted = False

    def _fill_buffer(self):
        while len(self._buffer) < self._num_required_elements:
            try:
                self._buffer.append(next(self._parent))
            except StopIteration:
                self._exhausted = True
                break

    @stats.record_next_duration_if_output
    @stats.trace_input_pipeline_next(stage_category=stats.IPL_CAT_PREPROCESSING)
    def __next__(self) -> T:
        self._fill_buffer()

        # build window from current buffer
        window = [self._buffer[i * self._stride] for i in range(self._window_size)]

        # advance by shift and clear buffer
        for _ in range(self._shift):
            if self._buffer:
                self._buffer.popleft()
            try:
                self._buffer.append(next(self._parent))
            except StopIteration:
                self._exhausted = True

        with self._stats.record_self_time():
            return self._stats.record_output_spec(self._window_fn(window))

    def get_state(self):
        self._parent.get_state()

    def set_state(self, state):
        self._parent.set_state(state)

    def _get_next_index(self) -> int:
        """
        Taken 1:1 from `_BatchDatasetIterator`
        
        """
        return (
            dataset.get_next_index(self._parent) + self._window_size - 1
        ) // self._window_size

    def _set_next_index(self, index: int) -> None:
        """
        Taken 1:1 from `_BatchDatasetIterator`
        
        """
        dataset.set_next_index(self._parent, index * self._window_size)

    def __str__(self) -> str:
        return (
            f"WindowDatasetIterator(window_size={self._window_size},"
            f" shift={self._shift},"
            f" stride={self._stride})"
        )

class WindowIterDataset(dataset.IterDataset[T]):
    """
    Implement window function following grain.IterDataset.batch()

    """

    def __init__(
        self,
        parent: dataset.IterDataset[S],
        window_size: int,
        shift: int = 1,
        stride: int = 1,
        window_fn: Callable[[Sequence[S]], T] | None = None,
    ):
        super.__init__(parent)
        if window_size <= 1:
            raise ValueError("window size must be positive and greater 1.")
        self._window_size = window_size
        self._shift = shift
        self._stride = stride
        self._window_fn = make_window if window_fn is None else window_fn
        
    def __iter__(self) -> _WindowDatasetIterator[T]:
        parent_iter = self._parent.__iter__()
        return _WindowDatasetIterator(
            parent_iter,
            self._window_size,
            shift=self._shift,
            stride=self._stride,
            window_fn=self._window_fn
        )
    
    @property
    def _element_spec(self) -> Any:
        return _get_window_element_spec(
            dataset.get_element_spec(self._parent),
            self._window_size
        )
    
    def __str__(self) -> str:
        return (
            f"WindowIterDataset(window_size={self._window_size},"
            f" shift={self._shift},"
            f" stride={self._stride})"
        )

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions