From 03eeeb41e93bc44823637be6331f107cc3f690e9 Mon Sep 17 00:00:00 2001 From: michaelfeil <63565275+michaelfeil@users.noreply.github.com> Date: Sat, 21 Feb 2026 18:56:20 -0800 Subject: [PATCH] squash all 4 commits in 1 Signed-off-by: michaelfeil <63565275+michaelfeil@users.noreply.github.com> --- modelopt/torch/utils/dataset_utils.py | 184 ++++++++++++++++++-------- 1 file changed, 130 insertions(+), 54 deletions(-) diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 16bff49c2..0da8ad726 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -110,77 +110,148 @@ ] +def _normalize_splits(split: str | list[str]) -> list[str]: + """Ensure split is always a list.""" + return [split] if isinstance(split, str) else list(split) + + +def _auto_preprocess_sample( + sample: dict, + dataset_name: str, + tokenizer: "PreTrainedTokenizerBase | None" = None, +) -> str: + """Auto-detect dataset format and preprocess a single sample based on column conventions. + + Column detection order (first match wins): + 1. ``messages`` / ``conversations`` -> ``tokenizer.apply_chat_template`` (with ``tools`` if present) + 2. ``prompt`` (+ optional ``completion`` / ``response`` / ``output``) -> concatenate + 3. ``text`` -> use as-is + 4. ``input`` (+ optional ``output``) -> concatenate + + Raises: + ValueError: If the tokenizer is missing/incompatible for chat-format datasets, + or if no recognized column is found. + """ + chat_key = next((k for k in ("messages", "conversations") if sample.get(k)), None) + if chat_key is not None: + if tokenizer is None or not hasattr(tokenizer, "apply_chat_template"): + raise ValueError( + f"Dataset '{dataset_name}' has a '{chat_key}' column but no tokenizer with " + "apply_chat_template was provided." + ) + kwargs: dict[str, Any] = {} + tools = sample.get("tools") + if tools: + kwargs["tools"] = tools + return tokenizer.apply_chat_template(sample[chat_key], tokenize=False, **kwargs) + + if "prompt" in sample: + parts = [sample["prompt"]] + parts.extend(sample[k] for k in ("completion", "response", "output") if sample.get(k)) + return "\n".join(parts) + + if "text" in sample: + return sample["text"] + + if "input" in sample: + parts = [sample["input"]] + if sample.get("output"): + parts.append(sample["output"]) + return "\n".join(parts) + + raise ValueError( + f"Cannot auto-detect format for dataset '{dataset_name}'. " + f"Found columns: {list(sample.keys())}. " + "Expected one of: 'messages', 'conversations', 'prompt', 'text', or 'input'." + ) + + def get_dataset_samples( dataset_name: str, num_samples: int, *, apply_chat_template: bool = False, tokenizer: "PreTrainedTokenizerBase | None" = None, + split: str | list[str] | None = None, ) -> list[str]: - """Load a portion of train dataset with the dataset name and a given size. + """Load a portion of a dataset with the dataset name and a given size. + + Supports both registered datasets (in ``SUPPORTED_DATASET_CONFIG``) and arbitrary + HuggingFace datasets. Unregistered datasets are auto-detected by column names: + ``messages``/``conversations`` (chat), ``prompt``, ``text``, or ``input``. Args: - dataset_name: Name of the dataset to load. + dataset_name: Name or HuggingFace path of the dataset to load. num_samples: Number of samples to load from the dataset. - apply_chat_template: Whether to apply the chat template to the samples (if supported by the dataset). + apply_chat_template: Whether to apply the chat template to the samples + (if supported by the dataset). For unregistered datasets with a + ``messages`` column, chat template is always applied regardless of + this flag. tokenizer: Tokenizer to use for applying the chat template to the samples. No tokenization is done and plain text is still returned. + split: Override the split(s) to load. Accepts a single split name or a list. + If ``None``, uses the splits defined in ``SUPPORTED_DATASET_CONFIG`` for + registered datasets, or ``["train"]`` for unregistered datasets. Returns: Samples: The list of samples. """ - # Load the dataset - if dataset_name not in SUPPORTED_DATASET_CONFIG: - raise NotImplementedError( - f"dataset {dataset_name} is not supported. Please use one of the following:" - f" {get_supported_datasets()}." - ) - from datasets import load_dataset - dataset_config = SUPPORTED_DATASET_CONFIG[dataset_name] - if apply_chat_template: - if "chat_key" not in dataset_config: - warn( - f"Dataset {dataset_name} does not support chat template. Chat template will not be applied." - ) - elif tokenizer is None: - raise ValueError("Tokenizer is required when applying chat template.") - print(f"Applying chat template to dataset {dataset_name}") - - # It's unfortunate that the load_dataset function does not support split a list while streaming. - # So we need to load the dataset for each split. - config = dataset_config["config"].copy() - splits = config.pop("split", [None]) - dataset_splits = [ - load_dataset( - streaming=True, - **config, - split=split, - ) - for split in splits - ] - - # Split the samples evenly across the splits - # For streaming datasets, there is no reliable way to get the number of samples in each split - # other than loading the entire dataset. So, we just use the same number of samples for each split. - num_samples_splits = [num_samples // len(dataset_splits) for _ in dataset_splits] - num_samples_splits[-1] += num_samples - sum(num_samples_splits) - samples = [] - for dataset, num_samples_split in zip(dataset_splits, num_samples_splits): - for i, sample in enumerate(dataset): - if i >= num_samples_split: - break + is_registered = dataset_name in SUPPORTED_DATASET_CONFIG + + if is_registered: + dataset_config = SUPPORTED_DATASET_CONFIG[dataset_name] + config = dataset_config["config"].copy() + splits = _normalize_splits(split) if split is not None else config.pop("split", [None]) + if split is not None: + config.pop("split", None) + + if apply_chat_template: + if "chat_key" not in dataset_config: + warn( + f"Dataset {dataset_name} does not support chat template." + " Chat template will not be applied." + ) + elif tokenizer is None: + raise ValueError("Tokenizer is required when applying chat template.") - # Apply preprocess function to the sample + def _preprocess(sample: dict) -> str: if apply_chat_template and "chat_key" in dataset_config: - sample = tokenizer.apply_chat_template( # type: ignore[union-attr] - sample[dataset_config["chat_key"]], tokenize=False + kwargs: dict[str, Any] = {} + tools = sample.get("tools") + if tools: + kwargs["tools"] = tools + return tokenizer.apply_chat_template( # type: ignore[union-attr] + sample[dataset_config["chat_key"]], tokenize=False, **kwargs ) - else: - sample = dataset_config["preprocess"](sample) - if sample != "": # wikitext has some empty samples - samples.append(sample) + return dataset_config["preprocess"](sample) + + else: + warn( + f"Dataset '{dataset_name}' is not in SUPPORTED_DATASET_CONFIG. " + "Auto-detecting format from column names." + ) + config = {"path": dataset_name} + splits = _normalize_splits(split) if split is not None else ["train"] + + def _preprocess(sample: dict) -> str: + return _auto_preprocess_sample(sample, dataset_name, tokenizer) + + # load_dataset does not support a list of splits while streaming, so load each separately. + dataset_splits = [load_dataset(streaming=True, **config, split=s) for s in splits] + + num_per_split = [num_samples // len(dataset_splits)] * len(dataset_splits) + num_per_split[-1] += num_samples - sum(num_per_split) + + samples: list[str] = [] + for dataset, n in zip(dataset_splits, num_per_split): + for i, sample in enumerate(dataset): + if i >= n: + break + text = _preprocess(sample) + if text: + samples.append(text) return samples @@ -208,20 +279,23 @@ def get_dataset_dataloader( max_sample_length: int = 512, device: torch.device | None = None, include_labels: bool = False, + apply_chat_template: bool = False, ) -> DataLoader: - """Get a dataloader with the dataset name and toknizer of the target model. + """Get a dataloader with the dataset name and tokenizer of the target model. Args: dataset_name: Name of the dataset to load. - tokenizer: Instancne of Hugginface tokenizer. + tokenizer: Instance of HuggingFace tokenizer. batch_size: Batch size of the returned dataloader. num_samples: Number of samples from the dataset. max_sample_length: Maximum length of a sample. device: Target device for the returned dataloader. include_labels: Whether to include labels in the dataloader. + apply_chat_template: Whether to apply the chat template to the samples + (if supported by the dataset). Returns: - A instance of dataloader. + An instance of dataloader. """ assert tokenizer is not None, "Please provide a tokenizer." # batch_encode_plus will modify the tokenizer in place, so we need to clone it. @@ -244,7 +318,9 @@ def get_dataset_dataloader( all_samples = [] for ds_name, num_sample in zip(dataset_name, num_samples): - samples = get_dataset_samples(ds_name, num_sample) + samples = get_dataset_samples( + ds_name, num_sample, apply_chat_template=apply_chat_template, tokenizer=tokenizer + ) all_samples.extend(samples) batch_encoded = tokenizer.batch_encode_plus(