Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 138 additions & 1 deletion grain/_src/python/dataset/transformations/interleave.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@

from collections.abc import Sequence
import functools
import queue
import threading
from typing import Any, TypeVar
import weakref

from concurrent import futures
from grain._src.python import options as grain_options
from grain._src.python.dataset import base
from grain._src.python.dataset import dataset
Expand Down Expand Up @@ -308,9 +311,132 @@ def __str__(self) -> str:
)


class NonDeterministicInterleaveDatasetIterator(dataset.DatasetIterator[T]):
"""Iterates over datasets non-deterministically using a thread pool."""

def __init__(
self,
datasets: Sequence[dataset.IterDataset[T] | dataset.MapDataset[T]],
cycle_length: int,
num_make_iter_threads: int = 1,
iter_buffer_size: int = 1,
):
super().__init__()
self._datasets = datasets
self._cycle_length = min(cycle_length, len(datasets))
self._iter_buffer_size = max(1, iter_buffer_size)

self._queues = [
queue.Queue(maxsize=self._iter_buffer_size)
for _ in range(self._cycle_length)
]
self._next_queue_to_check = 0
self._should_stop = threading.Event()
self._executor = futures.ThreadPoolExecutor(
max_workers=self._cycle_length,
thread_name_prefix="grain-non-det-interleave-pool",
)
self._finished_datasets_count = 0
self._lock = threading.Lock()
self._started = False
self._next_dataset_to_submit = self._cycle_length

def _start_workers(self):
with self._lock:
if self._started:
return
self._started = True
# We seed the executor with twice the amount of work so that the executor
# is never waiting for work as it depletes it task queue.
initial_count = min(self._cycle_length * 2, len(self._datasets))
for i in range(initial_count):
self._executor.submit(self._worker_fn, i)

def _submit_next_dataset(self):
with self._lock:
if self._should_stop.is_set():
return
if self._next_dataset_to_submit < len(self._datasets):
index_to_submit = self._next_dataset_to_submit
self._next_dataset_to_submit += 1
self._executor.submit(self._worker_fn, index_to_submit)

def _worker_fn(self, index: int):
buffer = self._queues[index % self._cycle_length]
iterator = _add_prefetch_and_make_iterator(
self._datasets[index],
interleave_iterator=weakref.ref(self),
start_prefetch=True,
)
try:
while not self._should_stop.is_set():
element = iterator.__next__()
buffer.put((element, None))
except StopIteration:
buffer.put((None, StopIteration))
self._submit_next_dataset()
except Exception as e: # pylint: disable=broad-except
buffer.put((None, e))

@stats.record_next_duration_if_output
@stats.trace_input_pipeline_next(stage_category=stats.IPL_CAT_PREPROCESSING)
def __next__(self) -> T:
self._assert_not_closed()
if not self._started:
self._start_workers()
_ = self._stats # eagerly initialize stats

while True:
if self._finished_datasets_count == len(self._datasets) and all(
q.empty() for q in self._queues
):
raise StopIteration

for idx in range(self._cycle_length):
q = self._queues[idx]
try:
item, err = q.get_nowait()
if err is StopIteration:
with self._lock:
self._finished_datasets_count += 1
break
elif err is not None:
raise err
item = self._stats.record_bytes_produced(item)
return self._stats.record_output_spec(item)
except queue.Empty:
continue

def get_state(self) -> dict[str, Any]:
return {
"finished_datasets_count": self._finished_datasets_count,
"num_datasets": len(self._datasets),
}

def set_state(self, state: dict[str, Any]) -> None:
raise NotImplementedError(
"set_state is not supported for non-deterministic interleaving."
)

def close(self) -> None:
if self._closed:
return
self._closed = True
self._should_stop.set()
self._executor.shutdown(wait=False)
for q in self._queues:
while not q.empty():
try:
q.get_nowait()
except queue.Empty:
break


def _add_prefetch_and_make_iterator(
ds: dataset.IterDataset[T] | dataset.MapDataset[T],
interleave_iterator: weakref.ref[InterleaveDatasetIterator[T]],
interleave_iterator: weakref.ref[
InterleaveDatasetIterator | NonDeterministicInterleaveDatasetIterator
],
start_prefetch: bool,
) -> dataset.DatasetIterator[T]:
"""Adds prefetching to an IterDataset and returns an iterator.
Expand Down Expand Up @@ -384,6 +510,7 @@ def __init__(
num_make_iter_threads: int = 1,
make_iter_buffer_size: int = 1,
iter_buffer_size: int = 1,
deterministic: bool = True,
):
"""Initializes the InterleaveIterDataset.

Expand All @@ -405,15 +532,25 @@ def __init__(
is 1, with this we'll always keep the next iterator ready in advance.
iter_buffer_size: Optional. The number of elements to prefetch from each
iterator. Default value is 1.
deterministic: Optional. If True, the iterators will be cycled through in
a round-robin fashion. If False, the next element will be taken from the
first iterator that has an element ready. Default value is True.
"""
super().__init__()
self._datasets = datasets
self._cycle_length = cycle_length
self._num_make_iter_threads = num_make_iter_threads
self._make_iter_buffer_size = make_iter_buffer_size
self._iter_buffer_size = iter_buffer_size
self._deterministic = deterministic

def __iter__(self) -> dataset.DatasetIterator[T]:
if not self._deterministic:
return NonDeterministicInterleaveDatasetIterator(
self._datasets,
cycle_length=self._cycle_length,
iter_buffer_size=self._iter_buffer_size,
)
return InterleaveDatasetIterator(
self._datasets,
cycle_length=self._cycle_length,
Expand Down
39 changes: 37 additions & 2 deletions grain/_src/python/dataset/transformations/interleave_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from grain._src.python.dataset import base
from grain._src.python.dataset import dataset
from grain._src.python.dataset.transformations import interleave
from grain._src.python.testing.experimental import assert_equal_output_after_checkpoint
from grain._src.python.testing import experimental
import numpy as np


Expand Down Expand Up @@ -187,7 +187,7 @@ def test_checkpointing_comprehensive(self):
for i in range(1, 6)
]
ds = interleave.InterleaveIterDataset(ds, cycle_length=5)
assert_equal_output_after_checkpoint(ds)
experimental.assert_equal_output_after_checkpoint(ds)

def test_set_state_does_not_recreate_iterators_if_not_needed(self):
cycle_length = 5
Expand Down Expand Up @@ -291,6 +291,41 @@ def test_set_next_index_with_multiple_datasets(self):
):
dataset.set_next_index(ds_iter, 0)

def test_non_deterministic_interleave(self):
ds1 = dataset.MapDataset.range(10).to_iter_dataset()
ds2 = dataset.MapDataset.range(10, 20).to_iter_dataset()
ds = interleave.InterleaveIterDataset(
[ds1, ds2],
cycle_length=2,
iter_buffer_size=2,
deterministic=False,
)
it = ds.__iter__()
first_element = next(it)
self.assertIn(first_element, [0, 10])

# We should produce all 20 elements.
elements = list(it)
elements += [first_element]
self.assertLen(elements, 20)
self.assertEqual(sorted(elements), list(range(0, 20)))

def test_non_deterministic_interleave_unsupported_checkpointing(self):
ds1 = dataset.MapDataset.range(10).to_iter_dataset()
ds2 = dataset.MapDataset.range(10, 20).to_iter_dataset()
ds = interleave.InterleaveIterDataset(
[ds1, ds2],
cycle_length=2,
deterministic=False,
)
it = ds.__iter__()

with self.assertRaisesRegex(
NotImplementedError,
"set_state is not supported for non-deterministic interleaving.",
):
it.set_state({})


if __name__ == "__main__":
absltest.main()
Loading