Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/advanced/epochized_blending.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ SPDX-License-Identifier: BSD-3-Clause -->
As an alternative to blending with a weight for each dataset, blending can be made accurate and
iterating the dataset can follow epochs (i.e. interrupt iteration after an epoch) with this concept.

If you want **changing sampling proportions over training** (instead of an epoch-defined repetition scheme), see scheduled blend weights in [](../basic/metadataset).

Here is an example `metadataset.yaml` config file that changes to epochized blending:

```yaml
Expand Down
42 changes: 42 additions & 0 deletions docs/source/basic/metadataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,48 @@ splits:
In the above example, we create a blend of three datasets. Out of the yielded training samples, 62.5% ({math}`=\frac{5}{8}`) will come from `./coco`, 25% from `./coyo` and 12.5% from `./other`.
Note that the relative paths in the metadataset are relative to the location of the metadataset file. Absolute paths are allowed but won't work for object storage.

## Scheduled weights

Instead of using a constant number for `weight:`, you can use a schedule to change sampling proportions over training (based on the **batch index**).

Example (step schedule):

```yaml
splits:
train:
blend:
- weight: 1
path: ./ds1
- weight:
step:
0: 100
1500: 10
3000: 0
path: ./ds2
```

Example (linear schedule):

```yaml
splits:
train:
blend:
- weight: 1
path: ./ds1
- weight:
linear:
0: 100
1500: 10
3000: 0
path: ./ds2
```

Notes:
- The schedule x-axis is the **batch index** (how many batches have been yielded by the loader on that rank).
- `step` uses the last point with key \(\le\) the current batch index (knot points are inclusive).
- `linear` linearly interpolates between points and clamps outside the endpoints.
- You may combine scheduled and constant weights across nested metadatasets (*schedule × constant*), but **schedule × schedule** along one path is not supported.

To use the metadataset in your loader, simply load it with {py:func}`get_train_dataset <megatron.energon.get_train_dataset>` instead of a normal energon dataset:
```python
from megatron.energon import get_train_dataset
Expand Down
2 changes: 2 additions & 0 deletions docs/source/basic/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ splits:

This assumes, that the datasets `coyo` and `coco` exist in subfolders next to the `coyo-coco-dataset.yaml` file. You could also use absolute paths, but that will not work well when using object storage such as S3.

You can also use **scheduled weights** (step/linear) to change sampling proportions over training. See [](metadataset) for details and examples.

Usage in your loader, simply use {py:func}`get_train_dataset <megatron.energon.get_train_dataset>`:
```python
from megatron.energon import get_train_dataset, get_loader, WorkerConfig
Expand Down
6 changes: 6 additions & 0 deletions docs/source/basic/save_restore.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,9 @@ See [](../advanced/error_handling) for details.

In each of these scenarios, ensure that the logic for saving and restoring the state is appropriately synchronized across ranks to maintain consistency.
If you encounter torch distributed errors, likely torch distributed calls are out of sync, or not all ranks are called correctly. If unsure, debug using the first scenario, saving each rank separately.

```{admonition} Scheduled blend weights
:class: note
If you use scheduled metadataset blend weights (e.g. `weight: {step: ...}` / `weight: {linear: ...}`), they remain deterministic across save/restore.
Schedules are evaluated from the per-rank batch index embedded in the loader state, and the blend RNG state is checkpointed as part of the dataset state.
```
3 changes: 2 additions & 1 deletion src/megatron/energon/metadataset/loader_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from megatron.energon.epathlib import EPath
from megatron.energon.flavors.base_dataset import BaseCoreDatasetFactory
from megatron.energon.flavors.webdataset.structs import DatasetSubset
from megatron.energon.weights import WeightSpec
from megatron.energon.worker import WorkerConfig


Expand All @@ -27,7 +28,7 @@ class DatasetBlendMode(Enum):
@edataclass
class LoadedDataset:
dataset: BaseCoreDatasetFactory
weight: Union[float, int, None] = None
weight: Union[float, int, WeightSpec, None] = None
repetitions: Union[float, int, None] = None
aux: Optional[Dict[str, FileStore]] = None

Expand Down
9 changes: 7 additions & 2 deletions src/megatron/energon/metadataset/metadataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
DatasetLoaderInterface,
LoadedDatasetList,
)
from megatron.energon.weights import compose_weights, weight_from_config
from megatron.energon.worker import WorkerConfig


Expand Down Expand Up @@ -128,6 +129,7 @@ def get_datasets(
**kwargs,
) -> LoadedDatasetList:
sum_weight = sum(dataset.weight for dataset in self.datasets)
assert sum_weight > 0, "Sum of blend weights must be > 0"
datasets = []
for dataset in self.datasets:
inner_result = dataset.get_datasets(
Expand All @@ -148,11 +150,14 @@ def get_datasets(
)
for loaded_dataset in inner_result.datasets:
if inner_result.blend_mode == DatasetBlendMode.DATASET_WEIGHT:
assert isinstance(loaded_dataset.weight, float)
assert loaded_dataset.weight is not None
else:
assert loaded_dataset.weight is None
loaded_dataset.weight = 1.0
loaded_dataset.weight = loaded_dataset.weight * dataset.weight / sum_weight
loaded_dataset.weight = compose_weights(
weight_from_config(loaded_dataset.weight),
float(dataset.weight) / float(sum_weight),
)
datasets.append(loaded_dataset)
return LoadedDatasetList(
blend_mode=DatasetBlendMode.DATASET_WEIGHT,
Expand Down
39 changes: 34 additions & 5 deletions src/megatron/energon/metadataset/metadataset_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@
LoadedDatasetList,
)
from megatron.energon.metadataset.metadataset import Metadataset
from megatron.energon.weights import (
WeightConfig,
WeightLike,
compose_weights,
make_node_entry_weights,
weight_from_config,
)
from megatron.energon.worker import WorkerConfig

# Regex for any URL-like string (any protocol)
Expand Down Expand Up @@ -398,7 +405,23 @@ def get_datasets(

@dataclass
class BlendWeightMixin:
weight: float = 1.0
"""Mixin for metadataset blend entries that carry a sampling weight.

`weight` may be a constant number or a schedule mapping:
- `weight: 5`
- `weight: {step: {0: 100, 1500: 10, 3000: 0}}`
- `weight: {linear: {0: 100, 1500: 10, 3000: 0}}`

Schedules are evaluated per-rank batch index and are intended to be deterministic and
checkpoint/resume-safe.
"""

# Supports either a constant numeric weight or a schedule like:
# weight:
# step:
# 0: 100
# 1000: 0
weight: WeightConfig = 1.0


@edataclass
Expand Down Expand Up @@ -440,9 +463,12 @@ def get_datasets(
**kwargs,
) -> LoadedDatasetList:
subset = self._get_subset(subset)
sum_weight = sum(dataset.weight for dataset in self.blend)
# Node-level weights must be normalized at this node to preserve hierarchical blend semantics.
entry_weights: List[WeightLike] = make_node_entry_weights(
[dataset.weight for dataset in self.blend]
)
datasets = []
for dataset in self.blend:
for dataset, entry_weight in zip(self.blend, entry_weights):
inner_result = dataset.get_datasets(
training=training,
split_part=split_part,
Expand All @@ -461,13 +487,16 @@ def get_datasets(
)
for loaded_dataset in inner_result.datasets:
if inner_result.blend_mode == DatasetBlendMode.DATASET_WEIGHT:
assert isinstance(loaded_dataset.weight, float)
assert loaded_dataset.weight is not None
else:
assert inner_result.blend_mode == DatasetBlendMode.NONE
assert loaded_dataset.weight is None
assert loaded_dataset.repetitions is None
loaded_dataset.weight = 1.0
loaded_dataset.weight = loaded_dataset.weight * dataset.weight / sum_weight
loaded_dataset.weight = compose_weights(
weight_from_config(loaded_dataset.weight),
entry_weight,
)
datasets.append(loaded_dataset)
return LoadedDatasetList(
blend_mode=DatasetBlendMode.DATASET_WEIGHT,
Expand Down
2 changes: 1 addition & 1 deletion src/megatron/energon/task_encoder/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,7 @@ def build_train_datasets(
),
worker_config=worker_config,
),
1.0 if dataset.weight is None else float(dataset.weight),
1.0 if dataset.weight is None else dataset.weight,
)
for dataset, worker_rotation_offset in zip(datasets, worker_rotation_offsets)
]
Expand Down
Loading