diff --git a/docs/source/advanced/packing.md b/docs/source/advanced/packing.md index 45df53ac..0a62a48a 100644 --- a/docs/source/advanced/packing.md +++ b/docs/source/advanced/packing.md @@ -18,8 +18,8 @@ and {py:meth}`pack_selected_samples ` with a ``pushback`` sequence: those samples are appended back to the reading buffer before the next fill from the dataset. For each group, the second method `pack_selected_samples` will be called. You need to implement how a group of samples will be mapped to a single sample. In terms of LLMs for example, this method might concatenate the input tokens. diff --git a/docs/source/basic/task_encoder.md b/docs/source/basic/task_encoder.md index 37295b0b..076a2709 100644 --- a/docs/source/basic/task_encoder.md +++ b/docs/source/basic/task_encoder.md @@ -19,8 +19,8 @@ and override one or more of the following methods. The data flow of {py:func}`ge - {py:meth}`def preencode_sample(self, sample: T_sample) -> T_sample ` - Can be used as a replacement to {py:func}`encode_sample ` - Only used together with {py:meth}`postencode_sample ` (below). Use it if you need to instantiate lazy data before packing the selected samples. But it will also work when not using packing. -- {py:meth}`def select_samples_to_pack(self, samples: List[T_encoded_sample]) -> List[List[T_encoded_sample]] ` - - Optional. Allows for efficient sample packing. See [](../advanced/packing). +- {py:meth}`def select_samples_to_pack(self, samples: List[T_encoded_sample]) -> Union[List[List[T_encoded_sample]], PackedSamplesOutput] ` + - Optional. Allows for efficient sample packing. See [](../advanced/packing). May return {py:class}`PackedSamplesOutput ` to re-queue a pushback sequence onto the reading buffer. - {py:meth}`def postencode_sample(self, sample: T_sample) -> T_encoded_sample ` - Only used together with {py:meth}`preencode_sample `. Use it if you need to instantiate lazy data before packing the selected samples. But it will also work when not using packing. - {py:meth}`def pack_selected_samples(self, samples: List[T_encoded_sample]) -> T_batch_sample] ` diff --git a/src/megatron/energon/__init__.py b/src/megatron/energon/__init__.py index e5204230..4095a735 100644 --- a/src/megatron/energon/__init__.py +++ b/src/megatron/energon/__init__.py @@ -104,6 +104,7 @@ LogSampleDataset, MapDataset, MixBatchDataset, + PackedSamplesOutput, PackingDataset, RepeatDataset, ShuffleBufferDataset, @@ -186,6 +187,7 @@ "PrimaryFileStore", "prepare_metadataset", "RepeatDataset", + "PackedSamplesOutput", "reraise_exception", "Sample", "SampleDecoder", diff --git a/src/megatron/energon/metadataset/dataset_loader.py b/src/megatron/energon/metadataset/dataset_loader.py index 46e93588..b115ac9c 100644 --- a/src/megatron/energon/metadataset/dataset_loader.py +++ b/src/megatron/energon/metadataset/dataset_loader.py @@ -90,6 +90,7 @@ def get_datasets( subflavors: Optional[Dict[str, Any]] = None, shuffle_over_epochs_multiplier: Optional[int] = 1, subset: Optional[DatasetSubset] = None, + group: Optional[str] = None, **kwargs, ) -> LoadedDatasetList: return LoadedDatasetList( @@ -106,6 +107,7 @@ def get_datasets( **kwargs, ), weight=None, + group=group, ) ], ) diff --git a/src/megatron/energon/metadataset/join_dataset_loader.py b/src/megatron/energon/metadataset/join_dataset_loader.py index 4ea1cbd4..7e00791b 100644 --- a/src/megatron/energon/metadataset/join_dataset_loader.py +++ b/src/megatron/energon/metadataset/join_dataset_loader.py @@ -541,6 +541,7 @@ def get_datasets( subflavors: Optional[Dict[str, Any]] = None, shuffle_over_epochs_multiplier: Optional[int] = 1, subset: Optional[DatasetSubset] = None, + group: Optional[str] = None, **kwargs, ) -> LoadedDatasetList: return LoadedDatasetList( @@ -557,6 +558,7 @@ def get_datasets( **kwargs, ), weight=None, + group=group, ) ], ) diff --git a/src/megatron/energon/metadataset/loader_interface.py b/src/megatron/energon/metadataset/loader_interface.py index ba7c6ea5..28779e26 100644 --- a/src/megatron/energon/metadataset/loader_interface.py +++ b/src/megatron/energon/metadataset/loader_interface.py @@ -27,9 +27,17 @@ class DatasetBlendMode(Enum): @edataclass class LoadedDataset: + #: The dataset factory. dataset: BaseCoreDatasetFactory + #: Sampling weight when using dataset-weight blending. weight: Union[float, int, None] = None + #: Epochized repetition count when using repetition-based blending. repetitions: Union[float, int, None] = None + #: Dataset group key from Metadataset V2 (YAML ``group`` field). Must match keys used in + #: ``packing_buffer_size`` / ``shuffle_buffer_size`` when those are dicts. ``None`` is the + #: default group. + group: Optional[str] = None + #: Auxiliary datasets for crude cooking. aux: Optional[Dict[str, FileStore]] = None @@ -48,12 +56,18 @@ class TraversedDatasetReference: split_part: Effective split part to use when loading the leaf dataset. aux: Resolved auxiliary dataset or filesystem references keyed by auxiliary name. subflavors: Effective subflavors implied by the traversed metadataset hierarchy. + group: Merged dataset group name from Metadataset V2 ``group`` fields along the path to this + leaf (nested segments join with ``+``). Keys dict entries for per-group + ``packing_buffer_size`` and ``shuffle_buffer_size``. + shuffle_over_epochs_multiplier: Effective shuffle over epochs multiplier from metadataset references. """ path: EPath split_part: str aux: dict[str, EPath] subflavors: dict[str, Any] + group: Optional[str] = None + shuffle_over_epochs_multiplier: Optional[int] = 1 class DatasetLoaderInterface(ABC): @@ -69,6 +83,8 @@ def traverse( mds_path: Optional[EPath] = None, *, split_part: Union[Literal["train", "val", "test"], str], + _group: Optional[str] = None, + _shuffle_over_epochs_multiplier: Optional[int] = 1, _subflavors: Optional[Dict[str, Any]] = None, ) -> List[TraversedDatasetReference]: """Traverse a metadataset subtree and collect flattened leaf dataset references. @@ -83,6 +99,12 @@ def traverse( use None only for top-level metadatasets. split_part: Split to traverse, such as `\"train\"`, `\"val\"`, or `\"test\"`. Nested references may override this with their own configured split. + _group: Inherited merged group name from Metadataset V2 ``group`` fields (nested segments + join with ``+``). Used to match dict keys in ``packing_buffer_size`` / + ``shuffle_buffer_size``. + _shuffle_over_epochs_multiplier: Inherited shuffle multiplier (merged per node like + ``get_datasets``); default ``1``. + _subflavors: Effective subflavors implied by the traversed metadataset hierarchy. Returns: A flattened list of `TraversedDatasetReference` values for all leaf datasets reached @@ -100,6 +122,7 @@ def get_datasets( subflavors: Optional[Dict[str, Any]] = None, shuffle_over_epochs_multiplier: Optional[int] = 1, subset: Optional[DatasetSubset] = None, + group: Optional[str] = None, **kwargs, ) -> LoadedDatasetList: """ @@ -120,6 +143,11 @@ def get_datasets( an infinite number of epochs (effectively, this will draw shard slices with replacement). subset: If specified, the inner dataset(s) will be subsetted. + group: Dataset group for this leaf (merged with outer scopes). Datasets that share the same + non-``None`` key are blended and shuffled together, then each group's stream runs through + packing (:class:`~megatron.energon.PackingDataset`) before streams from different groups + are blended. Must match dict keys in ``packing_buffer_size`` / ``shuffle_buffer_size`` when + those are mappings. ``None`` selects the default group. **kwargs: Additional arguments to the dataset constructor. Returns: diff --git a/src/megatron/energon/metadataset/metadataset.py b/src/megatron/energon/metadataset/metadataset.py index ec8efa45..8ce475aa 100644 --- a/src/megatron/energon/metadataset/metadataset.py +++ b/src/megatron/energon/metadataset/metadataset.py @@ -81,6 +81,22 @@ def _merge_traversed_subflavors( return {**self.subflavors, **(inherited_subflavors or {})} return dict(inherited_subflavors or {}) + def _merge_shuffle_over_epochs_multiplier( + self, inherited_shuffle_over_epochs_multiplier: Optional[int] + ) -> Optional[int]: + """Same semantics as Metadataset V2 ``ShuffleOverEpochsMultiplierMixin`` / ``get_datasets``.""" + if ( + inherited_shuffle_over_epochs_multiplier is None + or self.shuffle_over_epochs_multiplier is None + ): + return None + if ( + inherited_shuffle_over_epochs_multiplier == -1 + or self.shuffle_over_epochs_multiplier == -1 + ): + return -1 + return inherited_shuffle_over_epochs_multiplier * self.shuffle_over_epochs_multiplier + def post_initialize(self, mds_path: Optional[EPath] = None): self._resolve_path(mds_path) if self.path.is_file(): @@ -101,35 +117,30 @@ def traverse( mds_path: Optional[EPath] = None, *, split_part: Union[Literal["train", "val", "test"], str], + _group: Optional[str] = None, + _shuffle_over_epochs_multiplier: Optional[int] = 1, _subflavors: Optional[Dict[str, Any]] = None, ) -> List[TraversedDatasetReference]: - """Traverse this V1 dataset reference into flattened leaf references. - - Args: - mds_path: Parent metadataset path used internally to resolve relative dataset and - auxiliary paths. Must be set for nested references and inner traversal nodes; - use None only for top-level metadatasets. - split_part: Split inherited from the parent traversal. If this reference defines its - own split override, that split takes precedence for nested traversal and the - returned leaf reference. - - Returns: - A single leaf `TraversedDatasetReference` for direct dataset references, or the - flattened traversal result of the nested metadataset when this reference points to one. - """ self._resolve_path(mds_path) - effective_subflavors = self._merge_traversed_subflavors(_subflavors) + _subflavors = self._merge_traversed_subflavors(_subflavors) + _shuffle_over_epochs_multiplier = self._merge_shuffle_over_epochs_multiplier( + _shuffle_over_epochs_multiplier + ) if self.path.is_file(): return self._load_nested_metadataset().traverse( split_part=self.split_part or split_part, - _subflavors=effective_subflavors, + _group=_group, + _shuffle_over_epochs_multiplier=_shuffle_over_epochs_multiplier, + _subflavors=_subflavors, ) return [ TraversedDatasetReference( path=self.path, split_part=self.split_part or split_part, aux={}, - subflavors=effective_subflavors, + subflavors=_subflavors, + group=_group, + shuffle_over_epochs_multiplier=_shuffle_over_epochs_multiplier, ) ] @@ -142,6 +153,7 @@ def get_datasets( subflavors: Optional[Dict[str, Any]] = None, shuffle_over_epochs_multiplier: Optional[int] = 1, subset: Optional[DatasetSubset] = None, + group: Optional[str] = None, **kwargs, ) -> LoadedDatasetList: if self.subflavors is not None: @@ -167,6 +179,7 @@ def get_datasets( subflavors=subflavors, shuffle_over_epochs_multiplier=new_shuffle_over_epochs_multiplier, subset=subset, + group=group, **kwargs, ) @@ -187,6 +200,8 @@ def traverse( mds_path: Optional[EPath] = None, *, split_part: Union[Literal["train", "val", "test"], str], + _group: Optional[str] = None, + _shuffle_over_epochs_multiplier: Optional[int] = 1, _subflavors: Optional[Dict[str, Any]] = None, ) -> List[TraversedDatasetReference]: assert mds_path is not None @@ -196,6 +211,8 @@ def traverse( dataset.traverse( mds_path, split_part=split_part, + _group=_group, + _shuffle_over_epochs_multiplier=_shuffle_over_epochs_multiplier, _subflavors=_subflavors, ) ) @@ -210,6 +227,7 @@ def get_datasets( subflavors: Optional[Dict[str, Any]] = None, shuffle_over_epochs_multiplier: Optional[int] = 1, subset: Optional[DatasetSubset] = None, + group: Optional[str] = None, **kwargs, ) -> LoadedDatasetList: sum_weight = sum(dataset.weight for dataset in self.datasets) @@ -222,6 +240,7 @@ def get_datasets( subflavors=subflavors, shuffle_over_epochs_multiplier=shuffle_over_epochs_multiplier, subset=subset, + group=group, **kwargs, ) if inner_result.blend_mode not in ( @@ -270,21 +289,16 @@ def traverse( mds_path: Optional[EPath] = None, *, split_part: Union[Literal["train", "val", "test"], str], + _group: Optional[str] = None, + _shuffle_over_epochs_multiplier: Optional[int] = 1, _subflavors: Optional[Dict[str, Any]] = None, ) -> List[TraversedDatasetReference]: - """Traverse the selected V1 split and flatten all reachable leaf references. - - Args: - mds_path: Unused for top-level metadatasets. Present to satisfy the shared interface. - split_part: Split to traverse. - - Returns: - The flattened list of traversed leaf dataset references for `split_part`. - """ assert mds_path is None return self._splits[split_part].traverse( self._path, split_part=split_part, + _group=_group, + _shuffle_over_epochs_multiplier=_shuffle_over_epochs_multiplier, _subflavors=_subflavors, ) @@ -297,6 +311,7 @@ def get_datasets( subflavors: Optional[Dict[str, Any]] = None, shuffle_over_epochs_multiplier: Optional[int] = 1, subset: Optional[DatasetSubset] = None, + group: Optional[str] = None, **kwargs, ) -> LoadedDatasetList: return self._splits[split_part].get_datasets( @@ -306,5 +321,6 @@ def get_datasets( subflavors=subflavors, shuffle_over_epochs_multiplier=shuffle_over_epochs_multiplier, subset=subset, + group=group, **kwargs, ) diff --git a/src/megatron/energon/metadataset/metadataset_v2.py b/src/megatron/energon/metadataset/metadataset_v2.py index 6f277e95..5f323b1d 100644 --- a/src/megatron/energon/metadataset/metadataset_v2.py +++ b/src/megatron/energon/metadataset/metadataset_v2.py @@ -173,7 +173,7 @@ def merge(self, parent_subset: DatasetSubset | None) -> DatasetSubset: ) -@edataclass +@dataclass(kw_only=True, eq=False) class SubsetRatioMixin: subset: Optional[Subset] = None @@ -191,13 +191,80 @@ def _get_subset(self, parent_subset: Optional[DatasetSubset]) -> Optional[Datase return None +@dataclass(kw_only=True, eq=False) +class GroupMixin: + #: Names a dataset group for train pipelines (blend/shuffle/pack): datasets sharing the + #: same non-None string are blended and shuffled together before packing; packed streams from + #: different groups are then blended. ``None`` means default from outer scopes. + group: Optional[str] = None + + def _merge_group(self, inherited_group: Optional[str]) -> Optional[str]: + if self.group is not None: + if inherited_group is not None: + return f"{inherited_group}+{self.group}" + return self.group + return inherited_group + + +@dataclass(kw_only=True, eq=False) +class ShuffleOverEpochsMultiplierMixin: + shuffle_over_epochs_multiplier: Optional[int] = 1 + + def _merge_shuffle_over_epochs_multiplier( + self, inherited_shuffle_over_epochs_multiplier: Optional[int] + ) -> Optional[int]: + if ( + inherited_shuffle_over_epochs_multiplier is None + or self.shuffle_over_epochs_multiplier is None + ): + # If no shuffling is requested, this has override priority. + return None + elif ( + inherited_shuffle_over_epochs_multiplier == -1 + or self.shuffle_over_epochs_multiplier == -1 + ): + # Next priority is sampling without replacement. + return -1 + else: + # Otherwise, multiply the shuffle over epochs multiplier. + return inherited_shuffle_over_epochs_multiplier * self.shuffle_over_epochs_multiplier + + +@dataclass(kw_only=True, eq=False) +class SubflavorsMixin: + subflavors: Optional[Dict[str, Any]] = None + + def _merge_subflavors(self, inherited_subflavors: Optional[Dict[str, Any]]) -> Dict[str, Any]: + """Merge this reference's subflavors with the inherited traversal subflavors. + + The merge order mirrors `get_datasets(...)`: this reference contributes the base mapping, + and inherited outer-hierarchy subflavors override on key conflicts. + + + Args: + inherited_subflavors: Effective subflavors accumulated from outer metadataset + references during traversal. + + Returns: + The effective subflavor mapping for this reference, after applying outer-overrides-inner + merge semantics. + """ + if self.subflavors is not None: + return {**self.subflavors, **(inherited_subflavors or {})} + return dict(inherited_subflavors or {}) + + @edataclass -class DatasetReference(SubsetRatioMixin, DatasetLoaderInterface): +class DatasetReference( + SubsetRatioMixin, + GroupMixin, + ShuffleOverEpochsMultiplierMixin, + SubflavorsMixin, + DatasetLoaderInterface, +): path: Union[str, EPath] split_part: Optional[str] = None - subflavors: Optional[Dict[str, Any]] = None - shuffle_over_epochs_multiplier: Optional[int] = 1 dataset_config: Optional[str] = None split_config: Optional[str] = None @@ -269,26 +336,6 @@ def _get_traversed_aux_references(self) -> dict[str, EPath]: traversed_aux[key] = value.fs_path return traversed_aux - def _merge_traversed_subflavors( - self, inherited_subflavors: Optional[Dict[str, Any]] - ) -> Dict[str, Any]: - """Merge this reference's subflavors with the inherited traversal subflavors. - - The merge order mirrors `get_datasets(...)`: this reference contributes the base mapping, - and inherited outer-hierarchy subflavors override on key conflicts. - - Args: - inherited_subflavors: Effective subflavors accumulated from outer metadataset - references during traversal. - - Returns: - The effective subflavor mapping for this reference, after applying outer-overrides-inner - merge semantics. - """ - if self.subflavors is not None: - return {**self.subflavors, **(inherited_subflavors or {})} - return dict(inherited_subflavors or {}) - def _load_nested_metadataset(self) -> DatasetLoaderInterface: assert isinstance(self.path, EPath) assert self.aux is None, "Cannot specify auxiliary datasets for crude datasets" @@ -327,33 +374,24 @@ def traverse( mds_path: Optional[EPath] = None, *, split_part: Union[Literal["train", "val", "test"], str], + _group: Optional[str] = None, + _shuffle_over_epochs_multiplier: Optional[int] = 1, _subflavors: Optional[Dict[str, Any]] = None, ) -> List[TraversedDatasetReference]: - """Traverse this V2 dataset reference into flattened leaf references. - - For direct leaf datasets, traversal resolves the dataset path and any auxiliary references - into plain `EPath` values. For nested metadatasets, traversal recurses immediately into the - referenced split instead of building an intermediate object graph. - Args: - mds_path: Parent metadataset path used internally to resolve relative dataset and - auxiliary paths. Must be set for nested references and inner traversal nodes; - use None only for top-level metadatasets. - split_part: Split inherited from the parent traversal. If this reference defines its - own split override, that split takes precedence for nested traversal and the - returned leaf reference. - - Returns: - A single leaf `TraversedDatasetReference` for direct dataset references, or the - flattened traversal result of the nested metadataset when this reference points to one. - """ self._resolve_path(mds_path) - effective_subflavors = self._merge_traversed_subflavors(_subflavors) + _subflavors = self._merge_subflavors(_subflavors) + _group = self._merge_group(_group) + _shuffle_over_epochs_multiplier = self._merge_shuffle_over_epochs_multiplier( + _shuffle_over_epochs_multiplier + ) ds_type = get_dataset_type(self.path) if ds_type == EnergonDatasetType.METADATASET: return self._load_nested_metadataset().traverse( split_part=self.split_part or split_part, - _subflavors=effective_subflavors, + _group=_group, + _shuffle_over_epochs_multiplier=_shuffle_over_epochs_multiplier, + _subflavors=_subflavors, ) self._normalize_aux_references(mds_path, validate=False) return [ @@ -361,7 +399,9 @@ def traverse( path=self.path, split_part=self.split_part or split_part, aux=self._get_traversed_aux_references(), - subflavors=effective_subflavors, + subflavors=_subflavors, + group=_group, + shuffle_over_epochs_multiplier=_shuffle_over_epochs_multiplier, ) ] @@ -378,32 +418,21 @@ def get_datasets( subflavors: Optional[Dict[str, Any]] = None, shuffle_over_epochs_multiplier: Optional[int] = 1, subset: Optional[DatasetSubset] = None, + group: Optional[str] = None, **kwargs, ) -> LoadedDatasetList: - if self.subflavors is not None: - subflavors = {**self.subflavors, **(subflavors or {})} assert self._dataset is not None - if shuffle_over_epochs_multiplier is None or self.shuffle_over_epochs_multiplier is None: - # If no shuffling is requested, this has override priority. - new_shuffle_over_epochs_multiplier = None - elif shuffle_over_epochs_multiplier == -1 or self.shuffle_over_epochs_multiplier == -1: - # Next priority is sampling without replacement. - new_shuffle_over_epochs_multiplier = -1 - else: - # Otherwise, multiply the shuffle over epochs multiplier. - new_shuffle_over_epochs_multiplier = ( - shuffle_over_epochs_multiplier * self.shuffle_over_epochs_multiplier - ) - subset = self._get_subset(subset) - result = self._dataset.get_datasets( training=training, split_part=self.split_part or split_part, worker_config=worker_config, - subflavors=subflavors, - shuffle_over_epochs_multiplier=new_shuffle_over_epochs_multiplier, - subset=subset, + subflavors=self._merge_subflavors(subflavors), + shuffle_over_epochs_multiplier=self._merge_shuffle_over_epochs_multiplier( + shuffle_over_epochs_multiplier + ), + subset=self._get_subset(subset), + group=self._merge_group(group), **kwargs, ) if self.aux is not None: @@ -443,6 +472,8 @@ def traverse( mds_path: Optional[EPath] = None, *, split_part: Union[Literal["train", "val", "test"], str], + _group: Optional[str] = None, + _shuffle_over_epochs_multiplier: Optional[int] = 1, _subflavors: Optional[Dict[str, Any]] = None, ) -> List[TraversedDatasetReference]: raise NotImplementedError("traverse_metadataset() does not support joined datasets.") @@ -462,13 +493,17 @@ def get_datasets( @edataclass -class MetadatasetJoin(SubsetRatioMixin, DatasetLoaderInterface): +class MetadatasetJoin( + SubsetRatioMixin, + GroupMixin, + ShuffleOverEpochsMultiplierMixin, + SubflavorsMixin, + DatasetLoaderInterface, +): join: Union[List[JoinDatasetReference], Dict[str, JoinDatasetReference]] joiner: Union[Type[Sample], Callable[..., Sample]] split_part: Optional[str] = None - subflavors: Optional[Dict[str, Any]] = None - shuffle_over_epochs_multiplier: Optional[int] = 1 dataset_config: Optional[str] = None split_config: Optional[str] = None @@ -514,6 +549,8 @@ def traverse( mds_path: Optional[EPath] = None, *, split_part: Union[Literal["train", "val", "test"], str], + _group: Optional[str] = None, + _shuffle_over_epochs_multiplier: Optional[int] = 1, _subflavors: Optional[Dict[str, Any]] = None, ) -> List[TraversedDatasetReference]: raise NotImplementedError("traverse_metadataset() does not support joined datasets.") @@ -531,17 +568,20 @@ def get_datasets( subflavors: Optional[Dict[str, Any]] = None, shuffle_over_epochs_multiplier: Optional[int] = 1, subset: Optional[DatasetSubset] = None, + group: Optional[str] = None, **kwargs, ) -> LoadedDatasetList: assert self._dataset is not None, "Missing post_initialize call." - subset = self._get_subset(subset) return self._dataset.get_datasets( training=training, split_part=split_part, worker_config=worker_config, - subflavors=subflavors, - shuffle_over_epochs_multiplier=shuffle_over_epochs_multiplier, - subset=subset, + subflavors=self._merge_subflavors(subflavors), + shuffle_over_epochs_multiplier=self._merge_shuffle_over_epochs_multiplier( + shuffle_over_epochs_multiplier + ), + subset=self._get_subset(subset), + group=self._merge_group(group), **kwargs, ) @@ -562,10 +602,16 @@ class BlendJoinDatasetReference(BlendWeightMixin, MetadatasetJoin): @edataclass -class MetadatasetBlend(DatasetLoaderInterface, SubsetRatioMixin): +class MetadatasetBlend( + SubsetRatioMixin, + GroupMixin, + ShuffleOverEpochsMultiplierMixin, + SubflavorsMixin, + DatasetLoaderInterface, +): """Blending of datasets by specifying the sampling weight for the inner datasets.""" - blend: List[Union[BlendDatasetReference, BlendJoinDatasetReference]] + blend: List[Union[BlendDatasetReference, BlendJoinDatasetReference, "MetadatasetBlend"]] def post_initialize(self, mds_path: Optional[EPath] = None): assert mds_path is not None @@ -577,15 +623,24 @@ def traverse( mds_path: Optional[EPath] = None, *, split_part: Union[Literal["train", "val", "test"], str], + _group: Optional[str] = None, + _shuffle_over_epochs_multiplier: Optional[int] = 1, _subflavors: Optional[Dict[str, Any]] = None, ) -> List[TraversedDatasetReference]: assert mds_path is not None + _group = self._merge_group(_group) + _shuffle_over_epochs_multiplier = self._merge_shuffle_over_epochs_multiplier( + _shuffle_over_epochs_multiplier + ) + _subflavors = self._merge_subflavors(_subflavors) flattened: List[TraversedDatasetReference] = [] for dataset in self.blend: flattened.extend( dataset.traverse( mds_path, split_part=split_part, + _group=_group, + _shuffle_over_epochs_multiplier=_shuffle_over_epochs_multiplier, _subflavors=_subflavors, ) ) @@ -606,9 +661,15 @@ def get_datasets( subflavors: Optional[Dict[str, Any]] = None, shuffle_over_epochs_multiplier: Optional[int] = 1, subset: Optional[DatasetSubset] = None, + group: Optional[str] = None, **kwargs, ) -> LoadedDatasetList: subset = self._get_subset(subset) + group = self._merge_group(group) + subflavors = self._merge_subflavors(subflavors) + shuffle_over_epochs_multiplier = self._merge_shuffle_over_epochs_multiplier( + shuffle_over_epochs_multiplier + ) sum_weight = sum(dataset.weight for dataset in self.blend) datasets = [] for dataset in self.blend: @@ -619,6 +680,7 @@ def get_datasets( subflavors=subflavors, shuffle_over_epochs_multiplier=shuffle_over_epochs_multiplier, subset=subset, + group=group, **kwargs, ) if inner_result.blend_mode not in ( @@ -660,13 +722,25 @@ class BlendEpochizedJoinDatasetReference(BlendRepetitionsMixin, MetadatasetJoin) @edataclass -class MetadatasetBlendEpochized(SubsetRatioMixin, DatasetLoaderInterface): +class MetadatasetBlendEpochized( + SubsetRatioMixin, + GroupMixin, + ShuffleOverEpochsMultiplierMixin, + SubflavorsMixin, + DatasetLoaderInterface, +): """Blending of datasets, by specifying the number of repetitions for samples from the inner datasets. Ensures that the constraint, that samples are seen exactly this many times before repeating the "epoch" (i.e. one epoch contains the total number of repetitions for each inner dataset).""" - blend_epochized: List[Union[BlendEpochizedDatasetReference, BlendEpochizedJoinDatasetReference]] + blend_epochized: List[ + Union[ + BlendEpochizedDatasetReference, + BlendEpochizedJoinDatasetReference, + "MetadatasetBlendEpochized", + ] + ] def post_initialize(self, mds_path: Optional[EPath] = None): assert mds_path is not None @@ -678,15 +752,24 @@ def traverse( mds_path: Optional[EPath] = None, *, split_part: Union[Literal["train", "val", "test"], str], + _group: Optional[str] = None, + _shuffle_over_epochs_multiplier: Optional[int] = 1, _subflavors: Optional[Dict[str, Any]] = None, ) -> List[TraversedDatasetReference]: assert mds_path is not None flattened: List[TraversedDatasetReference] = [] + _group = self._merge_group(_group) + _shuffle_over_epochs_multiplier = self._merge_shuffle_over_epochs_multiplier( + _shuffle_over_epochs_multiplier + ) + _subflavors = self._merge_subflavors(_subflavors) for dataset in self.blend_epochized: flattened.extend( dataset.traverse( mds_path, split_part=split_part, + _group=_group, + _shuffle_over_epochs_multiplier=_shuffle_over_epochs_multiplier, _subflavors=_subflavors, ) ) @@ -707,9 +790,15 @@ def get_datasets( subflavors: Optional[Dict[str, Any]] = None, shuffle_over_epochs_multiplier: Optional[int] = 1, subset: Optional[DatasetSubset] = None, + group: Optional[str] = None, **kwargs, ) -> LoadedDatasetList: subset = self._get_subset(subset) + shuffle_over_epochs_multiplier = self._merge_shuffle_over_epochs_multiplier( + shuffle_over_epochs_multiplier + ) + group = self._merge_group(group) + subflavors = self._merge_subflavors(subflavors) datasets = [] for dataset in self.blend_epochized: inner_result = dataset.get_datasets( @@ -719,6 +808,7 @@ def get_datasets( subflavors=subflavors, shuffle_over_epochs_multiplier=shuffle_over_epochs_multiplier, subset=subset, + group=group, **kwargs, ) if inner_result.blend_mode not in ( @@ -760,6 +850,8 @@ def traverse( mds_path: Optional[EPath] = None, *, split_part: Union[Literal["train", "val", "test"], str], + _group: Optional[str] = None, + _shuffle_over_epochs_multiplier: Optional[int] = 1, _subflavors: Optional[Dict[str, Any]] = None, ) -> List[TraversedDatasetReference]: """Traverse the selected V2 split and flatten all reachable leaf references. @@ -775,6 +867,8 @@ def traverse( return self.splits[split_part].traverse( self.path, split_part=split_part, + _group=_group, + _shuffle_over_epochs_multiplier=_shuffle_over_epochs_multiplier, _subflavors=_subflavors, ) @@ -808,6 +902,7 @@ def get_datasets( subflavors: Optional[Dict[str, Any]] = None, shuffle_over_epochs_multiplier: Optional[int] = 1, subset: Optional[DatasetSubset] = None, + group: Optional[str] = None, **kwargs, ) -> LoadedDatasetList: return self.splits[split_part].get_datasets( @@ -817,5 +912,6 @@ def get_datasets( subflavors=subflavors, shuffle_over_epochs_multiplier=shuffle_over_epochs_multiplier, subset=subset, + group=group, **kwargs, ) diff --git a/src/megatron/energon/task_encoder/base.py b/src/megatron/energon/task_encoder/base.py index 5335d018..f083c8c3 100644 --- a/src/megatron/energon/task_encoder/base.py +++ b/src/megatron/energon/task_encoder/base.py @@ -56,6 +56,7 @@ PackingDataset, ShuffleBufferDataset, ) +from megatron.energon.wrappers.packing_dataset import PackedSamplesOutput from megatron.energon.wrappers.repeat_dataset import RepeatDataset T = TypeVar("T") @@ -536,16 +537,18 @@ def _batch( def select_samples_to_pack( self, samples: List[T_encoded_sample] - ) -> List[List[T_encoded_sample]]: + ) -> list[list[T_encoded_sample]] | PackedSamplesOutput[T_encoded_sample]: """ For packing, selects the samples to be packed together. Packing is only active when packing_buffer_size is set. Internally this stage is called "pre_packing". Args: - samples: The samples to pre-pack. A full buffer will be passed into the function. + samples: The samples to pre-pack (a full reading buffer per call when ``packing_buffer_size`` is set). - Returns: The pre-packed samples as a list of lists of samples. + Returns: + Either a ``list[list[T]]`` of packs, or :class:`PackedSamplesOutput` + to attach a ``pushback`` sequence reapplied to the reading buffer before the next fill. """ raise NotImplementedError("Packing only effective when overridden.") @@ -562,63 +565,88 @@ def pack_selected_samples(self, samples: List[T_encoded_sample]) -> T_encoded_sa """ raise NotImplementedError("Packing only effective when overridden.") - def build_batch( + def _build_packing_postencode( self, dataset: SavableDataset[T_encoded_sample], *, - batch_size: Optional[int], - batch_drop_last: bool = False, - packing_buffer_size: Optional[int] = None, + group: Optional[str], + packing_buffer_size: int | dict[str | None, int | None] | None, worker_config: WorkerConfig, - ) -> SavableDataset[T_raw_batch]: - """Applies the batcher to the dataset.""" - - dataset: SavableDataset[Any] + ) -> SavableDataset[T_encoded_sample]: + """Builds the (packing +) post-encode stage after encoding. - if packing_buffer_size is not None: - select_samples_to_pack_provided = self._is_overridden(self.select_samples_to_pack) - pack_selected_samples_provided = self._is_overridden(self.pack_selected_samples) + When ``packing_buffer_size`` is a dict, selects the buffer size (or ``None`` to disable + packing) for this leaf's dataset group via ``group`` (must match + :attr:`~megatron.energon.metadataset.loader_interface.LoadedDataset.group`). - assert select_samples_to_pack_provided and pack_selected_samples_provided, ( - "Both select_samples_to_pack and pack_selected_samples methods must be provided in the TaskEncoder when using packing_buffer_size" - ) + Args: + dataset: Encoded sample stream for one group. + group: Key into ``packing_buffer_size`` when it is a dict; ``None`` is the default group. + packing_buffer_size: Global buffer size, per-group mapping, or ``None`` to disable packing. + worker_config: Worker configuration for wrapped datasets. + """ + if isinstance(packing_buffer_size, dict): + packing_buffer_size = packing_buffer_size[group] + if packing_buffer_size is None: if self._is_overridden(self.postencode_sample): - post_encode_fn = self.postencode_sample - else: - post_encode_fn = None + dataset = MapDataset( + dataset, + self.postencode_sample, + worker_config=worker_config, + stateless_map_fn=get_stateless(self.postencode_sample), + failure_tolerance=get_failure_tolerance( + self.postencode_sample, self.__default_failure_tolerance__ + ), + ) + return dataset - dataset = PackingDataset( - dataset, - buffer_size=packing_buffer_size, - pre_packer=self.select_samples_to_pack, - final_packer=self.pack_selected_samples, - final_packer_stateless=get_stateless(self.pack_selected_samples), - sample_encoder=post_encode_fn, - sample_encoder_stateless=True - if post_encode_fn is None - else get_stateless(post_encode_fn), - worker_config=worker_config, - pre_packer_failure_tolerance=get_failure_tolerance( - self.select_samples_to_pack, self.__default_failure_tolerance__ - ), - final_packer_failure_tolerance=get_failure_tolerance( - self.pack_selected_samples, self.__default_failure_tolerance__ - ), - sample_encoder_failure_tolerance=0 - if post_encode_fn is None - else get_failure_tolerance(post_encode_fn, self.__default_failure_tolerance__), - ) - elif self._is_overridden(self.postencode_sample): - dataset = MapDataset( - dataset, - self.postencode_sample, - worker_config=worker_config, - stateless_map_fn=get_stateless(self.postencode_sample), - failure_tolerance=get_failure_tolerance( - self.postencode_sample, self.__default_failure_tolerance__ - ), + select_samples_to_pack_provided = self._is_overridden(self.select_samples_to_pack) + pack_selected_samples_provided = self._is_overridden(self.pack_selected_samples) + + assert select_samples_to_pack_provided and pack_selected_samples_provided, ( + "Both select_samples_to_pack and pack_selected_samples methods must be provided in the TaskEncoder when using packing_buffer_size" + ) + + if self._is_overridden(self.postencode_sample): + post_encode_fn = self.postencode_sample + post_encode_stateless = get_stateless(self.postencode_sample) + post_encode_failure_tolerance = get_failure_tolerance( + self.postencode_sample, self.__default_failure_tolerance__ ) + else: + post_encode_fn = None + post_encode_stateless = True + post_encode_failure_tolerance = 0 + + return PackingDataset( + dataset, + buffer_size=packing_buffer_size, + pre_packer=self.select_samples_to_pack, + final_packer=self.pack_selected_samples, + final_packer_stateless=get_stateless(self.pack_selected_samples), + sample_encoder=post_encode_fn, + sample_encoder_stateless=post_encode_stateless, + worker_config=worker_config, + pre_packer_failure_tolerance=get_failure_tolerance( + self.select_samples_to_pack, self.__default_failure_tolerance__ + ), + final_packer_failure_tolerance=get_failure_tolerance( + self.pack_selected_samples, self.__default_failure_tolerance__ + ), + sample_encoder_failure_tolerance=post_encode_failure_tolerance, + ) + + def build_batch( + self, + dataset: SavableDataset[T_encoded_sample], + *, + batch_size: Optional[int], + batch_drop_last: bool = False, + worker_config: WorkerConfig, + ) -> SavableDataset[T_raw_batch]: + """Applies the batcher to the dataset.""" + dataset: SavableDataset[Any] if self._is_overridden(self.batch_group_criterion): dataset = GroupBatchDataset( @@ -769,33 +797,36 @@ def build_encode_sample( ) return dataset - def build_train_datasets( + def _group_weight( + self, + group_ds: List[LoadedDataset], + blend_mode: DatasetBlendMode, + *, + repeat: bool, + ) -> float: + """Blend weight for one dataset group when merging packed streams after grouping.""" + if blend_mode == DatasetBlendMode.DATASET_WEIGHT: + return sum(float(d.weight) for d in group_ds) + if blend_mode == DatasetBlendMode.SAMPLE_REPETITIONS or ( + not repeat and blend_mode == DatasetBlendMode.NONE + ): + return sum( + len(d.dataset) * (1 if d.repetitions is None else float(d.repetitions)) + for d in group_ds + ) + return float(len(group_ds)) + + def _build_train_blend_shuffle_encode_branch( self, *, datasets: List[LoadedDataset], + worker_rotation_offsets: List[int], + blend_mode: DatasetBlendMode, + repeat: bool, + shuffle_buffer_size: Optional[int], worker_config: WorkerConfig, - batch_size: Optional[int], - batch_drop_last: bool = False, - packing_buffer_size: Optional[int] = None, - virtual_epoch_length: int = 0, - shuffle_buffer_size: Optional[int] = None, - blend_mode: DatasetBlendMode = DatasetBlendMode.NONE, - repeat: bool = True, - ) -> SavableDataset[T_batch]: - """Combines train datasets to a single dataset.""" - - # Check if there's a CrudeWebdataset but no cookers - for dataset in datasets: - if isinstance(dataset.dataset, CrudeWebdataset): - assert self.cookers, "CrudeWebdataset found, but no cookers registered." - - global_workers = max(1, worker_config.num_workers) * worker_config.world_size - rotation_lengths = [len(dataset.dataset) for dataset in datasets] - for i in range(1, len(rotation_lengths)): - rotation_lengths[i] += rotation_lengths[i - 1] - worker_rotation_offsets = [ - rotation_length % global_workers for rotation_length in [0] + rotation_lengths[:-1] - ] + ) -> SavableDataset[T_encoded_sample]: + """Builds the (blend) → (repeat/shuffle) → (preencode/encode) pipeline.""" if blend_mode == DatasetBlendMode.DATASET_WEIGHT: assert repeat, ( @@ -874,14 +905,134 @@ def build_train_datasets( size=shuffle_buffer_size, worker_config=worker_config, ) - dataset = self.build_encode_sample(dataset, worker_config=worker_config) + return self.build_encode_sample(dataset, worker_config=worker_config) + + def _compute_rotation_offsets( + self, datasets: List[LoadedDataset], worker_config: WorkerConfig + ) -> List[int]: + global_workers = max(1, worker_config.num_workers) * worker_config.world_size + rotation_lengths = [len(d.dataset) for d in datasets] + for i in range(1, len(rotation_lengths)): + rotation_lengths[i] += rotation_lengths[i - 1] + return [rotation_length % global_workers for rotation_length in [0] + rotation_lengths[:-1]] + + def _build_train_blend_shuffle_encode_groups( + self, + datasets: List[LoadedDataset], + packing_buffer_size: Optional[int | dict[str | None, int | None]], + blend_mode: DatasetBlendMode, + repeat: bool, + shuffle_buffer_size: Optional[int | dict[str | None, int | None]], + worker_config: WorkerConfig, + ) -> SavableDataset[T_encoded_sample]: + """Builds the train pipeline with optional per-dataset-group isolation. + + Splits ``datasets`` by :attr:`~megatron.energon.metadataset.loader_interface.LoadedDataset.group`. + For each group, runs blend → (optional shuffle) → encode, then applies packing/postencode for + that group's ``packing_buffer_size`` / ``shuffle_buffer_size`` entries. When multiple groups + exist, blends the resulting streams with weights from :meth:`_group_weight`. + + Pipeline per group: + ``blend → shuffle → encode → select_samples_to_pack → postencode → pack_selected_samples``. + Multiple groups: ``(... per group ...) → blend``. + """ + rotation_offsets = self._compute_rotation_offsets(datasets, worker_config) + + dataset_groups: dict[Optional[str], tuple[list[LoadedDataset], list[int]]] = {} + for ld, ro in zip(datasets, rotation_offsets): + if ld.group in dataset_groups: + dataset_groups[ld.group][0].append(ld) + dataset_groups[ld.group][1].append(ro) + else: + dataset_groups[ld.group] = ([ld], [ro]) + + streams: List[tuple[SavableDataset[Any], float]] = [] + for group_key, (group_ds, rotation_offsets) in dataset_groups.items(): + if isinstance(shuffle_buffer_size, dict): + group_shuffle_buffer_size = shuffle_buffer_size[group_key] + else: + group_shuffle_buffer_size = shuffle_buffer_size + + dataset = self._build_train_blend_shuffle_encode_branch( + datasets=group_ds, + worker_rotation_offsets=rotation_offsets, + blend_mode=blend_mode, + repeat=repeat, + shuffle_buffer_size=group_shuffle_buffer_size, + worker_config=worker_config, + ) + # Post-encode is included + dataset = self._build_packing_postencode( + dataset, + group=group_key, + packing_buffer_size=packing_buffer_size, + worker_config=worker_config, + ) + streams.append( + ( + dataset, + self._group_weight(group_ds, blend_mode, repeat=repeat), + ) + ) + + if len(streams) > 1: + return BlendDataset(*streams, worker_config=worker_config) + else: + return streams[0][0] + + def build_train_datasets( + self, + *, + datasets: List[LoadedDataset], + worker_config: WorkerConfig, + batch_size: Optional[int], + batch_drop_last: bool = False, + packing_buffer_size: Optional[int | dict[str | None, int | None]] = None, + virtual_epoch_length: int = 0, + shuffle_buffer_size: Optional[int | dict[str | None, int | None]] = None, + blend_mode: DatasetBlendMode = DatasetBlendMode.NONE, + repeat: bool = True, + ) -> SavableDataset[T_batch]: + """Combines train datasets into one batched dataset pipeline. + + Args: + datasets: Loaded leaf datasets (each carries a ``group`` key when using Metadataset V2). + worker_config: Worker configuration for wrapped datasets. + batch_size: Batch dimension; ``None`` skips batching. + batch_drop_last: If true, drop the last batch when smaller than ``batch_size``. + packing_buffer_size: Packing buffer size, or a dict mapping dataset group keys + (including ``None`` for the default group) to sizes or ``None`` to disable packing + per group. + virtual_epoch_length: If positive, wraps with epochization at this length. + shuffle_buffer_size: Shuffle buffer before encoding, or per-group dict like + ``packing_buffer_size``. + blend_mode: How leaf weights map to the inner :class:`~megatron.energon.BlendDataset`. + repeat: Whether inner datasets loop indefinitely. + + Returns: + The full train :class:`~megatron.energon.flavors.SavableDataset` pipeline. + """ + + # Check if there's a CrudeWebdataset but no cookers + for dataset in datasets: + if isinstance(dataset.dataset, CrudeWebdataset): + assert self.cookers, "CrudeWebdataset found, but no cookers registered." + + dataset = self._build_train_blend_shuffle_encode_groups( + datasets=datasets, + packing_buffer_size=packing_buffer_size, + blend_mode=blend_mode, + repeat=repeat, + shuffle_buffer_size=shuffle_buffer_size, + worker_config=worker_config, + ) dataset = self.build_batch( dataset, batch_size=batch_size, batch_drop_last=batch_drop_last, - packing_buffer_size=packing_buffer_size, worker_config=worker_config, ) + if virtual_epoch_length > 0: dataset = EpochizeDataset( dataset, @@ -893,31 +1044,13 @@ def build_train_datasets( return dataset - def build_val_datasets( + def _build_val_concat_encode_branch( self, - *, datasets: List[LoadedDataset], + worker_rotation_offsets: List[int], worker_config: WorkerConfig, - batch_size: int, - batch_drop_last: bool = False, - packing_buffer_size: Optional[int] = None, - limit: Optional[int] = None, - ) -> SavableDataset[T_batch]: - """Combines val datasets to a single dataset.""" - - # Check if there's a CrudeWebdataset but no cookers - for dataset in datasets: - if isinstance(dataset, CrudeWebdataset): - assert self.cookers, "CrudeWebdataset found, but no cookers registered." - - global_workers = max(1, worker_config.num_workers) * worker_config.world_size - rotation_lengths = [len(dataset.dataset) for dataset in datasets] - for i in range(1, len(rotation_lengths)): - rotation_lengths[i] += rotation_lengths[i - 1] - worker_rotation_offsets = [ - rotation_length % global_workers for rotation_length in [0] + rotation_lengths[:-1] - ] - + ) -> SavableDataset[T_encoded_sample]: + """Builds the (concat) → (preencode/encode) pipeline.""" if len(datasets) > 1: dataset = ConcatDataset( *[ @@ -930,12 +1063,93 @@ def build_val_datasets( dataset = self._load_dataset(datasets[0], worker_rotation_offsets[0], worker_config) else: raise ValueError("No datasets given.") - dataset = self.build_encode_sample(dataset, worker_config=worker_config) + return self.build_encode_sample(dataset, worker_config=worker_config) + + def _build_val_concat_encode_groups( + self, + datasets: List[LoadedDataset], + packing_buffer_size: Optional[int | dict[str | None, int | None]], + worker_config: WorkerConfig, + ) -> SavableDataset[T_encoded_sample]: + """Builds the validation pipeline with optional per-dataset-group isolation. + + Like :meth:`_build_train_blend_shuffle_encode_groups`, but concatenates leaves instead of + blending, and omits shuffle/repeat. Splits ``datasets`` by ``LoadedDataset.group``, applies + packing per group's ``packing_buffer_size`` entry, then concatenates group streams when needed. + + Pipeline per group: + ``concat loaded leaves → encode → select_samples_to_pack → postencode → pack_selected_samples``. + Multiple groups: ``(... per group ...) → concat``. + """ + rotation_offsets = self._compute_rotation_offsets(datasets, worker_config) + + dataset_groups: dict[Optional[str], tuple[list[LoadedDataset], list[int]]] = {} + for ld, ro in zip(datasets, rotation_offsets): + if ld.group in dataset_groups: + dataset_groups[ld.group][0].append(ld) + dataset_groups[ld.group][1].append(ro) + else: + dataset_groups[ld.group] = ([ld], [ro]) + + streams: List[SavableDataset[Any]] = [] + for group_key, (group_ds, rotation_offsets) in dataset_groups.items(): + branch = self._build_val_concat_encode_branch( + datasets=group_ds, + worker_rotation_offsets=rotation_offsets, + worker_config=worker_config, + ) + # Post-encode is included + branch = self._build_packing_postencode( + branch, + group=group_key, + packing_buffer_size=packing_buffer_size, + worker_config=worker_config, + ) + streams.append(branch) + if len(streams) > 1: + return ConcatDataset(*streams, worker_config=worker_config) + else: + return streams[0] + + def build_val_datasets( + self, + *, + datasets: List[LoadedDataset], + worker_config: WorkerConfig, + batch_size: int, + batch_drop_last: bool = False, + packing_buffer_size: Optional[int | dict[str | None, int | None]] = None, + limit: Optional[int] = None, + ) -> SavableDataset[T_batch]: + """Combines validation datasets into one batched dataset pipeline. + + Args: + datasets: Loaded leaf datasets (each may carry a ``group`` key when using Metadataset V2). + worker_config: Worker configuration for wrapped datasets. + batch_size: Batch dimension. + batch_drop_last: If true, drop the last batch when smaller than ``batch_size``. + packing_buffer_size: Packing buffer size, or per-group dict (keys include ``None`` for the + default group). + limit: If set and positive, caps the number of batches via :class:`~megatron.energon.LimitDataset`. + + Returns: + The full validation :class:`~megatron.energon.flavors.SavableDataset` pipeline. + """ + + # Check if there's a CrudeWebdataset but no cookers + for dataset in datasets: + if isinstance(dataset, CrudeWebdataset): + assert self.cookers, "CrudeWebdataset found, but no cookers registered." + + dataset = self._build_val_concat_encode_groups( + datasets=datasets, + packing_buffer_size=packing_buffer_size, + worker_config=worker_config, + ) dataset = self.build_batch( dataset, batch_size=batch_size, batch_drop_last=batch_drop_last, - packing_buffer_size=packing_buffer_size, worker_config=worker_config, ) if limit is not None and limit > 0: diff --git a/src/megatron/energon/task_encoder/loader.py b/src/megatron/energon/task_encoder/loader.py index 5c680bb2..4e3d0bc7 100644 --- a/src/megatron/energon/task_encoder/loader.py +++ b/src/megatron/energon/task_encoder/loader.py @@ -111,8 +111,8 @@ def get_train_dataset( worker_config: WorkerConfig, batch_size: Optional[int], batch_drop_last: bool = False, - packing_buffer_size: Optional[int] = None, - shuffle_buffer_size: Optional[int], + packing_buffer_size: Optional[int | dict[str | None, int | None]] = None, + shuffle_buffer_size: Optional[int | dict[str | None, int | None]], max_samples_per_sequence: Optional[int], virtual_epoch_length: int = 0, shuffle_over_epochs_multiplier: Optional[int] = 1, @@ -137,7 +137,12 @@ def get_train_dataset( worker_config: Worker configuration to use. batch_size: Size of a batch. If None, do not batch batch_drop_last: If true, drop the last batch if it is smaller than `batch_size`. - shuffle_buffer_size: Size of the sample shuffle buffer (before task encoding). + packing_buffer_size: Size of the packing buffer. If a dict, keys are dataset group names from + Metadataset V2 (YAML ``group``, merged along the path; ``None`` is the default group). + Values are buffer sizes, or ``None`` to disable packing for that group. + shuffle_buffer_size: Sample shuffle buffer size before task encoding. If a dict, keys are the + same dataset group names as for ``packing_buffer_size``; values are buffer sizes or ``None`` + to disable shuffling for that group. max_samples_per_sequence: If set, limit the number of samples per sample-sequence to this. virtual_epoch_length: If set, the dataset will be epochized to this length (=iterating will be suspended and the for-loop returns, next for-loop continues iterating). @@ -186,7 +191,7 @@ def get_val_dataset( worker_config: WorkerConfig, batch_size: int, batch_drop_last: bool = False, - packing_buffer_size: Optional[int] = None, + packing_buffer_size: Optional[int | dict[str | None, int | None]] = None, limit: Optional[int] = None, task_encoder: TaskEncoder[Any, Any, Any, T] = DefaultTaskEncoder(), **kwargs, @@ -208,6 +213,9 @@ def get_val_dataset( worker_config: Worker configuration to use. batch_size: Size of a batch batch_drop_last: If true, drop the last batch if it is smaller than `batch_size`. + packing_buffer_size: Size of the packing buffer. If a dict, keys are dataset group names from + Metadataset V2 (same as ``get_train_dataset``); values are buffer sizes or ``None`` to + disable packing per group. limit: If set, limit the number of batches loaded from the dataset to this. task_encoder: Task encoder to use. **kwargs: Additional arguments to the dataset constructor. @@ -241,7 +249,7 @@ def get_val_datasets( worker_config: WorkerConfig, batch_size: int, batch_drop_last: bool = False, - packing_buffer_size: Optional[int] = None, + packing_buffer_size: Optional[int | dict[str | None, int | None]] = None, limit: Optional[int] = None, task_encoder: TaskEncoder[Any, Any, Any, T] = DefaultTaskEncoder(), **kwargs, @@ -263,6 +271,9 @@ def get_val_datasets( worker_config: Worker configuration to use. batch_size: Size of a batch batch_drop_last: If true, drop the last batch if it is smaller than `batch_size`. + packing_buffer_size: Size of the packing buffer. If a dict, keys are dataset group names from + Metadataset V2 (same as ``get_train_dataset``); values are buffer sizes or ``None`` to + disable packing per group. limit: If set, limit the number of batches loaded from the dataset to this. task_encoder: Task encoder to use. **kwargs: Additional arguments to the dataset constructor. diff --git a/src/megatron/energon/wrappers/__init__.py b/src/megatron/energon/wrappers/__init__.py index 808ca50e..955e1b1d 100644 --- a/src/megatron/energon/wrappers/__init__.py +++ b/src/megatron/energon/wrappers/__init__.py @@ -18,7 +18,7 @@ generic_concat, homogeneous_concat_mix, ) -from megatron.energon.wrappers.packing_dataset import PackingDataset +from megatron.energon.wrappers.packing_dataset import PackedSamplesOutput, PackingDataset from megatron.energon.wrappers.repeat_dataset import RepeatDataset from megatron.energon.wrappers.shuffle_buffer_dataset import ShuffleBufferDataset from megatron.energon.wrappers.skip import SkipSample @@ -40,6 +40,7 @@ "ShuffleBufferDataset", "SkipSample", "PackingDataset", + "PackedSamplesOutput", "concat_pad", "generic_concat", "homogeneous_concat_mix", diff --git a/src/megatron/energon/wrappers/packing_dataset.py b/src/megatron/energon/wrappers/packing_dataset.py index ee10e6c2..a81ad4e4 100644 --- a/src/megatron/energon/wrappers/packing_dataset.py +++ b/src/megatron/energon/wrappers/packing_dataset.py @@ -3,6 +3,7 @@ import contextlib import inspect +from dataclasses import dataclass, field from typing import ( Any, Callable, @@ -12,6 +13,7 @@ Iterator, List, Optional, + Sequence, TypeVar, Union, ) @@ -31,6 +33,19 @@ T_batch_sample = TypeVar("T_batch_sample") +@dataclass(slots=True, frozen=True) +class PackedSamplesOutput(Generic[T_sample]): + """Return type for :meth:`~megatron.energon.TaskEncoder.select_samples_to_pack` when including + samples to push back onto the packing reading buffer.""" + + #: The packs of samples to be packed together. + # One entry per pack, each containing the samples to be packed together. + packs: list[list[T_sample]] + + #: The samples to push back onto the packing reading buffer for the next fill. + pushback: Sequence[T_sample] = field(default_factory=tuple) + + class PackingDataset( BaseWrapperDataset[T_sample, T_batch_sample], Generic[T_sample, T_encoded_sample, T_batch_sample], @@ -39,7 +54,7 @@ class PackingDataset( then combined into a batch.""" buffer_size: int - pre_packer: Callable[[List[T_sample]], List[List[T_sample]]] + pre_packer: Callable[[list[T_sample]], list[list[T_sample]] | PackedSamplesOutput[T_sample]] sample_encoder: Optional[Callable[[T_sample], T_encoded_sample]] sample_encoder_stateless: bool final_packer: Callable[[List[T_encoded_sample]], T_batch_sample] @@ -86,7 +101,9 @@ def __init__( self, dataset: SavableDataset[T_sample], buffer_size: int, - pre_packer: Callable[[List[T_sample]], List[List[T_sample]]], + pre_packer: Callable[ + [list[T_sample]], list[list[T_sample]] | PackedSamplesOutput[T_sample] + ], final_packer: Callable[[List[T_encoded_sample]], T_batch_sample], *, final_packer_stateless: bool = False, @@ -108,7 +125,8 @@ def __init__( dataset: The input dataset to wrap buffer_size: The desired size of the input buffer for pre packing. Last buffer of a dataset may be smaller. pre_packer: Function which selects samples from the buffer to be packed together. - May raise :exc:`megatron.energon.SkipSample` to skip a buffer. + May return :class:`PackedSamplesOutput` to include a ``pushback`` sequence + reapplied to the reading buffer. May raise :exc:`megatron.energon.SkipSample` to skip a buffer. final_packer: Function which combines the selected samples into a single sample. final_packer_stateless: If True, the final_packer is stateless, thus samples can be stored/restored. @@ -245,26 +263,35 @@ def next_pre_pack(): together.""" assert self._pre_packing_buffer.len_worker() == 0 - if self._reading_buffer.len_worker() > 0: - # Take all samples from the reading buffer and pre_pack them - samples = self._reading_buffer.buffer.copy() - # Clear buffer and pre_packing_lengths - self._reading_buffer.clear() - pre_packing_lengths.clear() - # Now pre pack the samples - pre_packs = [] - with self._pre_pack_failure_handler.handle_errors(samples): - with self._pre_packing_sample_index.ctx(): - pre_packs = self.pre_packer(samples) - - # Put the pre-packed samples into the pre_packing_buffer - # They will be flattened here to avoid nested buffers - # But the lengths of the groups are stored in pre_packing_lengths - # so that the groups can be separated later - for pre_pack in pre_packs: - if len(pre_pack) > 0: - self._pre_packing_buffer.extend(pre_pack) - pre_packing_lengths.append(len(pre_pack)) + if self._reading_buffer.len_worker() == 0: + return + + # Take all samples from the reading buffer and pre_pack them + samples = self._reading_buffer.buffer.copy() + # Clear buffer and pre_packing_lengths + self._reading_buffer.clear() + pre_packing_lengths.clear() + # Now pre pack the samples + pre_packs_result: list[list[T_sample]] | PackedSamplesOutput[T_sample] | None = None + with self._pre_pack_failure_handler.handle_errors(samples): + with self._pre_packing_sample_index.ctx(): + pre_packs_result = self.pre_packer(samples) + if pre_packs_result is None: + # The error handler may actually catch the error + return + # Put the pre-packed samples into the pre_packing_buffer + # They will be flattened here to avoid nested buffers + # But the lengths of the groups are stored in pre_packing_lengths + # so that the groups can be separated later + if isinstance(pre_packs_result, PackedSamplesOutput): + pre_packs = pre_packs_result.packs + self._reading_buffer.extend(list(pre_packs_result.pushback)) + else: + pre_packs = pre_packs_result + for pre_pack in pre_packs: + if len(pre_pack) > 0: + self._pre_packing_buffer.extend(pre_pack) + pre_packing_lengths.append(len(pre_pack)) def next_final_pack() -> Generator[T_batch_sample, None, None]: """Yield the next packs from the buffer. The final packer is called on the fly.""" @@ -367,19 +394,15 @@ def restore_sample(self, restore_key: Any) -> T_sample: # We need to store multiple indices to restore a batch. self.assert_can_restore() if inspect.isgeneratorfunction(self.final_packer): - id, pack_idx, pack_sub_idx, *pack_restore_keys = restore_key id, pack_idx, pack_sub_idx, *pack_restore_keys = restore_key assert id == type(self).__name__ else: - id, pack_idx, *pack_restore_keys = restore_key id, pack_idx, *pack_restore_keys = restore_key assert id == type(self).__name__ pack = [] for inner_idx in pack_restore_keys: if self.sample_encoder is not None: - id, sample_idx, *inner_idx = inner_idx - assert id == type(self).__name__ id, sample_idx, *inner_idx = inner_idx assert id == type(self).__name__ assert isinstance(sample_idx, int) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 01aa74ee..e795a39b 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -34,6 +34,7 @@ DefaultTaskEncoder, MapDataset, MixBatchDataset, + PackedSamplesOutput, Sample, SavableDataLoader, TaskEncoder, @@ -1561,6 +1562,96 @@ def pack_selected_samples( assert restored_sample_1.__key__ == samples[1].__key__ assert restored_sample_1.__restore_key__ == samples[1].__restore_key__ + def test_packing_pushback(self): + """Pushback must prepend deferred samples ahead of new iterator pulls. + + With ``buffer_size`` 5, the first full buffer follows the dataset iterator order (not + necessarily sorted ``__key__``). We pack only the first sample and push the rest back; the + next pre-pack must see the four deferred samples first, then one newly drawn sample — i.e. + the second buffer's keys must start with the first buffer's keys at indices ``[1:5]``. + """ + torch.manual_seed(42) + + buf = 5 + + class TestTaskEncoder(DefaultTaskEncoder): + def __init__(self): + super().__init__(raw_batch_type=CaptioningBatch) + self.buffer_keys_per_call: list[list[str]] = [] + self.first_fill_keys: list[str] | None = None + + @stateless + def encode_sample(self, sample: CaptioningSample) -> EncodedCaptioningSample: + return EncodedCaptioningSample.derive_from( + sample, + image=sample.image, + caption=torch.frombuffer(sample.caption.encode(), dtype=torch.uint8), + ) + + def select_samples_to_pack( + self, samples: list[EncodedCaptioningSample] + ) -> PackedSamplesOutput[EncodedCaptioningSample]: + keys = [s.__key__ for s in samples] + self.buffer_keys_per_call.append(keys.copy()) + call = len(self.buffer_keys_per_call) + if call == 1: + assert len(samples) == buf + self.first_fill_keys = keys.copy() + return PackedSamplesOutput( + packs=[samples[:1]], + pushback=tuple(samples[1:]), + ) + if call == 2: + assert len(samples) == buf + assert self.first_fill_keys is not None + assert keys[: buf - 1] == self.first_fill_keys[1:buf], ( + "Pushback did not appear before new samples: " + f"second_fill[:4]={keys[: buf - 1]!r} " + f"expected first_fill[1:5]={self.first_fill_keys[1:buf]!r}" + ) + return PackedSamplesOutput( + packs=[[s] for s in samples], + pushback=(), + ) + return PackedSamplesOutput( + packs=[[s] for s in samples], + pushback=(), + ) + + @stateless + def pack_selected_samples( + self, samples: list[EncodedCaptioningSample] + ) -> EncodedCaptioningSample: + return EncodedCaptioningSample( + __key__=",".join(sample.__key__ for sample in samples), + __restore_key__=(), + image=torch.stack([sample.image for sample in samples]), + caption=torch.cat([sample.caption for sample in samples]), + ) + + task_encoder = TestTaskEncoder() + loader = get_loader( + get_train_dataset( + self.dataset_path, + batch_size=2, + packing_buffer_size=buf, + worker_config=no_worker_config, + virtual_epoch_length=4, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + task_encoder=task_encoder, + ) + ) + + out = list(loader) + assert len(out) == 4 + assert task_encoder.first_fill_keys is not None + assert ( + task_encoder.buffer_keys_per_call[1][: buf - 1] == task_encoder.first_fill_keys[1:buf] + ) + restored = loader.restore_sample(out[0].__restore_key__) + assert restored.__key__ == out[0].__key__ + def test_group_batch(self): class GroupingTaskEncoder( TaskEncoder[CaptioningSample, CaptioningSample, CaptioningSample, CaptioningSample] diff --git a/tests/test_metadataset_v2.py b/tests/test_metadataset_v2.py index e33992f7..c37ff654 100644 --- a/tests/test_metadataset_v2.py +++ b/tests/test_metadataset_v2.py @@ -32,7 +32,7 @@ from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME from megatron.energon.metadataset.loader import prepare_metadataset, traverse_metadataset from megatron.energon.metadataset.loader_interface import DatasetBlendMode -from megatron.energon.task_encoder.base import DefaultTaskEncoder +from megatron.energon.task_encoder.base import DefaultTaskEncoder, stateless from megatron.energon.wrappers.watchdog_dataset import WatchdogDataset from tests.epath_s3_emulator import setup_s3_emulator @@ -257,6 +257,86 @@ def test_metadataset(self): assert len(Counter(train_order1)) == 110 assert all(48 <= v <= 52 for v in Counter(train_order1).values()) + def test_group(self): + """Group-specific shuffle + packing keeps returned samples source-homogeneous.""" + mds_path = self.dataset_path / "group_blend.yaml" + with open(mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " blend:", + " - weight: 1", + " path: ds1", + " group: alpha", + " subflavors:", + " packing_source: ds1", + " - weight: 1", + " path: ds2", + " group: beta", + " subflavors:", + " packing_source: ds2", + ] + ) + ) + + leaves = traverse_metadataset(mds_path, split_part="train") + assert len(leaves) == 2 + assert {ref.group for ref in leaves} == {"alpha", "beta"} + + worker_config = WorkerConfig(rank=0, world_size=1, num_workers=0, seed_offset=0) + loaded = load_dataset(mds_path).get_datasets( + training=True, + split_part="train", + worker_config=worker_config, + ) + assert loaded.blend_mode == DatasetBlendMode.DATASET_WEIGHT + assert {d.group for d in loaded.datasets} == {"alpha", "beta"} + + class GroupIsolationEncoder(DefaultTaskEncoder): + """Each returned packed sample must come from exactly one packing source.""" + + @stateless + def encode_sample(self, sample: TextSample) -> TextSample: + return sample + + def select_samples_to_pack(self, samples: list[TextSample]) -> list[list[TextSample]]: + return [samples] + + @stateless + def pack_selected_samples(self, samples: list[TextSample]) -> TextSample: + packing_sources = { + sample.__subflavors__.get("packing_source") + for sample in samples + if sample.__subflavors__ is not None + } + assert len(packing_sources) == 1, ( + "Mixed sources in one packed sample: " + f"{packing_sources} for keys {[sample.__key__ for sample in samples]}" + ) + return TextSample.derive_from( + samples[0], + __key__=",".join(s.__key__ for s in samples), + __restore_key__=(), + text=f"{next(iter(packing_sources))}:" + " | ".join(s.text for s in samples), + ) + + torch.manual_seed(42) + packed_ds = get_train_dataset( + mds_path, + worker_config=worker_config, + batch_size=2, + packing_buffer_size=8, + shuffle_buffer_size={"alpha": 8, "beta": None}, + max_samples_per_sequence=None, + task_encoder=GroupIsolationEncoder(), + virtual_epoch_length=10, + ) + list(get_loader(packed_ds)) + def test_nested_metadataset(self): torch.manual_seed(42) worker_config = WorkerConfig( @@ -375,9 +455,15 @@ def test_traverse_metadataset_preserves_missing_v2_leaf_and_aux(self): "splits:", " train:", " path: missing_ds", + " subflavors:", + " source: missing_leaf_metadataset_v2.yaml", + " number: 42", + " mds: nested_val", " aux:", " labels: missing_aux", " media: filesystem://media", + " group: abc", + " shuffle_over_epochs_multiplier: 2", ] ), encoding="utf-8", @@ -392,6 +478,13 @@ def test_traverse_metadataset_preserves_missing_v2_leaf_and_aux(self): "labels": EPath(self.dataset_path / "missing_aux"), "media": EPath(self.dataset_path / "media"), } + assert refs[0].subflavors == { + "source": "missing_leaf_metadataset_v2.yaml", + "number": 42, + "mds": "nested_val", + } + assert refs[0].group == "abc" + assert refs[0].shuffle_over_epochs_multiplier == 2 def test_joined_metadataset(self): torch.manual_seed(42)