From 3d313d3f694621404325b476a1a07e90ff063421 Mon Sep 17 00:00:00 2001 From: Phani Aenugula Date: Tue, 3 Feb 2026 09:21:46 -0800 Subject: [PATCH] Add `ElasticIterDatasetIterator` to handle scaling up and down between checkpoints. * Allows users to keep their pipelines elastic and restore from a checkpoint with variable amount of hosts Limitations * Does not guarantee determinism between scaling * The limit of parallelism is the number of shards PiperOrigin-RevId: 864910611 --- CHANGELOG.md | 1 + grain/_src/python/dataset/elastic_iterator.py | 300 ++++++++++++++++-- .../python/dataset/elastic_iterator_test.py | 149 +++++++-- .../dataset/transformations/interleave.py | 35 +- .../transformations/interleave_test.py | 31 ++ grain/experimental.py | 5 +- 6 files changed, 458 insertions(+), 63 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e3b3fcccc..30f438e1c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,7 @@ changes. Best viewed [here](https://google-grain.readthedocs.io/en/latest/change * Exposes `SharedMemoryArrayMetadata` in a public API as a metadata descriptor for `SharedMemoryArray`. * `ParquetIterDataset` can read from multiple string paths interleaving reads. + * Add `ElasticIterDatasetIterator` for scaling up and down the number of shards between checkpoints. * Breaking changes: * Custom implementations of `RandomAccessDataSource` should accept `int` diff --git a/grain/_src/python/dataset/elastic_iterator.py b/grain/_src/python/dataset/elastic_iterator.py index a7140c172..92de70826 100644 --- a/grain/_src/python/dataset/elastic_iterator.py +++ b/grain/_src/python/dataset/elastic_iterator.py @@ -13,8 +13,9 @@ # limitations under the License. """Iterator supporting changes in the number of hosts (dataset shards).""" +import copy import functools -from typing import Any +from typing import Any, TypeVar, cast from grain._src.core import sharding from grain._src.python import options @@ -22,53 +23,203 @@ from grain._src.python.dataset.transformations import ( filter as filter_dataset, ) +from grain._src.python.dataset.transformations import interleave +from grain._src.python.dataset.transformations import prefetch + +T = TypeVar("T") _GLOBAL_NEXT_INDEX_STATE_KEY = "global_next_index" -class ElasticIterator(dataset.DatasetIterator): - """Iterator supporting recovery from a checkpoint after changes in sharding. +class ElasticIterDatasetIterator(dataset.DatasetIterator): + """Elastic iterator for InterleaveIterDatasets.""" - The input dataset is expected to be unbatched and unsharded. In order to - provide elasticity guarantee this iterator includes both, batching and - sharding. The iterator supports elastic re-configuration by having each - shard produce the same exact checkpoint (while producing different data) as - long as they are advanced the same number of steps. + def __init__( + self, + interleave_dataset: interleave.InterleaveIterDataset, + shard_options: sharding.ShardOptions, + global_batch_size: int, + drop_remainder: bool, + read_options: options.ReadOptions, + multiprocessing_options: options.MultiprocessingOptions | None = None, + ): + # We must set the slice on the original dataset so that the interleave + # iterator is created with the correct (sliced) datasets. + self._ds: interleave.InterleaveIterDataset = copy.deepcopy( + interleave_dataset + ) + self._num_dataset_shards = len(interleave_dataset._datasets) # pylint: disable=protected-access + self._num_host_shards = len(self._ds._datasets) # pylint: disable=protected-access + self._cycle_length = self._ds._cycle_length # pylint: disable=protected-access - State of any shard can be used to restore the state of all of the shards after - changes in sharding and global batch size. + self._global_batch_size = global_batch_size + self._drop_remainder = drop_remainder + self._shard_options = shard_options + self._read_options = read_options + self._multiprocessing_options = multiprocessing_options - This iterator explicitly disallows many-to-one transformations without - a fixed ratio, like `filter` and generic `IterDataset` transformations. - """ + # These will be initialized when the iterator is created. + self._iterator_started = False + self._is_batched = False + self._closed = False + + @property + def num_dataset_shards(self) -> int: + return self._num_dataset_shards + + @property + def num_host_shards(self) -> int: + return self._num_host_shards + + @property + def shard_options(self) -> sharding.ShardOptions: + return self._shard_options + + def close(self): + if self._closed: + return + self._closed = True + if "_iterator" in self.__dict__: + self._iterator.close() + + @functools.cached_property + def _iterator(self) -> dataset.DatasetIterator: + ds = self._ds + self._iterator_started = True + if self._global_batch_size > 0: + ds = ds.batch( + self._global_batch_size, drop_remainder=self._drop_remainder + ) + self._is_batched = True + if self._multiprocessing_options: + self._prefetch_wrapped = True + # ds = ds.mp_prefetch(self._multiprocessing_options) + return ds.__iter__() + + def __next__(self) -> Any: + return next(self._iterator) + + def get_state(self): + state = self._iterator.get_state() + ds_iterator_states = {} + + indices = state["iterators_in_use_indices"] + states = state["iterators_in_use_states"] + exhausted = state["exhausted"] + next_index_in_datasets = state["next_index_in_datasets"] + if self._is_batched: + interleave_iter = cast(interleave.InterleaveDatasetIterator, self._iterator._parent) # pylint: disable=protected-access + else: + interleave_iter = cast( + interleave.InterleaveDatasetIterator, self._iterator + ) + for i in range(self._num_host_shards): + shard_index = ( + i * self._shard_options.shard_count + self._shard_options.shard_index + ) + # If the current shard index is greater than or equal to the next + # index in datasets, it means the current shard has not yet started + # to be iterated on. + if i >= next_index_in_datasets: + ds_iterator_states[shard_index] = { + "exhausted": 0, + "state": interleave_iter._get_iterator_start_state(i), # pylint: disable=protected-access + } + elif i not in indices: + # These shards are exhausted but should still create a state to maintain + # static state spec shapes. + ds_iterator_states[shard_index] = { + "exhausted": 1, + "state": interleave_iter._get_iterator_start_state(i), # pylint: disable=protected-access + } + + for index, state, is_exhausted in zip(indices, states, exhausted): + # These shards are currently being iterated on. + shard_index = ( + index * self._shard_options.shard_count + + self._shard_options.shard_index + ) + ds_iterator_states[shard_index] = { + "exhausted": is_exhausted, + "state": state, + } + + return { + "ds_iterator_states": ds_iterator_states, + } + + def set_state(self, state): + """Sets state by reconstructing the state for the underlying interleave.""" + ds_iterator_states = state["ds_iterator_states"] + active_states = [] + + for shard_index, shard_state in sorted(ds_iterator_states.items()): + # Check if this state belongs to the current shard. + if ( + shard_index - self._shard_options.shard_index + ) % self._shard_options.shard_count == 0: + slice_index = shard_index // self._shard_options.shard_count + if not shard_state["exhausted"]: + active_states.append((slice_index, shard_state["state"])) + + iterators_in_use_indices = [] + iterators_in_use_states = [] + exhausted = [] + count = 0 + future_states = {} + for ind, state in active_states: + if count < self._cycle_length: + iterators_in_use_indices.append(ind) + iterators_in_use_states.append(state) + exhausted.append(0) + count += 1 + elif state: + # If a state exists for this iterator add it to future states + future_states[ind] = state + next_index_in_datasets = max(iterators_in_use_indices) + 1 + while count < self._cycle_length: + iterators_in_use_indices.append(next_index_in_datasets) + iterators_in_use_states.append(None) + exhausted.append(1) + count += 1 + + new_state = { + "next_index_in_cycle": 0, + "next_index_in_datasets": next_index_in_datasets, + "iterators_in_use_indices": iterators_in_use_indices, + "iterators_in_use_states": iterators_in_use_states, + "exhausted": exhausted, + "future_states": future_states, + } + if "_iterator" in self.__dict__: + self.__dict__["_iterator"].close() + self.__dict__.pop("_iterator", None) + self._iterator.set_state(new_state) + + +class _ElasticMapDatasetIterator(dataset.DatasetIterator): + """Iterator for MapDatasets in ElasticIterator.""" def __init__( self, ds: dataset.MapDataset, - global_batch_size: int, shard_options: sharding.ShardOptions, - *, + global_batch_size: int, + drop_remainder: bool, read_options: options.ReadOptions = options.ReadOptions(), multiprocessing_options: options.MultiprocessingOptions | None = None, ): - super().__init__() - to_check = [ds] - while to_check: - next_ds = to_check.pop() - if isinstance(next_ds, filter_dataset.FilterMapDataset): - raise ValueError( - "ElasticIterator does not support `filter` transformation." - ) - to_check.extend(next_ds.parents) self._ds = ds - self._global_batch_size = global_batch_size self._shard_options = shard_options - self._global_next_index = 0 + self._global_batch_size = global_batch_size + self._drop_remainder = drop_remainder self._read_options = read_options self._multiprocessing_options = multiprocessing_options + self._global_next_index = 0 + self._closed = False @functools.cached_property - def _iterator(self) -> dataset.DatasetIterator: + def _iterator(self): ds = self._ds[ self._global_next_index + self._shard_options.shard_index :: self._shard_options.shard_count @@ -83,13 +234,10 @@ def _iterator(self) -> dataset.DatasetIterator: ) ds = ds.batch(host_batch_size, drop_remainder=True) ds = ds.to_iter_dataset(read_options=self._read_options) - if self._multiprocessing_options is not None: + if self._multiprocessing_options: ds = ds.mp_prefetch(self._multiprocessing_options) return ds.__iter__() - def __iter__(self) -> dataset.DatasetIterator: - return self - def __next__(self) -> Any: result = next(self._iterator) self._global_next_index += self._global_batch_size @@ -100,7 +248,93 @@ def get_state(self) -> dict[str, Any]: _GLOBAL_NEXT_INDEX_STATE_KEY: self._global_next_index, } - def set_state(self, state: dict[str, Any]): + def close(self): + if self._closed: + return + self._closed = True + if "_iterator" in self.__dict__: + self._iterator.close() + + def set_state(self, state): self._global_next_index = state[_GLOBAL_NEXT_INDEX_STATE_KEY] - # Reset the iterator if it was already created. + if "_iterator" in self.__dict__: + self.__dict__["_iterator"].close() self.__dict__.pop("_iterator", None) + + +class ElasticIterDataset(dataset.IterDataset): + """Iterator supporting recovery from a checkpoint after changes in sharding. + + The input dataset is expected to be unbatched and unsharded. In order to + provide elasticity guarantee this iterator includes both, batching and + sharding. The iterator supports elastic re-configuration by having each + shard produce the same exact checkpoint (while producing different data) as + long as they are advanced the same number of steps. + + State of any shard can be used to restore the state of all of the shards after + changes in sharding and global batch size. + + This iterator explicitly disallows many-to-one transformations without + a fixed ratio, like `filter` and generic `IterDataset` transformations. + """ + + def __init__( + self, + parent: dataset.MapDataset | dataset.IterDataset, + global_batch_size: int, + shard_options: sharding.ShardOptions, + *, + read_options: options.ReadOptions = options.ReadOptions(), + multiprocessing_options: options.MultiprocessingOptions | None = None, + drop_remainder: bool = False, + ): + super().__init__() + self._ds = parent + if isinstance(parent, dataset.IterDataset): + prefetch._set_slice_iter_dataset( + self._ds, + slice(shard_options.shard_index, None, shard_options.shard_count), + ) + else: + to_check = [parent] + while to_check: + next_ds = to_check.pop() + if isinstance(next_ds, filter_dataset.FilterMapDataset): + raise ValueError( + "ElasticIterDataset does not support `filter` transformation." + ) + to_check.extend(next_ds.parents) + + self._shard_options = shard_options + self._global_batch_size = global_batch_size + self._drop_remainder = drop_remainder + self._read_options = read_options + self._multiprocessing_options = multiprocessing_options + + @property + def shard_options(self) -> sharding.ShardOptions: + return self._shard_options + + def __iter__(self) -> dataset.DatasetIterator: + if isinstance(self._ds, interleave.InterleaveIterDataset): + return ElasticIterDatasetIterator( + self._ds, + self._shard_options, + self._global_batch_size, + self._drop_remainder, + self._read_options, + self._multiprocessing_options, + ) + else: + return _ElasticMapDatasetIterator( + self._ds, + self._shard_options, + self._global_batch_size, + self._drop_remainder, + self._read_options, + self._multiprocessing_options, + ) + + +# Maintain compatibility with public code. +ElasticIterator = ElasticIterDataset diff --git a/grain/_src/python/dataset/elastic_iterator_test.py b/grain/_src/python/dataset/elastic_iterator_test.py index 1c4261f09..ffd1bb967 100644 --- a/grain/_src/python/dataset/elastic_iterator_test.py +++ b/grain/_src/python/dataset/elastic_iterator_test.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import platform from absl.testing import absltest from absl.testing import parameterized @@ -20,12 +19,13 @@ from grain._src.python import options from grain._src.python.dataset import dataset from grain._src.python.dataset import elastic_iterator +from grain._src.python.dataset.transformations import interleave import grain._src.python.testing.experimental as test_util import numpy as np @absltest.skipIf(platform.system() == "Windows", "Skipped under bazel.") -class ElasticIteratorTest(parameterized.TestCase): +class ElasticMapDataset(parameterized.TestCase): @parameterized.parameters( dict( @@ -58,12 +58,12 @@ def test_produces_correct_elements( ): ds = dataset.MapDataset.range(10).map(lambda x: x + 1) actual = list( - elastic_iterator.ElasticIterator( + elastic_iterator.ElasticIterDataset( ds, global_batch_size, shard_options, multiprocessing_options=multiprocessing_options, - ) + ).__iter__() ) np.testing.assert_equal( actual, expected, err_msg=f"actual: {actual}, expected: {expected}" @@ -71,17 +71,19 @@ def test_produces_correct_elements( def test_checkpointing(self): ds = dataset.MapDataset.range(100).map(lambda x: x * 2).shuffle(42) - it = elastic_iterator.ElasticIterator(ds, 5, sharding.NoSharding()) + it = elastic_iterator.ElasticIterDataset( + ds, 5, sharding.NoSharding() + ).__iter__() test_util.assert_equal_output_after_checkpoint(it) def test_checkpointing_with_multiprocessing(self): ds = dataset.MapDataset.range(5).map(lambda x: x * 2).shuffle(42) - it = elastic_iterator.ElasticIterator( + it = elastic_iterator.ElasticIterDataset( ds, 2, sharding.NoSharding(), multiprocessing_options=options.MultiprocessingOptions(2), - ) + ).__iter__() test_util.assert_equal_output_after_checkpoint(it) def _elastic_resize_test_base( @@ -116,22 +118,22 @@ def test_elastic_downsize(self): # Create iterators over 32 hosts with per-host batch size 2. def make_iterators_before(): return [ - elastic_iterator.ElasticIterator( + elastic_iterator.ElasticIterDataset( ds, 64, sharding.ShardOptions(shard_index=i, shard_count=32), - ) + ).__iter__() for i in range(32) ] # Create new iterators over 16 hosts with per-host batch size 2. def make_iterators_after(): return [ - elastic_iterator.ElasticIterator( + elastic_iterator.ElasticIterDataset( ds, 32, sharding.ShardOptions(shard_index=i, shard_count=16), - ) + ).__iter__() for i in range(16) ] @@ -147,28 +149,28 @@ def test_elastic_downsize_with_multiprocessing(self): # Create iterators over 8 hosts with per-host batch size 32. def make_iterators_before(): return [ - elastic_iterator.ElasticIterator( + elastic_iterator.ElasticIterDataset( ds, 256, sharding.ShardOptions(shard_index=i, shard_count=8), multiprocessing_options=options.MultiprocessingOptions( num_workers=2 ), - ) + ).__iter__() for i in range(8) ] # Create new iterators over 4 hosts with per-host batch size 32. def make_iterators_after(): return [ - elastic_iterator.ElasticIterator( + elastic_iterator.ElasticIterDataset( ds, 128, sharding.ShardOptions(shard_index=i, shard_count=4), multiprocessing_options=options.MultiprocessingOptions( num_workers=2 ), - ) + ).__iter__() for i in range(4) ] @@ -184,22 +186,22 @@ def test_elastic_upsize(self): # Create iterators over 8 hosts with per-host batch size 16. def make_iterators_before(): return [ - elastic_iterator.ElasticIterator( + elastic_iterator.ElasticIterDataset( ds, 128, sharding.ShardOptions(shard_index=i, shard_count=8), - ) + ).__iter__() for i in range(8) ] # Create new iterators over 64 hosts with per-host batch size 2. def make_iterators_after(): return [ - elastic_iterator.ElasticIterator( + elastic_iterator.ElasticIterDataset( ds, 128, sharding.ShardOptions(shard_index=i, shard_count=64), - ) + ).__iter__() for i in range(64) ] @@ -215,28 +217,28 @@ def test_elastic_upsize_with_multiprocessing(self): # Create iterators over 4 hosts with per-host batch size 16. def make_iterators_before(): return [ - elastic_iterator.ElasticIterator( + elastic_iterator.ElasticIterDataset( ds, 64, sharding.ShardOptions(shard_index=i, shard_count=4), multiprocessing_options=options.MultiprocessingOptions( num_workers=2 ), - ) + ).__iter__() for i in range(4) ] # Create new iterators over 6 hosts with per-host batch size 16. def make_iterators_after(): return [ - elastic_iterator.ElasticIterator( + elastic_iterator.ElasticIterDataset( ds, 96, sharding.ShardOptions(shard_index=i, shard_count=6), multiprocessing_options=options.MultiprocessingOptions( num_workers=2 ), - ) + ).__iter__() for i in range(6) ] @@ -249,9 +251,106 @@ def test_filter_raises_error(self): ds = ds.filter(lambda x: x % 2 == 0) with self.assertRaisesRegex( ValueError, - "ElasticIterator does not support `filter` transformation.", + "ElasticIterDataset does not support `filter` transformation.", ): - elastic_iterator.ElasticIterator(ds, 5, sharding.NoSharding()) + elastic_iterator.ElasticIterDataset( + ds, 5, sharding.NoSharding() + ).__iter__() + + +class ElasticIterDataset(parameterized.TestCase): + + @parameterized.parameters( + dict( + shard_options=sharding.NoSharding(), + global_batch_size=1, + expected=list(range(15)), + ), + dict( + shard_options=sharding.ShardOptions(shard_index=0, shard_count=1), + global_batch_size=1, + expected=list(range(15)), + ), + dict( + shard_options=sharding.NoSharding(), + global_batch_size=3, + # Data is interleaved with cycle length 3. + expected=[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8, 13], [4, 9, 14]], + ), + ) + def test_no_sharding_produces_correct_elements( + self, shard_options, global_batch_size, expected + ): + ds = [ + # 3 shards, each with 5 elements. + dataset.MapDataset.range(i * 5, (i + 1) * 5).to_iter_dataset() + for i in range(3) + ] + interleave_ds = interleave.InterleaveIterDataset( + ds, cycle_length=global_batch_size + ) + it = elastic_iterator.ElasticIterDataset( + interleave_ds, + shard_options=shard_options, + global_batch_size=global_batch_size, + ).__iter__() + actual = list(it) + self.assertLen(actual, len(expected)) + for actual_batch, expected_batch in zip(actual, expected): + np.testing.assert_equal(actual_batch, expected_batch) + + @parameterized.parameters( + dict( + shard_options=sharding.ShardOptions(shard_index=0, shard_count=2), + global_batch_size=1, + expected=[0, 2, 4, 6, 8], + ), + dict( + shard_options=sharding.ShardOptions(shard_index=1, shard_count=2), + global_batch_size=1, + expected=[1, 3, 5, 7, 9], + ), + dict( + shard_options=sharding.ShardOptions(shard_index=0, shard_count=2), + global_batch_size=2, + expected=[[0, 2], [4, 6], [8]], + ), + ) + def test_sharding_produces_correct_elements( + self, shard_options, global_batch_size, expected + ): + ds = [ + # 4 shards, 0: [0, 4, 8], 1: [1, 5, 9], 2: [2, 6], 3: [3, 7] + dataset.MapDataset.range(i, 10, 4).to_iter_dataset() + for i in range(4) + ] + # Use cycle_length=2 as in the original test. + interleave_ds = interleave.InterleaveIterDataset(ds, cycle_length=2) + it = elastic_iterator.ElasticIterDataset( + interleave_ds, + shard_options=shard_options, + global_batch_size=global_batch_size, + ).__iter__() + actual = list(it) + self.assertLen(actual, len(expected)) + for actual_batch, expected_batch in zip(actual, expected): + np.testing.assert_equal(actual_batch, expected_batch) + + def test_checkpointing_no_change(self): + ds = [ + dataset.MapDataset.range(i, 100, 25).to_iter_dataset() + for i in range(25) + ] + global_batch_size = 2 + interleave_ds = interleave.InterleaveIterDataset( + ds, cycle_length=global_batch_size + ) + it = elastic_iterator.ElasticIterDataset( + interleave_ds, + shard_options=sharding.ShardOptions(shard_index=2, shard_count=4), + global_batch_size=global_batch_size, + ).__iter__() + test_util.assert_equal_output_after_checkpoint(it) if __name__ == "__main__": diff --git a/grain/_src/python/dataset/transformations/interleave.py b/grain/_src/python/dataset/transformations/interleave.py index 9db05ba90..a3a299d12 100644 --- a/grain/_src/python/dataset/transformations/interleave.py +++ b/grain/_src/python/dataset/transformations/interleave.py @@ -52,11 +52,11 @@ def __init__( functools.partial( _add_prefetch_and_make_iterator, # We use weakref to avoid a circular reference. The - # _InterleaveDatasetIterator holds a reference to the + # InterleaveDatasetIterator holds a reference to the # prefetch iterator in `self._prefetch_ds_iter`. # The call to `_add_prefetch_and_make_iterator` (and the # partial object) would hold a reference to the - # _InterleaveDatasetIterator. This would prolong its lifetime + # InterleaveDatasetIterator. This would prolong its lifetime # leading to increased resource usage. interleave_iterator=weakref.ref(self), start_prefetch=True, @@ -86,6 +86,8 @@ def __init__( self._exhausted_iterators: list[ tuple[int, dataset.DatasetIterator[T]] | None ] = [None] * self._cycle_length + # Future states used for elastic iterators + self._future_states: dict[int, Any] = {} @stats.record_next_duration_if_output @stats.trace_input_pipeline_next(stage_category=stats.IPL_CAT_PREPROCESSING) @@ -142,6 +144,16 @@ def __next__(self) -> T: self._iterators_in_use_indices[self._next_index_in_cycle] = ( self._next_index_in_datasets ) + # For elastic iterators, we might have a future state saved for this + # dataset iterator from which to resume from. + if ( + self._next_index_in_datasets in self._future_states + and self._iterators_in_use[self._next_index_in_cycle] + ): + future_state = self._future_states.pop(self._next_index_in_datasets) + self._iterators_in_use[self._next_index_in_cycle].set_state( + future_state + ) self._next_index_in_datasets += 1 elif not any(self._iterators_in_use): raise StopIteration @@ -182,13 +194,15 @@ def get_state(self): int(self._exhausted_iterator_state[i] is not None) for i in range(self._cycle_length) ] - return { + state = { "next_index_in_cycle": self._next_index_in_cycle, "next_index_in_datasets": self._next_index_in_datasets, "iterators_in_use_indices": self._iterators_in_use_indices.copy(), "iterators_in_use_states": iterators_in_use_states, "exhausted": exhausted, + "future_states": self._future_states, } + return state def set_state(self, state): exhausted = state["exhausted"] @@ -220,7 +234,9 @@ def set_state(self, state): interleave_iterator=weakref.ref(self), start_prefetch=False, ) - iterator.set_state(it_state) + # Only update the iterator state if it is given + if it_state: + iterator.set_state(it_state) self._iterators_in_use[index_in_cycle] = iterator else: self._exhausted_iterator_state[index_in_cycle] = it_state @@ -232,6 +248,7 @@ def set_state(self, state): self._next_index_in_cycle = state["next_index_in_cycle"] self._next_index_in_datasets = state["next_index_in_datasets"] self._iterators_in_use_indices = state["iterators_in_use_indices"] + self._future_states = state.get("future_states", {}) def _get_next_index(self) -> int: if len(self._datasets) == 1: @@ -307,6 +324,16 @@ def __str__(self) -> str: f" cycle_length={self._cycle_length})" ) + def _get_iterator_start_state(self, index: int) -> dict[str, Any]: + it = _add_prefetch_and_make_iterator( + self._datasets[index], + weakref.ref(self), + start_prefetch=False, + ) + state = it.get_state() + del it + return state + def _add_prefetch_and_make_iterator( ds: dataset.IterDataset[T] | dataset.MapDataset[T], diff --git a/grain/_src/python/dataset/transformations/interleave_test.py b/grain/_src/python/dataset/transformations/interleave_test.py index c5d647307..a5ed001a8 100644 --- a/grain/_src/python/dataset/transformations/interleave_test.py +++ b/grain/_src/python/dataset/transformations/interleave_test.py @@ -291,6 +291,37 @@ def test_set_next_index_with_multiple_datasets(self): ): dataset.set_next_index(ds_iter, 0) + def test_future_states(self): + datasets = [ + dataset.MapDataset.source([1, 2]).to_iter_dataset(), + dataset.MapDataset.source([3, 4]).to_iter_dataset(), + ] + ds = interleave.InterleaveIterDataset(datasets, cycle_length=1) + ds_iter = ds.__iter__() + + # Initialize the first iterator and get state. + state = ds_iter.get_state() + + # Get state for the second dataset iterator after advancing it. + ds1_iter = datasets[1].__iter__() + next(ds1_iter) # Consumes 3 + ds1_state = ds1_iter.get_state() + + # Inject future state for the second dataset (index 1). + state["future_states"] = {1: ds1_state} + + ds_iter.set_state(state) + + # Consume elements. + # It should yield elements from the first dataset (1, 2) and then + # yield elements from the second dataset starting from the future state (4). + self.assertEqual(next(ds_iter), 1) + self.assertEqual(next(ds_iter), 2) + self.assertEqual(next(ds_iter), 4) + + with self.assertRaises(StopIteration): + next(ds_iter) + if __name__ == "__main__": absltest.main() diff --git a/grain/experimental.py b/grain/experimental.py index 297c40b56..7537b4aa8 100644 --- a/grain/experimental.py +++ b/grain/experimental.py @@ -32,7 +32,10 @@ apply_transformations, WithOptionsIterDataset, ) -from grain._src.python.dataset.elastic_iterator import ElasticIterator +from grain._src.python.dataset.elastic_iterator import ( + ElasticIterDatasetIterator, + ElasticIterDataset, +) from grain._src.python.dataset.sources.parquet_dataset import ParquetIterDataset from grain._src.python.dataset.sources.tfrecord_dataset import TFRecordIterDataset