-
Notifications
You must be signed in to change notification settings - Fork 72
Description
I need a sliding window transform for my research, which (similar to the batch function) collects consecutive records and releases them as a window. The difference to the batch function is that
- windows are overlapping
- I need to have a
stride >= 1between elements in each window.
My attempt was to subclass IterDataset and simply add a window function. However, due to the hierarchical structure of grain this approach is not sustainable, as I would have to re-implement each class that inherits from IterDataset. Therefore, I cloned the repository and added the window function to the parent class IterDataset:
class IterDataset(_Dataset, Iterable[T], metaclass=IterDatasetMeta):
def window(
self,
window_size: int,
shift: int = 1,
stride: int = 1,
*,
window_fn: Callable[[Sequence[T]], S] | None = None,
) -> IterDataset[S]:
"""Returns a sliding window view of consecutive elements along a new first dimension.
Dataset elements are expected to be PyTrees.
Example usage::
ds = MapDataset.range(5).to_iter_dataset()
ds = ds.window(window_size=3)
list(ds) == [np.ndarray([0, 1, 2]), np.ndarray([1, 2, 3]), np.ndarray([2, 3, 4]), np.ndarray([3, 4, 5])]
"""
from grain._src.python.dataset.transformations import window
return window.WindowIterDataset(
parent=self,
window_size=window_size,
shift=shift,
stride=stride,
window_fn=window_fn
)
The necessary window.WindowIterDataset transform code is attached at the bottom.
Now my problem is that I have to rebuild the project locally. I tried to get it to work via grain/oss/common_runner.sh; it worked for a long time but it did output a python wheel.
Please help me on how to build grain locally so that I can test my window function.
"""Implements window transformations."""
from __future__ import annotations
from collections import deque
from typing import Any, Callable, Sequence, TypeVar
from grain._src.core import tree_lib
from grain._src.python.dataset import base
from grain._src.python.dataset import dataset
from grain._src.python.dataset import stats
from .batch import make_batch as make_window
T = TypeVar("T")
S = TypeVar("S")
def _get_window_element_spec(
input_spec: Any,
window_size: int
):
return tree_lib.map_structure(
lambda x: base.ShapeDtypeStruct(
shape=(window_size,) + x.shape, dtype=x.dtype
),
input_spec,
)
class _WindowDatasetIterator(dataset.DatasetIterator[T]):
def __init__(
self,
parent: dataset.DatasetIterator[S],
window_size: int,
shift: int,
stride: int,
drop_remainder: bool,
window_fn: Callable[[Sequence[S]], T]
):
super.__init__(parent)
self._window_size = window_size
self._shift = shift
self._stride = stride
self._drop_remainder = drop_remainder
self._window_fn = window_fn
self._buffer = deque()
self._index = 0
self._num_required_elements = (window_size - 1) * stride + 1
self._exhausted = False
def _fill_buffer(self):
while len(self._buffer) < self._num_required_elements:
try:
self._buffer.append(next(self._parent))
except StopIteration:
self._exhausted = True
break
@stats.record_next_duration_if_output
@stats.trace_input_pipeline_next(stage_category=stats.IPL_CAT_PREPROCESSING)
def __next__(self) -> T:
self._fill_buffer()
# build window from current buffer
window = [self._buffer[i * self._stride] for i in range(self._window_size)]
# advance by shift and clear buffer
for _ in range(self._shift):
if self._buffer:
self._buffer.popleft()
try:
self._buffer.append(next(self._parent))
except StopIteration:
self._exhausted = True
with self._stats.record_self_time():
return self._stats.record_output_spec(self._window_fn(window))
def get_state(self):
self._parent.get_state()
def set_state(self, state):
self._parent.set_state(state)
def _get_next_index(self) -> int:
"""
Taken 1:1 from `_BatchDatasetIterator`
"""
return (
dataset.get_next_index(self._parent) + self._window_size - 1
) // self._window_size
def _set_next_index(self, index: int) -> None:
"""
Taken 1:1 from `_BatchDatasetIterator`
"""
dataset.set_next_index(self._parent, index * self._window_size)
def __str__(self) -> str:
return (
f"WindowDatasetIterator(window_size={self._window_size},"
f" shift={self._shift},"
f" stride={self._stride})"
)
class WindowIterDataset(dataset.IterDataset[T]):
"""
Implement window function following grain.IterDataset.batch()
"""
def __init__(
self,
parent: dataset.IterDataset[S],
window_size: int,
shift: int = 1,
stride: int = 1,
window_fn: Callable[[Sequence[S]], T] | None = None,
):
super.__init__(parent)
if window_size <= 1:
raise ValueError("window size must be positive and greater 1.")
self._window_size = window_size
self._shift = shift
self._stride = stride
self._window_fn = make_window if window_fn is None else window_fn
def __iter__(self) -> _WindowDatasetIterator[T]:
parent_iter = self._parent.__iter__()
return _WindowDatasetIterator(
parent_iter,
self._window_size,
shift=self._shift,
stride=self._stride,
window_fn=self._window_fn
)
@property
def _element_spec(self) -> Any:
return _get_window_element_spec(
dataset.get_element_spec(self._parent),
self._window_size
)
def __str__(self) -> str:
return (
f"WindowIterDataset(window_size={self._window_size},"
f" shift={self._shift},"
f" stride={self._stride})"
)