Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
306 changes: 306 additions & 0 deletions grain/_src/python/dataset/transformations/interleave.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.
"""Implements dataset interleaving."""

import collections
from collections.abc import Sequence
import functools
from typing import Any, TypeVar
import weakref

from concurrent import futures
from grain._src.python import options as grain_options
from grain._src.python.dataset import base
from grain._src.python.dataset import dataset
Expand Down Expand Up @@ -436,3 +438,307 @@ def __str__(self) -> str:
def _element_spec(self) -> Any:
# Assumes that interleaved datasets have the same element spec.
return dataset.get_element_spec(self._datasets[0])


class TunableInterleaveDatasetIterator(dataset.DatasetIterator[T]):
"""Iterator for InterleaveIterDataset."""

def __init__(
self,
datasets: Sequence[dataset.IterDataset[T] | dataset.MapDataset[T]],
cycle_length: int,
num_make_iter_threads: int = 1,
make_iter_buffer_size: int = 1,
iter_buffer_size: int = 1,
):
super().__init__()
self._datasets = datasets
self._cycle_length = cycle_length
self._num_make_iter_threads = num_make_iter_threads
self._make_iter_buffer_size = make_iter_buffer_size
self._iter_buffer_size = iter_buffer_size
self._started = False

# cycle_length iterators that are in use.
self._iterators_in_use: list[dataset.DatasetIterator[T] | None] = [
None
] * self._cycle_length
# Indices of datasets being iterated over in _iterators_in_use.
self._iterators_in_use_indices: list[int] = [
-1 for _ in range(self._cycle_length)
]
# make_iter_buffer_size iterators that are prefetched in the background.
self._queued_iterators: collections.deque[
futures.Future[dataset.DatasetIterator[T]]
] = collections.deque()
# Executor for creating iterators and starting prefetch asynchronously.
self._executor: futures.ThreadPoolExecutor | None = None
# Index of the next iterator to be used in _iterators_in_use.
self._next_index_in_cycle: int = 0
# Index of the next iterator to be put into _iterators_in_use.
self._next_index_in_datasets: int = 0
# Index of the next unbuffered iterator.
self._next_index_in_unbuffered_datasets: int = 0
# Future states of the iterators that were being used before cycle_length
# was reduced.
self._future_states: dict[str, dict[str, Any]] = {}
# Placeholder state for use when datasets are exhausted. This is used for
# Pathways Remote Python , which requires state spec to remain consistent.
# This works when the input datasets have the same state spec.
self._placeholder_state: dict[str, Any] | None = None

def _increment_next_index_in_cycle(self):
self._next_index_in_cycle = (
self._next_index_in_cycle + 1
) % self._cycle_length

def _fill_queued_iterators(self):
if self._make_iter_buffer_size == 0 or self._num_make_iter_threads == 0:
# Iterator prefetching is not possible when threads or buffer size are 0.
return
while len(self._queued_iterators) < self._make_iter_buffer_size:
if self._next_index_in_unbuffered_datasets >= len(self._datasets):
break
future = self._create_iterator_asynchronously(
self._next_index_in_unbuffered_datasets
)
self._queued_iterators.append(future)
self._next_index_in_unbuffered_datasets += 1

def start_prefetch(self):
if self._started:
return
if self._num_make_iter_threads > 0 and self._make_iter_buffer_size > 0:
assert self._executor is None
self._executor = futures.ThreadPoolExecutor(
self._num_make_iter_threads, thread_name_prefix="grain-interleave"
)
self._fill_queued_iterators()
self._started = True

def _stop_prefetch(self):
if not self._started:
return
self._started = False
while self._queued_iterators:
future = self._queued_iterators.popleft()
future.cancel()
self._next_index_in_unbuffered_datasets = self._next_index_in_datasets
if self._executor is not None:
self._executor.shutdown(wait=True)
self._executor = None

def close(self):
for iterator in self._iterators_in_use:
if iterator is not None:
iterator.close()
for future in self._queued_iterators:
iterator = future.result()
iterator.close()
self._stop_prefetch()

def _set_cycle_length(self, new_cycle_length: int):
if new_cycle_length < 0:
raise ValueError("cycle_length must be non-negative.")
if new_cycle_length > len(self._datasets):
new_cycle_length = len(self._datasets)
if new_cycle_length == self._cycle_length:
return

if new_cycle_length < self._cycle_length:
# Move the iterators that are no longer in use to the queued iterators.
for i in reversed(range(new_cycle_length, self._cycle_length)):
it = self._iterators_in_use[i]
if it is not None:
future = futures.Future()
future.set_result(it)
self._queued_iterators.appendleft(future)
# Ensure the queue size is respected.
while len(self._queued_iterators) > new_cycle_length:
future = self._queued_iterators.pop()
future.cancel()
# TODO: Update _future_states for iterators that are moved out
# of _iterators_in_use.

if new_cycle_length > self._cycle_length:
# Increase the size of the lists to accommodate the new cycle length.
self._iterators_in_use.extend(
[None] * (new_cycle_length - self._cycle_length)
)
self._iterators_in_use_indices.extend(
[-1] * (new_cycle_length - self._cycle_length)
)

self._cycle_length = new_cycle_length

def _set_iter_buffer_size(self, new_iter_buffer_size: int):
# Subsequent iterators will have the new buffer size.
if new_iter_buffer_size < 0:
raise ValueError("iter_buffer_size must be non-negative.")
self._iter_buffer_size = new_iter_buffer_size

def _set_make_iter_buffer_size(self, new_make_iter_buffer_size: int):
if new_make_iter_buffer_size < 0:
raise ValueError("make_iter_buffer_size must be non-negative.")
while len(self._queued_iterators) > new_make_iter_buffer_size:
future = self._queued_iterators.pop()
future.cancel()
self._next_index_in_unbuffered_datasets -= 1
self._make_iter_buffer_size = new_make_iter_buffer_size

def _set_num_make_iter_threads(self, new_num_make_iter_threads: int):
if new_num_make_iter_threads < 0:
raise ValueError("num_make_iter_threads must be non-negative.")
self._num_make_iter_threads = new_num_make_iter_threads
if self._make_iter_buffer_size == 0:
return
old_executor = self._executor
if self._num_make_iter_threads > 0:
self._executor = futures.ThreadPoolExecutor(
self._num_make_iter_threads, thread_name_prefix="grain-prefetch"
)
else:
self._executor = None
if old_executor is not None:
# Allows the old executor to finish running the tasks it was already
# assigned asynchronously.
old_executor.shutdown(wait=False)

def _create_iterator_synchronously(
self, index: int
) -> dataset.DatasetIterator[T]:
return _add_prefetch_and_make_iterator(
self._datasets[index],
interleave_iterator=weakref.ref(self),
start_prefetch=True,
)

def _create_iterator_asynchronously(
self, index: int
) -> futures.Future[dataset.DatasetIterator[T]]:
if self._executor is None:
raise ValueError("Executor has not been initialized.")
return self._executor.submit(
_add_prefetch_and_make_iterator,
self._datasets[index],
interleave_iterator=weakref.ref(self),
start_prefetch=True,
)

def __next__(self) -> T:
self._assert_not_closed()
self.start_prefetch()
while True:
self._fill_queued_iterators()
# If the slot is empty, fill it with the next iterator.
if self._iterators_in_use[self._next_index_in_cycle] is None:
if self._next_index_in_datasets < len(self._datasets):
if self._queued_iterators:
# Case of asynchronous iterator creation.
future = self._queued_iterators.popleft()
self._fill_queued_iterators()
self._iterators_in_use[self._next_index_in_cycle] = future.result()
else:
# Case of synchronous iterator creation.
self._iterators_in_use[self._next_index_in_cycle] = (
self._create_iterator_synchronously(
self._next_index_in_datasets
)
)
self._next_index_in_unbuffered_datasets += 1
self._iterators_in_use_indices[self._next_index_in_cycle] = (
self._next_index_in_datasets
)
self._next_index_in_datasets += 1
# If the slot is not empty, try to get the next element from the iterator.
if self._iterators_in_use[self._next_index_in_cycle] is not None:
try:
element = next(self._iterators_in_use[self._next_index_in_cycle])
self._increment_next_index_in_cycle()
return element
except StopIteration:
self._iterators_in_use[self._next_index_in_cycle] = None
if not any(self._iterators_in_use):
# All datasets have been exhausted.
self._stop_prefetch()
raise StopIteration
self._increment_next_index_in_cycle()

def get_state(self):
if self._placeholder_state is None:
# This placeholder state allows state spec to remain consistent for
# Pathways Remote Python.
self._placeholder_state = _add_prefetch_and_make_iterator(
self._datasets[0],
interleave_iterator=weakref.ref(self),
start_prefetch=False,
).get_state()
iterators_in_use_states = [
it.get_state() if it is not None else self._placeholder_state
for it in self._iterators_in_use
]
iterator_present = [
1 if it is not None else 0 for it in self._iterators_in_use
]
return {
"cycle_length": self._cycle_length,
"iter_buffer_size": self._iter_buffer_size,
"make_iter_buffer_size": self._make_iter_buffer_size,
"num_make_iter_threads": self._num_make_iter_threads,
"next_index_in_cycle": self._next_index_in_cycle,
"next_index_in_datasets": self._next_index_in_datasets,
"iterators_in_use_indices": self._iterators_in_use_indices,
"future_states": self._future_states,
"iterators_in_use_states": iterators_in_use_states,
"iterator_present": iterator_present,
"placeholder_state": self._placeholder_state,
}

def set_state(self, state):
# Resize before setting state to avoid issues with mismatched list sizes
# or mismatched thread count.
if self._cycle_length != state["cycle_length"]:
self._set_cycle_length(state["cycle_length"])
if self._iter_buffer_size != state["iter_buffer_size"]:
self._set_iter_buffer_size(state["iter_buffer_size"])
if self._make_iter_buffer_size != state["make_iter_buffer_size"]:
self._set_make_iter_buffer_size(state["make_iter_buffer_size"])
if self._num_make_iter_threads != state["num_make_iter_threads"]:
self._set_num_make_iter_threads(state["num_make_iter_threads"])

for i, iterator_state in enumerate(state["iterators_in_use_states"]):
if state["iterator_present"][i] == 1:
if (
self._iterators_in_use[i] is not None
and state["iterators_in_use_indices"][i]
== self._iterators_in_use_indices[i]
):
# The iterator currently in use is the same on specified in the state.
# We can set the state of the iterator without recreating it.
it = self._iterators_in_use[i]
assert it is not None
it.set_state(iterator_state)
elif state["iterators_in_use_indices"][i] != -1:
# The iterator currently in use is different from the one specified
# in the state. We need to recreate the iterator.
it = self._create_iterator_synchronously(
state["iterators_in_use_indices"][i]
)
it.set_state(iterator_state)
self._iterators_in_use[i] = it
else:
# There is no iterator specified in the state for this position.
self._iterators_in_use[i] = None

if self._next_index_in_datasets != state["next_index_in_datasets"]:
# The queued iterators are no longer valid. We need to cancel them.
while self._queued_iterators:
future = self._queued_iterators.popleft()
future.cancel()
self._next_index_in_unbuffered_datasets = state["next_index_in_datasets"]

self._next_index_in_cycle = state["next_index_in_cycle"]
self._next_index_in_datasets = state["next_index_in_datasets"]
self._iterators_in_use_indices = state["iterators_in_use_indices"]
self._future_states = state["future_states"]
self._placeholder_state = state["placeholder_state"]
Loading