Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 130 additions & 54 deletions modelopt/torch/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,77 +110,148 @@
]


def _normalize_splits(split: str | list[str]) -> list[str]:
"""Ensure split is always a list."""
return [split] if isinstance(split, str) else list(split)


def _auto_preprocess_sample(
sample: dict,
dataset_name: str,
tokenizer: "PreTrainedTokenizerBase | None" = None,
) -> str:
"""Auto-detect dataset format and preprocess a single sample based on column conventions.

Column detection order (first match wins):
1. ``messages`` / ``conversations`` -> ``tokenizer.apply_chat_template`` (with ``tools`` if present)
2. ``prompt`` (+ optional ``completion`` / ``response`` / ``output``) -> concatenate
3. ``text`` -> use as-is
4. ``input`` (+ optional ``output``) -> concatenate

Raises:
ValueError: If the tokenizer is missing/incompatible for chat-format datasets,
or if no recognized column is found.
"""
chat_key = next((k for k in ("messages", "conversations") if sample.get(k)), None)
if chat_key is not None:
if tokenizer is None or not hasattr(tokenizer, "apply_chat_template"):
raise ValueError(
f"Dataset '{dataset_name}' has a '{chat_key}' column but no tokenizer with "
"apply_chat_template was provided."
)
kwargs: dict[str, Any] = {}
tools = sample.get("tools")
if tools:
kwargs["tools"] = tools
return tokenizer.apply_chat_template(sample[chat_key], tokenize=False, **kwargs)

if "prompt" in sample:
parts = [sample["prompt"]]
parts.extend(sample[k] for k in ("completion", "response", "output") if sample.get(k))
return "\n".join(parts)

if "text" in sample:
return sample["text"]

if "input" in sample:
parts = [sample["input"]]
if sample.get("output"):
parts.append(sample["output"])
return "\n".join(parts)

raise ValueError(
f"Cannot auto-detect format for dataset '{dataset_name}'. "
f"Found columns: {list(sample.keys())}. "
"Expected one of: 'messages', 'conversations', 'prompt', 'text', or 'input'."
)


def get_dataset_samples(
dataset_name: str,
num_samples: int,
*,
apply_chat_template: bool = False,
tokenizer: "PreTrainedTokenizerBase | None" = None,
split: str | list[str] | None = None,
) -> list[str]:
"""Load a portion of train dataset with the dataset name and a given size.
"""Load a portion of a dataset with the dataset name and a given size.

Supports both registered datasets (in ``SUPPORTED_DATASET_CONFIG``) and arbitrary
HuggingFace datasets. Unregistered datasets are auto-detected by column names:
``messages``/``conversations`` (chat), ``prompt``, ``text``, or ``input``.

Args:
dataset_name: Name of the dataset to load.
dataset_name: Name or HuggingFace path of the dataset to load.
num_samples: Number of samples to load from the dataset.
apply_chat_template: Whether to apply the chat template to the samples (if supported by the dataset).
apply_chat_template: Whether to apply the chat template to the samples
(if supported by the dataset). For unregistered datasets with a
``messages`` column, chat template is always applied regardless of
this flag.
tokenizer: Tokenizer to use for applying the chat template to the samples.
No tokenization is done and plain text is still returned.
split: Override the split(s) to load. Accepts a single split name or a list.
If ``None``, uses the splits defined in ``SUPPORTED_DATASET_CONFIG`` for
registered datasets, or ``["train"]`` for unregistered datasets.

Returns:
Samples: The list of samples.
"""
# Load the dataset
if dataset_name not in SUPPORTED_DATASET_CONFIG:
raise NotImplementedError(
f"dataset {dataset_name} is not supported. Please use one of the following:"
f" {get_supported_datasets()}."
)

from datasets import load_dataset

dataset_config = SUPPORTED_DATASET_CONFIG[dataset_name]
if apply_chat_template:
if "chat_key" not in dataset_config:
warn(
f"Dataset {dataset_name} does not support chat template. Chat template will not be applied."
)
elif tokenizer is None:
raise ValueError("Tokenizer is required when applying chat template.")
print(f"Applying chat template to dataset {dataset_name}")

# It's unfortunate that the load_dataset function does not support split a list while streaming.
# So we need to load the dataset for each split.
config = dataset_config["config"].copy()
splits = config.pop("split", [None])
dataset_splits = [
load_dataset(
streaming=True,
**config,
split=split,
)
for split in splits
]

# Split the samples evenly across the splits
# For streaming datasets, there is no reliable way to get the number of samples in each split
# other than loading the entire dataset. So, we just use the same number of samples for each split.
num_samples_splits = [num_samples // len(dataset_splits) for _ in dataset_splits]
num_samples_splits[-1] += num_samples - sum(num_samples_splits)
samples = []
for dataset, num_samples_split in zip(dataset_splits, num_samples_splits):
for i, sample in enumerate(dataset):
if i >= num_samples_split:
break
is_registered = dataset_name in SUPPORTED_DATASET_CONFIG

if is_registered:
dataset_config = SUPPORTED_DATASET_CONFIG[dataset_name]
config = dataset_config["config"].copy()
splits = _normalize_splits(split) if split is not None else config.pop("split", [None])
if split is not None:
config.pop("split", None)

if apply_chat_template:
if "chat_key" not in dataset_config:
warn(
f"Dataset {dataset_name} does not support chat template."
" Chat template will not be applied."
)
elif tokenizer is None:
raise ValueError("Tokenizer is required when applying chat template.")

# Apply preprocess function to the sample
def _preprocess(sample: dict) -> str:
if apply_chat_template and "chat_key" in dataset_config:
sample = tokenizer.apply_chat_template( # type: ignore[union-attr]
sample[dataset_config["chat_key"]], tokenize=False
kwargs: dict[str, Any] = {}
tools = sample.get("tools")
if tools:
kwargs["tools"] = tools
return tokenizer.apply_chat_template( # type: ignore[union-attr]
sample[dataset_config["chat_key"]], tokenize=False, **kwargs
)
else:
sample = dataset_config["preprocess"](sample)
if sample != "": # wikitext has some empty samples
samples.append(sample)
return dataset_config["preprocess"](sample)

else:
warn(
f"Dataset '{dataset_name}' is not in SUPPORTED_DATASET_CONFIG. "
"Auto-detecting format from column names."
)
config = {"path": dataset_name}
splits = _normalize_splits(split) if split is not None else ["train"]

def _preprocess(sample: dict) -> str:
return _auto_preprocess_sample(sample, dataset_name, tokenizer)

# load_dataset does not support a list of splits while streaming, so load each separately.
dataset_splits = [load_dataset(streaming=True, **config, split=s) for s in splits]

num_per_split = [num_samples // len(dataset_splits)] * len(dataset_splits)
num_per_split[-1] += num_samples - sum(num_per_split)

samples: list[str] = []
for dataset, n in zip(dataset_splits, num_per_split):
for i, sample in enumerate(dataset):
if i >= n:
break
text = _preprocess(sample)
if text:
samples.append(text)

return samples

Expand Down Expand Up @@ -208,20 +279,23 @@ def get_dataset_dataloader(
max_sample_length: int = 512,
device: torch.device | None = None,
include_labels: bool = False,
apply_chat_template: bool = False,
) -> DataLoader:
"""Get a dataloader with the dataset name and toknizer of the target model.
"""Get a dataloader with the dataset name and tokenizer of the target model.

Args:
dataset_name: Name of the dataset to load.
tokenizer: Instancne of Hugginface tokenizer.
tokenizer: Instance of HuggingFace tokenizer.
batch_size: Batch size of the returned dataloader.
num_samples: Number of samples from the dataset.
max_sample_length: Maximum length of a sample.
device: Target device for the returned dataloader.
include_labels: Whether to include labels in the dataloader.
apply_chat_template: Whether to apply the chat template to the samples
(if supported by the dataset).

Returns:
A instance of dataloader.
An instance of dataloader.
"""
assert tokenizer is not None, "Please provide a tokenizer."
# batch_encode_plus will modify the tokenizer in place, so we need to clone it.
Expand All @@ -244,7 +318,9 @@ def get_dataset_dataloader(

all_samples = []
for ds_name, num_sample in zip(dataset_name, num_samples):
samples = get_dataset_samples(ds_name, num_sample)
samples = get_dataset_samples(
ds_name, num_sample, apply_chat_template=apply_chat_template, tokenizer=tokenizer
)
all_samples.extend(samples)

batch_encoded = tokenizer.batch_encode_plus(
Expand Down