From 9522a48c7a56e3ffe20a9f6d4f728b9583d62ecf Mon Sep 17 00:00:00 2001 From: Sagun Bajra Date: Wed, 25 Mar 2026 12:55:56 -0700 Subject: [PATCH] Add an option to prefetch from the first available iterator in InterleaveDataset guarded by the `allow_reordering` option. PiperOrigin-RevId: 889390420 --- .../dataset/transformations/interleave.py | 139 +++++++++++++++++- .../transformations/interleave_test.py | 39 ++++- 2 files changed, 175 insertions(+), 3 deletions(-) diff --git a/grain/_src/python/dataset/transformations/interleave.py b/grain/_src/python/dataset/transformations/interleave.py index 9db05ba90..dafd46f88 100644 --- a/grain/_src/python/dataset/transformations/interleave.py +++ b/grain/_src/python/dataset/transformations/interleave.py @@ -15,9 +15,12 @@ from collections.abc import Sequence import functools +import queue +import threading from typing import Any, TypeVar import weakref +from concurrent import futures from grain._src.python import options as grain_options from grain._src.python.dataset import base from grain._src.python.dataset import dataset @@ -308,9 +311,132 @@ def __str__(self) -> str: ) +class NonDeterministicInterleaveDatasetIterator(dataset.DatasetIterator[T]): + """Iterates over datasets non-deterministically using a thread pool.""" + + def __init__( + self, + datasets: Sequence[dataset.IterDataset[T] | dataset.MapDataset[T]], + cycle_length: int, + num_make_iter_threads: int = 1, + iter_buffer_size: int = 1, + ): + super().__init__() + self._datasets = datasets + self._cycle_length = min(cycle_length, len(datasets)) + self._iter_buffer_size = max(1, iter_buffer_size) + + self._queues = [ + queue.Queue(maxsize=self._iter_buffer_size) + for _ in range(self._cycle_length) + ] + self._next_queue_to_check = 0 + self._should_stop = threading.Event() + self._executor = futures.ThreadPoolExecutor( + max_workers=self._cycle_length, + thread_name_prefix="grain-non-det-interleave-pool", + ) + self._finished_datasets_count = 0 + self._lock = threading.Lock() + self._started = False + self._next_dataset_to_submit = self._cycle_length + + def _start_workers(self): + with self._lock: + if self._started: + return + self._started = True + # We seed the executor with twice the amount of work so that the executor + # is never waiting for work as it depletes it task queue. + initial_count = min(self._cycle_length * 2, len(self._datasets)) + for i in range(initial_count): + self._executor.submit(self._worker_fn, i) + + def _submit_next_dataset(self): + with self._lock: + if self._should_stop.is_set(): + return + if self._next_dataset_to_submit < len(self._datasets): + index_to_submit = self._next_dataset_to_submit + self._next_dataset_to_submit += 1 + self._executor.submit(self._worker_fn, index_to_submit) + + def _worker_fn(self, index: int): + buffer = self._queues[index % self._cycle_length] + iterator = _add_prefetch_and_make_iterator( + self._datasets[index], + interleave_iterator=weakref.ref(self), + start_prefetch=True, + ) + try: + while not self._should_stop.is_set(): + element = iterator.__next__() + buffer.put((element, None)) + except StopIteration: + buffer.put((None, StopIteration)) + self._submit_next_dataset() + except Exception as e: # pylint: disable=broad-except + buffer.put((None, e)) + + @stats.record_next_duration_if_output + @stats.trace_input_pipeline_next(stage_category=stats.IPL_CAT_PREPROCESSING) + def __next__(self) -> T: + self._assert_not_closed() + if not self._started: + self._start_workers() + _ = self._stats # eagerly initialize stats + + while True: + if self._finished_datasets_count == len(self._datasets) and all( + q.empty() for q in self._queues + ): + raise StopIteration + + for idx in range(self._cycle_length): + q = self._queues[idx] + try: + item, err = q.get_nowait() + if err is StopIteration: + with self._lock: + self._finished_datasets_count += 1 + break + elif err is not None: + raise err + item = self._stats.record_bytes_produced(item) + return self._stats.record_output_spec(item) + except queue.Empty: + continue + + def get_state(self) -> dict[str, Any]: + return { + "finished_datasets_count": self._finished_datasets_count, + "num_datasets": len(self._datasets), + } + + def set_state(self, state: dict[str, Any]) -> None: + raise NotImplementedError( + "set_state is not supported for non-deterministic interleaving." + ) + + def close(self) -> None: + if self._closed: + return + self._closed = True + self._should_stop.set() + self._executor.shutdown(wait=False) + for q in self._queues: + while not q.empty(): + try: + q.get_nowait() + except queue.Empty: + break + + def _add_prefetch_and_make_iterator( ds: dataset.IterDataset[T] | dataset.MapDataset[T], - interleave_iterator: weakref.ref[InterleaveDatasetIterator[T]], + interleave_iterator: weakref.ref[ + InterleaveDatasetIterator | NonDeterministicInterleaveDatasetIterator + ], start_prefetch: bool, ) -> dataset.DatasetIterator[T]: """Adds prefetching to an IterDataset and returns an iterator. @@ -384,6 +510,7 @@ def __init__( num_make_iter_threads: int = 1, make_iter_buffer_size: int = 1, iter_buffer_size: int = 1, + deterministic: bool = True, ): """Initializes the InterleaveIterDataset. @@ -405,6 +532,9 @@ def __init__( is 1, with this we'll always keep the next iterator ready in advance. iter_buffer_size: Optional. The number of elements to prefetch from each iterator. Default value is 1. + deterministic: Optional. If True, the iterators will be cycled through in + a round-robin fashion. If False, the next element will be taken from the + first iterator that has an element ready. Default value is True. """ super().__init__() self._datasets = datasets @@ -412,8 +542,15 @@ def __init__( self._num_make_iter_threads = num_make_iter_threads self._make_iter_buffer_size = make_iter_buffer_size self._iter_buffer_size = iter_buffer_size + self._deterministic = deterministic def __iter__(self) -> dataset.DatasetIterator[T]: + if not self._deterministic: + return NonDeterministicInterleaveDatasetIterator( + self._datasets, + cycle_length=self._cycle_length, + iter_buffer_size=self._iter_buffer_size, + ) return InterleaveDatasetIterator( self._datasets, cycle_length=self._cycle_length, diff --git a/grain/_src/python/dataset/transformations/interleave_test.py b/grain/_src/python/dataset/transformations/interleave_test.py index c5d647307..786411310 100644 --- a/grain/_src/python/dataset/transformations/interleave_test.py +++ b/grain/_src/python/dataset/transformations/interleave_test.py @@ -20,7 +20,7 @@ from grain._src.python.dataset import base from grain._src.python.dataset import dataset from grain._src.python.dataset.transformations import interleave -from grain._src.python.testing.experimental import assert_equal_output_after_checkpoint +from grain._src.python.testing import experimental import numpy as np @@ -187,7 +187,7 @@ def test_checkpointing_comprehensive(self): for i in range(1, 6) ] ds = interleave.InterleaveIterDataset(ds, cycle_length=5) - assert_equal_output_after_checkpoint(ds) + experimental.assert_equal_output_after_checkpoint(ds) def test_set_state_does_not_recreate_iterators_if_not_needed(self): cycle_length = 5 @@ -291,6 +291,41 @@ def test_set_next_index_with_multiple_datasets(self): ): dataset.set_next_index(ds_iter, 0) + def test_non_deterministic_interleave(self): + ds1 = dataset.MapDataset.range(10).to_iter_dataset() + ds2 = dataset.MapDataset.range(10, 20).to_iter_dataset() + ds = interleave.InterleaveIterDataset( + [ds1, ds2], + cycle_length=2, + iter_buffer_size=2, + deterministic=False, + ) + it = ds.__iter__() + first_element = next(it) + self.assertIn(first_element, [0, 10]) + + # We should produce all 20 elements. + elements = list(it) + elements += [first_element] + self.assertLen(elements, 20) + self.assertEqual(sorted(elements), list(range(0, 20))) + + def test_non_deterministic_interleave_unsupported_checkpointing(self): + ds1 = dataset.MapDataset.range(10).to_iter_dataset() + ds2 = dataset.MapDataset.range(10, 20).to_iter_dataset() + ds = interleave.InterleaveIterDataset( + [ds1, ds2], + cycle_length=2, + deterministic=False, + ) + it = ds.__iter__() + + with self.assertRaisesRegex( + NotImplementedError, + "set_state is not supported for non-deterministic interleaving.", + ): + it.set_state({}) + if __name__ == "__main__": absltest.main()