diff --git a/grain/_src/python/dataset/transformations/interleave.py b/grain/_src/python/dataset/transformations/interleave.py index 9db05ba90..8a49263db 100644 --- a/grain/_src/python/dataset/transformations/interleave.py +++ b/grain/_src/python/dataset/transformations/interleave.py @@ -13,11 +13,13 @@ # limitations under the License. """Implements dataset interleaving.""" +import collections from collections.abc import Sequence import functools from typing import Any, TypeVar import weakref +from concurrent import futures from grain._src.python import options as grain_options from grain._src.python.dataset import base from grain._src.python.dataset import dataset @@ -436,3 +438,307 @@ def __str__(self) -> str: def _element_spec(self) -> Any: # Assumes that interleaved datasets have the same element spec. return dataset.get_element_spec(self._datasets[0]) + + +class TunableInterleaveDatasetIterator(dataset.DatasetIterator[T]): + """Iterator for InterleaveIterDataset.""" + + def __init__( + self, + datasets: Sequence[dataset.IterDataset[T] | dataset.MapDataset[T]], + cycle_length: int, + num_make_iter_threads: int = 1, + make_iter_buffer_size: int = 1, + iter_buffer_size: int = 1, + ): + super().__init__() + self._datasets = datasets + self._cycle_length = cycle_length + self._num_make_iter_threads = num_make_iter_threads + self._make_iter_buffer_size = make_iter_buffer_size + self._iter_buffer_size = iter_buffer_size + self._started = False + + # cycle_length iterators that are in use. + self._iterators_in_use: list[dataset.DatasetIterator[T] | None] = [ + None + ] * self._cycle_length + # Indices of datasets being iterated over in _iterators_in_use. + self._iterators_in_use_indices: list[int] = [ + -1 for _ in range(self._cycle_length) + ] + # make_iter_buffer_size iterators that are prefetched in the background. + self._queued_iterators: collections.deque[ + futures.Future[dataset.DatasetIterator[T]] + ] = collections.deque() + # Executor for creating iterators and starting prefetch asynchronously. + self._executor: futures.ThreadPoolExecutor | None = None + # Index of the next iterator to be used in _iterators_in_use. + self._next_index_in_cycle: int = 0 + # Index of the next iterator to be put into _iterators_in_use. + self._next_index_in_datasets: int = 0 + # Index of the next unbuffered iterator. + self._next_index_in_unbuffered_datasets: int = 0 + # Future states of the iterators that were being used before cycle_length + # was reduced. + self._future_states: dict[str, dict[str, Any]] = {} + # Placeholder state for use when datasets are exhausted. This is used for + # Pathways Remote Python , which requires state spec to remain consistent. + # This works when the input datasets have the same state spec. + self._placeholder_state: dict[str, Any] | None = None + + def _increment_next_index_in_cycle(self): + self._next_index_in_cycle = ( + self._next_index_in_cycle + 1 + ) % self._cycle_length + + def _fill_queued_iterators(self): + if self._make_iter_buffer_size == 0 or self._num_make_iter_threads == 0: + # Iterator prefetching is not possible when threads or buffer size are 0. + return + while len(self._queued_iterators) < self._make_iter_buffer_size: + if self._next_index_in_unbuffered_datasets >= len(self._datasets): + break + future = self._create_iterator_asynchronously( + self._next_index_in_unbuffered_datasets + ) + self._queued_iterators.append(future) + self._next_index_in_unbuffered_datasets += 1 + + def start_prefetch(self): + if self._started: + return + if self._num_make_iter_threads > 0 and self._make_iter_buffer_size > 0: + assert self._executor is None + self._executor = futures.ThreadPoolExecutor( + self._num_make_iter_threads, thread_name_prefix="grain-interleave" + ) + self._fill_queued_iterators() + self._started = True + + def _stop_prefetch(self): + if not self._started: + return + self._started = False + while self._queued_iterators: + future = self._queued_iterators.popleft() + future.cancel() + self._next_index_in_unbuffered_datasets = self._next_index_in_datasets + if self._executor is not None: + self._executor.shutdown(wait=True) + self._executor = None + + def close(self): + for iterator in self._iterators_in_use: + if iterator is not None: + iterator.close() + for future in self._queued_iterators: + iterator = future.result() + iterator.close() + self._stop_prefetch() + + def _set_cycle_length(self, new_cycle_length: int): + if new_cycle_length < 0: + raise ValueError("cycle_length must be non-negative.") + if new_cycle_length > len(self._datasets): + new_cycle_length = len(self._datasets) + if new_cycle_length == self._cycle_length: + return + + if new_cycle_length < self._cycle_length: + # Move the iterators that are no longer in use to the queued iterators. + for i in reversed(range(new_cycle_length, self._cycle_length)): + it = self._iterators_in_use[i] + if it is not None: + future = futures.Future() + future.set_result(it) + self._queued_iterators.appendleft(future) + # Ensure the queue size is respected. + while len(self._queued_iterators) > new_cycle_length: + future = self._queued_iterators.pop() + future.cancel() + # TODO: Update _future_states for iterators that are moved out + # of _iterators_in_use. + + if new_cycle_length > self._cycle_length: + # Increase the size of the lists to accommodate the new cycle length. + self._iterators_in_use.extend( + [None] * (new_cycle_length - self._cycle_length) + ) + self._iterators_in_use_indices.extend( + [-1] * (new_cycle_length - self._cycle_length) + ) + + self._cycle_length = new_cycle_length + + def _set_iter_buffer_size(self, new_iter_buffer_size: int): + # Subsequent iterators will have the new buffer size. + if new_iter_buffer_size < 0: + raise ValueError("iter_buffer_size must be non-negative.") + self._iter_buffer_size = new_iter_buffer_size + + def _set_make_iter_buffer_size(self, new_make_iter_buffer_size: int): + if new_make_iter_buffer_size < 0: + raise ValueError("make_iter_buffer_size must be non-negative.") + while len(self._queued_iterators) > new_make_iter_buffer_size: + future = self._queued_iterators.pop() + future.cancel() + self._next_index_in_unbuffered_datasets -= 1 + self._make_iter_buffer_size = new_make_iter_buffer_size + + def _set_num_make_iter_threads(self, new_num_make_iter_threads: int): + if new_num_make_iter_threads < 0: + raise ValueError("num_make_iter_threads must be non-negative.") + self._num_make_iter_threads = new_num_make_iter_threads + if self._make_iter_buffer_size == 0: + return + old_executor = self._executor + if self._num_make_iter_threads > 0: + self._executor = futures.ThreadPoolExecutor( + self._num_make_iter_threads, thread_name_prefix="grain-prefetch" + ) + else: + self._executor = None + if old_executor is not None: + # Allows the old executor to finish running the tasks it was already + # assigned asynchronously. + old_executor.shutdown(wait=False) + + def _create_iterator_synchronously( + self, index: int + ) -> dataset.DatasetIterator[T]: + return _add_prefetch_and_make_iterator( + self._datasets[index], + interleave_iterator=weakref.ref(self), + start_prefetch=True, + ) + + def _create_iterator_asynchronously( + self, index: int + ) -> futures.Future[dataset.DatasetIterator[T]]: + if self._executor is None: + raise ValueError("Executor has not been initialized.") + return self._executor.submit( + _add_prefetch_and_make_iterator, + self._datasets[index], + interleave_iterator=weakref.ref(self), + start_prefetch=True, + ) + + def __next__(self) -> T: + self._assert_not_closed() + self.start_prefetch() + while True: + self._fill_queued_iterators() + # If the slot is empty, fill it with the next iterator. + if self._iterators_in_use[self._next_index_in_cycle] is None: + if self._next_index_in_datasets < len(self._datasets): + if self._queued_iterators: + # Case of asynchronous iterator creation. + future = self._queued_iterators.popleft() + self._fill_queued_iterators() + self._iterators_in_use[self._next_index_in_cycle] = future.result() + else: + # Case of synchronous iterator creation. + self._iterators_in_use[self._next_index_in_cycle] = ( + self._create_iterator_synchronously( + self._next_index_in_datasets + ) + ) + self._next_index_in_unbuffered_datasets += 1 + self._iterators_in_use_indices[self._next_index_in_cycle] = ( + self._next_index_in_datasets + ) + self._next_index_in_datasets += 1 + # If the slot is not empty, try to get the next element from the iterator. + if self._iterators_in_use[self._next_index_in_cycle] is not None: + try: + element = next(self._iterators_in_use[self._next_index_in_cycle]) + self._increment_next_index_in_cycle() + return element + except StopIteration: + self._iterators_in_use[self._next_index_in_cycle] = None + if not any(self._iterators_in_use): + # All datasets have been exhausted. + self._stop_prefetch() + raise StopIteration + self._increment_next_index_in_cycle() + + def get_state(self): + if self._placeholder_state is None: + # This placeholder state allows state spec to remain consistent for + # Pathways Remote Python. + self._placeholder_state = _add_prefetch_and_make_iterator( + self._datasets[0], + interleave_iterator=weakref.ref(self), + start_prefetch=False, + ).get_state() + iterators_in_use_states = [ + it.get_state() if it is not None else self._placeholder_state + for it in self._iterators_in_use + ] + iterator_present = [ + 1 if it is not None else 0 for it in self._iterators_in_use + ] + return { + "cycle_length": self._cycle_length, + "iter_buffer_size": self._iter_buffer_size, + "make_iter_buffer_size": self._make_iter_buffer_size, + "num_make_iter_threads": self._num_make_iter_threads, + "next_index_in_cycle": self._next_index_in_cycle, + "next_index_in_datasets": self._next_index_in_datasets, + "iterators_in_use_indices": self._iterators_in_use_indices, + "future_states": self._future_states, + "iterators_in_use_states": iterators_in_use_states, + "iterator_present": iterator_present, + "placeholder_state": self._placeholder_state, + } + + def set_state(self, state): + # Resize before setting state to avoid issues with mismatched list sizes + # or mismatched thread count. + if self._cycle_length != state["cycle_length"]: + self._set_cycle_length(state["cycle_length"]) + if self._iter_buffer_size != state["iter_buffer_size"]: + self._set_iter_buffer_size(state["iter_buffer_size"]) + if self._make_iter_buffer_size != state["make_iter_buffer_size"]: + self._set_make_iter_buffer_size(state["make_iter_buffer_size"]) + if self._num_make_iter_threads != state["num_make_iter_threads"]: + self._set_num_make_iter_threads(state["num_make_iter_threads"]) + + for i, iterator_state in enumerate(state["iterators_in_use_states"]): + if state["iterator_present"][i] == 1: + if ( + self._iterators_in_use[i] is not None + and state["iterators_in_use_indices"][i] + == self._iterators_in_use_indices[i] + ): + # The iterator currently in use is the same on specified in the state. + # We can set the state of the iterator without recreating it. + it = self._iterators_in_use[i] + assert it is not None + it.set_state(iterator_state) + elif state["iterators_in_use_indices"][i] != -1: + # The iterator currently in use is different from the one specified + # in the state. We need to recreate the iterator. + it = self._create_iterator_synchronously( + state["iterators_in_use_indices"][i] + ) + it.set_state(iterator_state) + self._iterators_in_use[i] = it + else: + # There is no iterator specified in the state for this position. + self._iterators_in_use[i] = None + + if self._next_index_in_datasets != state["next_index_in_datasets"]: + # The queued iterators are no longer valid. We need to cancel them. + while self._queued_iterators: + future = self._queued_iterators.popleft() + future.cancel() + self._next_index_in_unbuffered_datasets = state["next_index_in_datasets"] + + self._next_index_in_cycle = state["next_index_in_cycle"] + self._next_index_in_datasets = state["next_index_in_datasets"] + self._iterators_in_use_indices = state["iterators_in_use_indices"] + self._future_states = state["future_states"] + self._placeholder_state = state["placeholder_state"]