Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/advanced/packing.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ and {py:meth}`pack_selected_samples <megatron.energon.TaskEncoder.pack_selected_
Furthermore, you need to initialize the loader with the `packing_buffer_size` argument set to a non-zero number.

The `select_samples_to_pack` method will receive a list of samples (size according to the selected `packing_buffer_size`),
and should partition those samples into groups that shall be packed together. Hence the function returns
a list of lists of samples.
and should partition those samples into groups that shall be packed together. Hence the function typically returns
a list of lists of samples. Alternatively it may return {py:class}`PackedSamplesOutput <megatron.energon.PackedSamplesOutput>` with a ``pushback`` sequence: those samples are appended back to the reading buffer before the next fill from the dataset.

For each group, the second method `pack_selected_samples` will be called. You need to implement how a group of
samples will be mapped to a single sample. In terms of LLMs for example, this method might concatenate the input tokens.
Expand Down
4 changes: 2 additions & 2 deletions docs/source/basic/task_encoder.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ and override one or more of the following methods. The data flow of {py:func}`ge
- {py:meth}`def preencode_sample(self, sample: T_sample) -> T_sample <megatron.energon.TaskEncoder.preencode_sample>`
- Can be used as a replacement to {py:func}`encode_sample <megatron.energon.DefaultTaskEncoder.encode_sample>`
- Only used together with {py:meth}`postencode_sample <megatron.energon.TaskEncoder.postencode_sample>` (below). Use it if you need to instantiate lazy data before packing the selected samples. But it will also work when not using packing.
- {py:meth}`def select_samples_to_pack(self, samples: List[T_encoded_sample]) -> List[List[T_encoded_sample]] <megatron.energon.TaskEncoder.select_samples_to_pack>`
- Optional. Allows for efficient sample packing. See [](../advanced/packing).
- {py:meth}`def select_samples_to_pack(self, samples: List[T_encoded_sample]) -> Union[List[List[T_encoded_sample]], PackedSamplesOutput] <megatron.energon.TaskEncoder.select_samples_to_pack>`
- Optional. Allows for efficient sample packing. See [](../advanced/packing). May return {py:class}`PackedSamplesOutput <megatron.energon.PackedSamplesOutput>` to re-queue a pushback sequence onto the reading buffer.
- {py:meth}`def postencode_sample(self, sample: T_sample) -> T_encoded_sample <megatron.energon.TaskEncoder.postencode_sample>`
- Only used together with {py:meth}`preencode_sample <megatron.energon.TaskEncoder.preencode_sample>`. Use it if you need to instantiate lazy data before packing the selected samples. But it will also work when not using packing.
- {py:meth}`def pack_selected_samples(self, samples: List[T_encoded_sample]) -> T_batch_sample] <megatron.energon.TaskEncoder.pack_selected_samples>`
Expand Down
2 changes: 2 additions & 0 deletions src/megatron/energon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
LogSampleDataset,
MapDataset,
MixBatchDataset,
PackedSamplesOutput,
PackingDataset,
RepeatDataset,
ShuffleBufferDataset,
Expand Down Expand Up @@ -186,6 +187,7 @@
"PrimaryFileStore",
"prepare_metadataset",
"RepeatDataset",
"PackedSamplesOutput",
"reraise_exception",
"Sample",
"SampleDecoder",
Expand Down
2 changes: 2 additions & 0 deletions src/megatron/energon/metadataset/dataset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def get_datasets(
subflavors: Optional[Dict[str, Any]] = None,
shuffle_over_epochs_multiplier: Optional[int] = 1,
subset: Optional[DatasetSubset] = None,
group: Optional[str] = None,
**kwargs,
) -> LoadedDatasetList:
return LoadedDatasetList(
Expand All @@ -106,6 +107,7 @@ def get_datasets(
**kwargs,
),
weight=None,
group=group,
)
],
)
2 changes: 2 additions & 0 deletions src/megatron/energon/metadataset/join_dataset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,7 @@ def get_datasets(
subflavors: Optional[Dict[str, Any]] = None,
shuffle_over_epochs_multiplier: Optional[int] = 1,
subset: Optional[DatasetSubset] = None,
group: Optional[str] = None,
**kwargs,
) -> LoadedDatasetList:
return LoadedDatasetList(
Expand All @@ -557,6 +558,7 @@ def get_datasets(
**kwargs,
),
weight=None,
group=group,
)
],
)
28 changes: 28 additions & 0 deletions src/megatron/energon/metadataset/loader_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,17 @@ class DatasetBlendMode(Enum):

@edataclass
class LoadedDataset:
#: The dataset factory.
dataset: BaseCoreDatasetFactory
#: Sampling weight when using dataset-weight blending.
weight: Union[float, int, None] = None
#: Epochized repetition count when using repetition-based blending.
repetitions: Union[float, int, None] = None
#: Dataset group key from Metadataset V2 (YAML ``group`` field). Must match keys used in
#: ``packing_buffer_size`` / ``shuffle_buffer_size`` when those are dicts. ``None`` is the
#: default group.
group: Optional[str] = None
#: Auxiliary datasets for crude cooking.
aux: Optional[Dict[str, FileStore]] = None


Expand All @@ -48,12 +56,18 @@ class TraversedDatasetReference:
split_part: Effective split part to use when loading the leaf dataset.
aux: Resolved auxiliary dataset or filesystem references keyed by auxiliary name.
subflavors: Effective subflavors implied by the traversed metadataset hierarchy.
group: Merged dataset group name from Metadataset V2 ``group`` fields along the path to this
leaf (nested segments join with ``+``). Keys dict entries for per-group
``packing_buffer_size`` and ``shuffle_buffer_size``.
shuffle_over_epochs_multiplier: Effective shuffle over epochs multiplier from metadataset references.
"""

path: EPath
split_part: str
aux: dict[str, EPath]
subflavors: dict[str, Any]
group: Optional[str] = None
shuffle_over_epochs_multiplier: Optional[int] = 1


class DatasetLoaderInterface(ABC):
Expand All @@ -69,6 +83,8 @@ def traverse(
mds_path: Optional[EPath] = None,
*,
split_part: Union[Literal["train", "val", "test"], str],
_group: Optional[str] = None,
_shuffle_over_epochs_multiplier: Optional[int] = 1,
_subflavors: Optional[Dict[str, Any]] = None,
) -> List[TraversedDatasetReference]:
"""Traverse a metadataset subtree and collect flattened leaf dataset references.
Expand All @@ -83,6 +99,12 @@ def traverse(
use None only for top-level metadatasets.
split_part: Split to traverse, such as `\"train\"`, `\"val\"`, or `\"test\"`. Nested
references may override this with their own configured split.
_group: Inherited merged group name from Metadataset V2 ``group`` fields (nested segments
join with ``+``). Used to match dict keys in ``packing_buffer_size`` /
``shuffle_buffer_size``.
_shuffle_over_epochs_multiplier: Inherited shuffle multiplier (merged per node like
``get_datasets``); default ``1``.
_subflavors: Effective subflavors implied by the traversed metadataset hierarchy.

Returns:
A flattened list of `TraversedDatasetReference` values for all leaf datasets reached
Expand All @@ -100,6 +122,7 @@ def get_datasets(
subflavors: Optional[Dict[str, Any]] = None,
shuffle_over_epochs_multiplier: Optional[int] = 1,
subset: Optional[DatasetSubset] = None,
group: Optional[str] = None,
**kwargs,
) -> LoadedDatasetList:
"""
Expand All @@ -120,6 +143,11 @@ def get_datasets(
an infinite number of epochs (effectively, this will draw shard slices with
replacement).
subset: If specified, the inner dataset(s) will be subsetted.
group: Dataset group for this leaf (merged with outer scopes). Datasets that share the same
non-``None`` key are blended and shuffled together, then each group's stream runs through
packing (:class:`~megatron.energon.PackingDataset`) before streams from different groups
are blended. Must match dict keys in ``packing_buffer_size`` / ``shuffle_buffer_size`` when
those are mappings. ``None`` selects the default group.
**kwargs: Additional arguments to the dataset constructor.

Returns:
Expand Down
68 changes: 42 additions & 26 deletions src/megatron/energon/metadataset/metadataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,22 @@ def _merge_traversed_subflavors(
return {**self.subflavors, **(inherited_subflavors or {})}
return dict(inherited_subflavors or {})

def _merge_shuffle_over_epochs_multiplier(
self, inherited_shuffle_over_epochs_multiplier: Optional[int]
) -> Optional[int]:
"""Same semantics as Metadataset V2 ``ShuffleOverEpochsMultiplierMixin`` / ``get_datasets``."""
if (
inherited_shuffle_over_epochs_multiplier is None
or self.shuffle_over_epochs_multiplier is None
):
return None
if (
inherited_shuffle_over_epochs_multiplier == -1
or self.shuffle_over_epochs_multiplier == -1
):
return -1
return inherited_shuffle_over_epochs_multiplier * self.shuffle_over_epochs_multiplier

def post_initialize(self, mds_path: Optional[EPath] = None):
self._resolve_path(mds_path)
if self.path.is_file():
Expand All @@ -101,35 +117,30 @@ def traverse(
mds_path: Optional[EPath] = None,
*,
split_part: Union[Literal["train", "val", "test"], str],
_group: Optional[str] = None,
_shuffle_over_epochs_multiplier: Optional[int] = 1,
_subflavors: Optional[Dict[str, Any]] = None,
) -> List[TraversedDatasetReference]:
"""Traverse this V1 dataset reference into flattened leaf references.

Args:
mds_path: Parent metadataset path used internally to resolve relative dataset and
auxiliary paths. Must be set for nested references and inner traversal nodes;
use None only for top-level metadatasets.
split_part: Split inherited from the parent traversal. If this reference defines its
own split override, that split takes precedence for nested traversal and the
returned leaf reference.

Returns:
A single leaf `TraversedDatasetReference` for direct dataset references, or the
flattened traversal result of the nested metadataset when this reference points to one.
"""
self._resolve_path(mds_path)
effective_subflavors = self._merge_traversed_subflavors(_subflavors)
_subflavors = self._merge_traversed_subflavors(_subflavors)
_shuffle_over_epochs_multiplier = self._merge_shuffle_over_epochs_multiplier(
_shuffle_over_epochs_multiplier
)
if self.path.is_file():
return self._load_nested_metadataset().traverse(
split_part=self.split_part or split_part,
_subflavors=effective_subflavors,
_group=_group,
_shuffle_over_epochs_multiplier=_shuffle_over_epochs_multiplier,
_subflavors=_subflavors,
)
return [
TraversedDatasetReference(
path=self.path,
split_part=self.split_part or split_part,
aux={},
subflavors=effective_subflavors,
subflavors=_subflavors,
group=_group,
shuffle_over_epochs_multiplier=_shuffle_over_epochs_multiplier,
)
]

Expand All @@ -142,6 +153,7 @@ def get_datasets(
subflavors: Optional[Dict[str, Any]] = None,
shuffle_over_epochs_multiplier: Optional[int] = 1,
subset: Optional[DatasetSubset] = None,
group: Optional[str] = None,
**kwargs,
) -> LoadedDatasetList:
if self.subflavors is not None:
Expand All @@ -167,6 +179,7 @@ def get_datasets(
subflavors=subflavors,
shuffle_over_epochs_multiplier=new_shuffle_over_epochs_multiplier,
subset=subset,
group=group,
**kwargs,
)

Expand All @@ -187,6 +200,8 @@ def traverse(
mds_path: Optional[EPath] = None,
*,
split_part: Union[Literal["train", "val", "test"], str],
_group: Optional[str] = None,
_shuffle_over_epochs_multiplier: Optional[int] = 1,
_subflavors: Optional[Dict[str, Any]] = None,
) -> List[TraversedDatasetReference]:
assert mds_path is not None
Expand All @@ -196,6 +211,8 @@ def traverse(
dataset.traverse(
mds_path,
split_part=split_part,
_group=_group,
_shuffle_over_epochs_multiplier=_shuffle_over_epochs_multiplier,
_subflavors=_subflavors,
)
)
Expand All @@ -210,6 +227,7 @@ def get_datasets(
subflavors: Optional[Dict[str, Any]] = None,
shuffle_over_epochs_multiplier: Optional[int] = 1,
subset: Optional[DatasetSubset] = None,
group: Optional[str] = None,
**kwargs,
) -> LoadedDatasetList:
sum_weight = sum(dataset.weight for dataset in self.datasets)
Expand All @@ -222,6 +240,7 @@ def get_datasets(
subflavors=subflavors,
shuffle_over_epochs_multiplier=shuffle_over_epochs_multiplier,
subset=subset,
group=group,
**kwargs,
)
if inner_result.blend_mode not in (
Expand Down Expand Up @@ -270,21 +289,16 @@ def traverse(
mds_path: Optional[EPath] = None,
*,
split_part: Union[Literal["train", "val", "test"], str],
_group: Optional[str] = None,
_shuffle_over_epochs_multiplier: Optional[int] = 1,
_subflavors: Optional[Dict[str, Any]] = None,
) -> List[TraversedDatasetReference]:
"""Traverse the selected V1 split and flatten all reachable leaf references.

Args:
mds_path: Unused for top-level metadatasets. Present to satisfy the shared interface.
split_part: Split to traverse.

Returns:
The flattened list of traversed leaf dataset references for `split_part`.
"""
assert mds_path is None
return self._splits[split_part].traverse(
self._path,
split_part=split_part,
_group=_group,
_shuffle_over_epochs_multiplier=_shuffle_over_epochs_multiplier,
_subflavors=_subflavors,
)

Expand All @@ -297,6 +311,7 @@ def get_datasets(
subflavors: Optional[Dict[str, Any]] = None,
shuffle_over_epochs_multiplier: Optional[int] = 1,
subset: Optional[DatasetSubset] = None,
group: Optional[str] = None,
**kwargs,
) -> LoadedDatasetList:
return self._splits[split_part].get_datasets(
Expand All @@ -306,5 +321,6 @@ def get_datasets(
subflavors=subflavors,
shuffle_over_epochs_multiplier=shuffle_over_epochs_multiplier,
subset=subset,
group=group,
**kwargs,
)
Loading
Loading