diff --git a/examples/megatron_bridge/README.md b/examples/megatron_bridge/README.md index 385cbcb65..b43e2c540 100644 --- a/examples/megatron_bridge/README.md +++ b/examples/megatron_bridge/README.md @@ -43,7 +43,7 @@ Once inside the container, you need to login with your HuggingFace token to down Note that the default dataset for pruning and quantization is [`nemotron-post-training-dataset-v2`](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2), which is gated. ```bash -huggingface-cli login --token +hf auth login --token ``` ## Pruning @@ -97,23 +97,40 @@ The [distill.py](distill.py) script loads student and teacher models from Huggin ### Data Preparation The distillation script expects pre-tokenized data in Megatron's binary format (`.bin` / `.idx` files). -You can tokenize your JSONL dataset using the following function: - -```python -from modelopt.torch.utils.plugins import megatron_preprocess_data - -megatron_preprocess_data( - input_path="/path/to/your/data.jsonl", - output_dir="/path/to/tokenized/data", - tokenizer_name_or_path="Qwen/Qwen3-0.6B", - json_keys=["text"], # change to your JSON key if needed - workers=32, - log_interval=100000, - max_sequence_length=256000, # To avoid rare OOM errors if text is too long -) + +You can tokenize your JSONL datasets using the following command: + +```bash +python -m modelopt.torch.utils.plugins.megatron_preprocess_data \ + --jsonl_paths /path/to/data1.jsonl /path/to/data2.jsonl ... \ + --json_keys text \ + --tokenizer Qwen/Qwen3-0.6B \ + --output_dir /path/to/tokenized/data/qwen3 \ + --workers 32 \ + --max_sequence_length 256_000 +``` + +Instead of `--jsonl_paths`, you can also pass a directory path to the `--input_dir` argument to tokenize all JSONL files in the directory. +We are setting a maximum sequence length of 256k to avoid rare OOM errors in tokenization if text is too long. + +If you want to download and tokenize a dataset from Hugging Face Hub directly, you can use the following command: + +```bash +python -m modelopt.torch.utils.plugins.megatron_preprocess_data \ + --hf_dataset nvidia/Nemotron-Pretraining-SFT-v1 \ + --hf_name Nemotron-SFT-General \ + --hf_split train \ + --hf_max_samples_per_split 10_000_000 \ + --json_keys text \ + --tokenizer Qwen/Qwen3-0.6B \ + --output_dir /path/to/tokenized/data/qwen3 \ + --workers 32 \ + --max_sequence_length 256_000 ``` -If you have multiple JSONL files, you can tokenize them one by one and pass all the paths to the `--data_paths` argument. +If you skip `--hf_name`, it will download and tokenize all subsets for the dataset. +If you skip `--hf_split`, it will download and tokenize all splits for the subset. +If you skip `--hf_max_samples_per_split`, it will download and tokenize all samples for the split. ### Distillation with Real Data @@ -124,7 +141,7 @@ torchrun --nnodes 1 --nproc_per_node 8 distill.py \ --tp_size 8 \ --teacher_hf_path Qwen/Qwen3-8B \ --student_hf_path Qwen/Qwen3-4B \ - --data_paths 1.0 /path/to/tokenized/data \ + --data_paths 1.0 /path/to/tokenized/data/qwen3 \ --data_path_to_cache /path/to/cache/dataset_indices_qwen3 \ --seq_length 8192 \ --mbs 1 \ diff --git a/examples/megatron_bridge/distill.py b/examples/megatron_bridge/distill.py index 31f1cfc71..8fc5cff6f 100644 --- a/examples/megatron_bridge/distill.py +++ b/examples/megatron_bridge/distill.py @@ -163,7 +163,7 @@ def _build_model_provider(hf_path): lr_warmup_iters=args.lr_warmup_iters, max_lr=args.lr, min_lr=args.min_lr, - adam_beta2=0.98, + adam_beta2=0.95, ) # Build dataset config @@ -227,7 +227,7 @@ def _build_model_provider(hf_path): save_interval=args.eval_interval, save=checkpoint_dir, load=checkpoint_dir, # Resume from this directory (if exists) - most_recent_k=3, # Keeps 3 most recent checkpoints (not metric-based) + most_recent_k=5, # Keeps 5 most recent checkpoints (not metric-based) ckpt_format="torch_dist", async_save=True, fully_parallel_save=True, @@ -238,7 +238,9 @@ def _build_model_provider(hf_path): print_rank_0("\nStarting distillation...") distill(config) - print_rank_0(f"\nDistillation done! Saved checkpoint to {checkpoint_dir}\n") + print_rank_0( + f"\nDistillation done! Saved checkpoint to {checkpoint_dir} in megatron distributed checkpoint format.\n" + ) if __name__ == "__main__": diff --git a/examples/nemo_run/common/process_climbmix.py b/examples/nemo_run/common/process_climbmix.py index 18fd35f2d..a6f91cc11 100644 --- a/examples/nemo_run/common/process_climbmix.py +++ b/examples/nemo_run/common/process_climbmix.py @@ -67,7 +67,7 @@ def get_args(): print("Tokenizing ClimbMix dataset...") input_paths = [raw_dir / name for name in subset_filenames] megatron_preprocess_data( - input_paths, + jsonl_paths=input_paths, output_dir=proc_dir, tokenizer_name_or_path=args.tokenizer, append_eod=True, diff --git a/modelopt/torch/prune/plugins/mcore_minitron.py b/modelopt/torch/prune/plugins/mcore_minitron.py index 3d476ec8f..6470776e7 100644 --- a/modelopt/torch/prune/plugins/mcore_minitron.py +++ b/modelopt/torch/prune/plugins/mcore_minitron.py @@ -317,6 +317,7 @@ def run_search(self) -> None: # Prune homogeneously self._prune(export_config, prune_depth=True) + # TODO: Rename to hybrid_layer_pattern after https://github.com/NVIDIA/Megatron-LM/pull/3377 # Update hybrid_override_pattern if pruning is done on a hybrid model if isinstance(self.model, MambaModel): print_rank_0(f"Original hybrid_override_pattern: {self.model.hybrid_override_pattern}") diff --git a/modelopt/torch/utils/plugins/megatron_preprocess_data.py b/modelopt/torch/utils/plugins/megatron_preprocess_data.py index 92ea4bd51..ea1dd0697 100644 --- a/modelopt/torch/utils/plugins/megatron_preprocess_data.py +++ b/modelopt/torch/utils/plugins/megatron_preprocess_data.py @@ -17,28 +17,52 @@ """Processing large data to tokenize for pretraining. -Usage: +Usage to tokenize one or more JSONL files: + +```bash +python -m modelopt.torch.utils.plugins.megatron_preprocess_data \ + --jsonl_paths path/to/input/data1.jsonl path/to/input/data2.jsonl ... \ + --json_keys text \ + --output_dir /path/to/tokenized/Qwen3/ \ + --tokenizer Qwen/Qwen3-0.6B +``` -```python -from modelopt.torch.utils.plugins import megatron_preprocess_data +Usage to tokenize all JSONL files in a directory: -megatron_preprocess_data( - input_path="path/to/input/data", - output_dir="path/to/output/dir", - tokenizer_name_or_path="hf_model_name", - json_keys=["name of json key(s) to tokenize"], -) +```bash +python -m modelopt.torch.utils.plugins.megatron_preprocess_data \ + --input_dir /path/to/input/data/ \ + --json_keys text \ + --output_dir /path/to/tokenized/Qwen3/ \ + --tokenizer Qwen/Qwen3-0.6B ``` + +Usage to download and tokenize a dataset from Hugging Face Hub: + +```bash +python -m modelopt.torch.utils.plugins.megatron_preprocess_data \ + --hf_dataset nvidia/Nemotron-Pretraining-Dataset-sample \ + --hf_name Nemotron-SFT-Code \ + --hf_split train \ + --json_keys text \ + --tokenizer Qwen/Qwen3-0.6B \ + --output_dir /path/to/tokenized/Qwen3/ +``` + +NOTE: If you skip --hf_name, it will download and tokenize all subsets for the dataset. +If you skip --hf_split, it will download and tokenize all splits for the subset. """ import argparse import json import multiprocessing -import sys +import os from pathlib import Path +from warnings import warn import requests from datasets import load_dataset +from huggingface_hub.utils import build_hf_headers from megatron.core.datasets import indexed_dataset from transformers import AutoTokenizer @@ -108,11 +132,12 @@ def __init__(self, vocab_size: int, json_keys: list[str], log_interval: int, wor self.log_interval = log_interval self.workers = workers - def _print_processing_stats(self, count: int, total_doc_len: int, total_enc_len: int): - if count % self.log_interval == 0: + def _print_processing_stats( + self, count: int, total_doc_len: int, total_enc_len: int, *, force_print: bool = False + ): + if count % self.log_interval == 0 or force_print: print( - f"Processed {num2hrb(count)} docs = {num2hrb(total_doc_len)} chars = {num2hrb(total_enc_len)} tokens", - file=sys.stderr, + f"\tProcessed {num2hrb(count)} docs = {num2hrb(total_doc_len)} chars = {num2hrb(total_enc_len)} tokens" ) def process_json_file( @@ -120,7 +145,7 @@ def process_json_file( ): output_prefix = Path(output_dir) / Path(input_file_name).stem - print("Opening", input_file_name) + print(f"\nOpening {input_file_name}") fin = open(input_file_name, encoding="utf-8") pool = multiprocessing.Pool(self.workers, initializer=encoder.initializer) @@ -142,7 +167,7 @@ def process_json_file( ) if not builders: - print(f"Output files corresponding to {input_file_name} already exist, skipping") + print(f"\t[SKIP] Output files corresponding to {input_file_name} already exist") return 0 total_doc_len, total_enc_len, final_enc_len = 0, 0, 0 @@ -153,6 +178,7 @@ def process_json_file( for key in doc: builders[key].add_document(doc[key], sentence_lens[key]) self._print_processing_stats(i, total_doc_len, total_enc_len) + self._print_processing_stats(i, total_doc_len, total_enc_len, force_print=True) fin.close() for key in builders: @@ -161,8 +187,92 @@ def process_json_file( return final_enc_len +def _download_hf_dataset( + dataset: str, + output_dir: str | Path, + json_keys: list[str], + name: str | None = None, + split: str | None = "train", + max_samples_per_split: int | None = None, +) -> list[str]: + """Download a Hugging Face dataset and save as JSONL files. + + Returns: + List of paths to downloaded JSONL files. + """ + print(f"Downloading dataset {dataset} from Hugging Face") + jsonl_paths: list[str] = [] + + try: + response = requests.get( + f"https://datasets-server.huggingface.co/splits?dataset={dataset}", + headers=build_hf_headers(), + timeout=10, + ) + response.raise_for_status() + except requests.RequestException as e: + raise RuntimeError(f"Failed to fetch dataset splits for {dataset}: {e}") from e + + response_json = response.json() + print(f"\nFound {len(response_json['splits'])} total splits for {dataset}:") + for entry in response_json["splits"]: + print(f"\t{entry}") + + splits_to_process = [] + for entry in response_json["splits"]: + if name is not None and name != entry.get("config", None): + continue + if split is not None and split != entry["split"]: + continue + splits_to_process.append(entry) + + print(f"\nFound {len(splits_to_process)} splits to process:") + for entry in splits_to_process: + print(f"\t{entry}") + + for entry in splits_to_process: + skip_processing = False + path = entry["dataset"] + name = entry.get("config", None) + split = entry["split"] + if max_samples_per_split is not None: + split = f"{split}[:{max_samples_per_split}]" + jsonl_file_path = f"{output_dir}/raw/{path.replace('/', '--')}_{name}_{split}.jsonl" + + print(f"\nLoading HF dataset {path=}, {name=}, {split=}") + if os.path.exists(jsonl_file_path): + jsonl_paths.append(jsonl_file_path) + print(f"\t[SKIP] Raw dataset {jsonl_file_path} already exists") + continue + ds = load_dataset(path=path, name=name, split=split) + + for key in json_keys: + if key not in ds.features: + warn(f"[SKIP] {key=} not found in {ds.features=}") + skip_processing = True + break + + if skip_processing: + continue + + print(f"Saving raw dataset to {jsonl_file_path}") + ds.to_json(jsonl_file_path) + jsonl_paths.append(jsonl_file_path) + + print(f"\n\nTokenizing JSONL paths: {jsonl_paths}\n") + return jsonl_paths + + def megatron_preprocess_data( - input_path: str | Path | list[str] | list[Path], + *, + input_dir: str | Path | None = None, + jsonl_paths: str | Path | list[str] | list[Path] | None = None, + # Hugging Face Hub dataset arguments + hf_dataset: str | None = None, + hf_name: str | None = None, + hf_split: str | None = "train", + hf_max_samples_per_split: int | None = None, + # Other arguments output_dir: str | Path, tokenizer_name_or_path: str, json_keys: list[str] = ["text"], @@ -173,25 +283,48 @@ def megatron_preprocess_data( ): """Process large data for pretraining. + Exactly one of ``input_dir``, ``jsonl_paths``, or ``hf_dataset`` must be provided. + Args: - input_path (str | Path | list): Path to file or directory - containing input JSONL files, or list of paths to JSONL files - output_dir (str | Path): Path to directory to save binary output files - tokenizer_name_or_path (str): Name or path of the Hugging Face tokenizer to use - json_keys (list, optional): List of keys to extract from json. Defaults to ["text"] - append_eod (bool, optional): Append an token to the end of a document. Defaults to False - max_sequence_length (int, optional): Maximum tokenized sequence length. Defaults to None - workers (int, optional): Number of worker processes to launch. Defaults to 1 - log_interval (int, optional): Interval between progress updates. Defaults to 1000 + input_dir (str | Path, optional): Directory containing JSONL files to tokenize. + jsonl_paths (str | Path | list, optional): One or more paths to JSONL files. + hf_dataset (str, optional): Hugging Face Hub dataset name or path to download and tokenize. + hf_name (str, optional): Hugging Face Hub dataset subset name. Downloads all subsets if None. + hf_split (str, optional): Hugging Face Hub dataset split. Defaults to "train". + hf_max_samples_per_split (int, optional): Maximum number of samples to download per split from Hugging Face Hub. + Skip to download all samples. + output_dir (str | Path): Path to directory to save binary output files. + tokenizer_name_or_path (str): Name or path of the Hugging Face tokenizer to use. + json_keys (list, optional): List of keys to extract from json. Defaults to ["text"]. + append_eod (bool, optional): Append an token to the end of a document. Defaults to False. + max_sequence_length (int, optional): Maximum tokenized sequence length. Defaults to None. + workers (int, optional): Number of worker processes to launch. Defaults to 1. + log_interval (int, optional): Interval between progress updates. Defaults to 100000. """ - if isinstance(input_path, list): - file_names = input_path - elif Path(input_path).is_file(): - file_names = [input_path] - else: - file_names = sorted(Path(input_path).glob("*.jsonl")) + num_sources = sum(x is not None for x in (input_dir, jsonl_paths, hf_dataset)) + if num_sources != 1: + raise ValueError( + "Exactly one of `input_dir`, `jsonl_paths`, or `hf_dataset` must be provided." + ) + + if hf_dataset is not None: + jsonl_paths = _download_hf_dataset( + hf_dataset, + output_dir, + json_keys, + name=hf_name, + split=hf_split, + max_samples_per_split=hf_max_samples_per_split, + ) + + if input_dir is not None: + file_names = sorted(Path(input_dir).glob("*.jsonl")) if not file_names: - raise ValueError(f"No JSONL files found in input path: {input_path}") + raise ValueError(f"No JSONL files found in input directory: {input_dir}") + elif isinstance(jsonl_paths, (str, Path)): + file_names = [jsonl_paths] # type: ignore[list-item] + else: + file_names = list(jsonl_paths) # type: ignore[arg-type] Path(output_dir).mkdir(exist_ok=True) vocab_size = AutoTokenizer.from_pretrained(tokenizer_name_or_path).vocab_size @@ -204,32 +337,43 @@ def megatron_preprocess_data( num_tokens = partition.process_json_file(name, output_dir, encoder) final_enc_len += num_tokens - print(f">>> Total number of tokens: {num2hrb(final_enc_len)}") + print(f"\n\n>>> Total number of tokens currently processed: {num2hrb(final_enc_len)}") def main(): - """Sample main function to process large data for pretraining. - - Example usage: - - >>> python megatron_preprocess_data.py \ - --dataset "nvidia/Nemotron-Pretraining-Dataset-sample" \ - --tokenizer "meta-llama/Llama-3.2-1B-Instruct" \ - --output_dir "./processed_data" - """ + """Sample main function to process large data for pretraining.""" parser = argparse.ArgumentParser(prog="megatron_preprocess_data") - parser.add_argument("--input_path", type=str, default=None, help="Input path.") + # Dataset arguments (pre-downloaded .jsonl files or download from Hugging Face Hub) + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument("--input_dir", type=str, help="Directory containing JSONL files") + group.add_argument( + "--jsonl_paths", nargs="+", type=str, help="One or more paths to JSONL files" + ) + group.add_argument( + "--hf_dataset", + type=str, + help="Hugging Face Hub dataset path to download and tokenize", + ) + parser.add_argument( + "--hf_name", + type=str, + default=None, + help="Hugging Face Hub dataset subset name. Skip to download and tokenize all subsets for the dataset.", + ) parser.add_argument( - "--dataset", + "--hf_split", type=str, - default="nvidia/Nemotron-Pretraining-Dataset-sample", - help="Hugging Face Hub dataset name or path", + default="train", + help="Hugging Face Hub dataset split. Skip to download and tokenize all splits for the subset.", ) - parser.add_argument("--subset", type=str, default=None, help="Hugging Face Hub dataset subset") - parser.add_argument("--split", type=str, default="train", help="Hugging Face Hub dataset split") parser.add_argument( - "--output_dir", type=str, default="./processed_data", help="Output directory" + "--hf_max_samples_per_split", + type=int, + default=None, + help="Maximum number of samples to download per split from Hugging Face Hub. Skip to download all samples.", ) + # Other arguments + parser.add_argument("--output_dir", type=str, required=True, help="Output directory") parser.add_argument("--tokenizer", type=str, required=True, help="Tokenizer name or path") parser.add_argument("--json_keys", nargs="+", default=["text"], help="JSON keys to tokenize") parser.add_argument("--append_eod", action="store_true", help="Append token") @@ -237,51 +381,21 @@ def main(): "--max_sequence_length", type=int, default=None, help="Maximum sequence length" ) parser.add_argument("--workers", type=int, default=8, help="Number of worker processes") - parser.add_argument("--log_interval", type=int, default=1000, help="Log interval") + parser.add_argument("--log_interval", type=int, default=100000, help="Log interval") args = parser.parse_args() - if args.input_path is None: - args.input_path = [] - - try: - response = requests.get( - f"https://datasets-server.huggingface.co/splits?dataset={args.dataset}", - timeout=10, - ) - response.raise_for_status() - except requests.RequestException as e: - print(f"Failed to fetch dataset splits for {args.dataset}: {e}") - return - - for entry in response.json()["splits"]: - skip_processing = False - name = entry["dataset"] - subset = entry.get("config", None) - split = entry["split"] - - if args.subset is not None and args.subset != subset: - skip_processing = True - if args.split is not None and args.split != split: - skip_processing = True - - print(f"Loading dataset {name} with subset {subset} and split {split}") - dataset = load_dataset(name, subset, split=split) - - for key in args.json_keys: - if key not in dataset.features: - print(f"Key {key} not found in dataset features. Skipping...") - skip_processing = True - break - - if skip_processing: - continue - - json_file_path = args.output_dir + "/" + name + "_" + subset + "_" + split + ".jsonl" - dataset.to_json(json_file_path) - args.input_path += [json_file_path] + print("\n==================== Arguments ====================") + for k, v in args.__dict__.items(): + print(f"{k:<35} {v}") + print("===================================================\n") megatron_preprocess_data( - input_path=args.input_path, + input_dir=args.input_dir, + jsonl_paths=args.jsonl_paths, + hf_dataset=args.hf_dataset, + hf_name=args.hf_name, + hf_split=args.hf_split, + hf_max_samples_per_split=args.hf_max_samples_per_split, output_dir=args.output_dir, tokenizer_name_or_path=args.tokenizer, json_keys=args.json_keys, diff --git a/tests/unit/torch/utils/test_megatron_preprocess_data.py b/tests/gpu_megatron/torch/utils/plugins/test_megatron_preprocess_data.py similarity index 52% rename from tests/unit/torch/utils/test_megatron_preprocess_data.py rename to tests/gpu_megatron/torch/utils/plugins/test_megatron_preprocess_data.py index dbdd8e308..de6bc7181 100644 --- a/tests/unit/torch/utils/test_megatron_preprocess_data.py +++ b/tests/gpu_megatron/torch/utils/plugins/test_megatron_preprocess_data.py @@ -15,18 +15,9 @@ import json import os -import platform from pathlib import Path -import pytest -from _test_utils.import_helper import skip_if_no_megatron - -if platform.system() == "Windows": - pytest.skip("Skipping on Windows", allow_module_level=True) - -skip_if_no_megatron() -datasets = pytest.importorskip("datasets") -_ = pytest.importorskip("transformers") +from datasets import load_dataset from modelopt.torch.utils.plugins.megatron_preprocess_data import megatron_preprocess_data @@ -40,49 +31,46 @@ def download_and_prepare_minipile_dataset(output_dir: Path) -> Path: Returns: Path to the created JSONL file """ - # Download the dataset - dataset = datasets.load_dataset("nanotron/minipile_100_samples", split="train") + dataset = load_dataset("nanotron/minipile_100_samples", split="train") - # Convert to JSONL format jsonl_file = output_dir / "minipile_100_samples.jsonl" with open(jsonl_file, "w", encoding="utf-8") as f: for item in dataset: - # Extract the text field and write as JSONL json_obj = {"text": item["text"]} f.write(json.dumps(json_obj) + "\n") return jsonl_file -def test_megatron_preprocess_data_with_minipile_dataset(tmp_path): - """Test megatron_preprocess_data function with nanotron/minipile_100_samples dataset. +def test_megatron_preprocess_data_with_minipile_jsonl(tmp_path): + """Test megatron_preprocess_data with nanotron/minipile_100_samples dataset. This test: 1. Downloads the HuggingFace dataset "nanotron/minipile_100_samples" 2. Converts it to JSONL format - 3. Passes it to megatron_preprocess_data + 3. Calls megatron_preprocess_data with jsonl_paths 4. Verifies that output files are created """ - # Download and prepare the dataset input_jsonl = download_and_prepare_minipile_dataset(tmp_path) - # Verify the input file was created and has content assert input_jsonl.exists(), "Input JSONL file should exist" assert input_jsonl.stat().st_size > 0, "Input JSONL file should not be empty" - # Test the megatron_preprocess_data function + with open(input_jsonl, encoding="utf-8") as f: + first_line = f.readline().strip() + first_item = json.loads(first_line) + assert "text" in first_item, "Each JSONL item should have a 'text' field" + assert isinstance(first_item["text"], str), "Text field should be a string" + megatron_preprocess_data( - input_path=input_jsonl, + jsonl_paths=input_jsonl, output_dir=tmp_path, - tokenizer_name_or_path="gpt2", # Use a small, common tokenizer + tokenizer_name_or_path="gpt2", json_keys=["text"], - append_eod=False, workers=1, - log_interval=10, ) - # Verify that output files were created output_prefix = tmp_path / "minipile_100_samples" expected_bin_file = f"{output_prefix}_text_document.bin" expected_idx_file = f"{output_prefix}_text_document.idx" @@ -94,55 +82,31 @@ def test_megatron_preprocess_data_with_minipile_dataset(tmp_path): f"Expected index file {expected_idx_file} should exist" ) - # Verify the files have content (non-zero size) assert os.path.getsize(expected_bin_file) > 0, "Binary file should not be empty" assert os.path.getsize(expected_idx_file) > 0, "Index file should not be empty" - # Optional: Verify the input JSONL file structure - with open(input_jsonl, encoding="utf-8") as f: - first_line = f.readline().strip() - first_item = json.loads(first_line) - assert "text" in first_item, "Each JSONL item should have a 'text' field" - assert isinstance(first_item["text"], str), "Text field should be a string" - -def test_megatron_preprocess_data_with_custom_parameters(tmp_path): - """Test megatron_preprocess_data with different parameters.""" - # Create a minimal test dataset - input_jsonl = tmp_path / "test_data.jsonl" +def test_megatron_preprocess_data_with_hf_dataset(tmp_path): + """Test megatron_preprocess_data with dataset download, --append_eod and --max_sequence_length. - # Create some test data - test_data = [ - {"text": "This is a test sentence for preprocessing."}, - {"text": "Another test sentence with different content."}, - {"text": "A third sentence to make sure the function works correctly."}, - ] - - with open(input_jsonl, "w", encoding="utf-8") as f: - f.writelines(json.dumps(item) + "\n" for item in test_data) - - # Test with different parameters + Downloads nanotron/minipile_100_samples train split from Hugging Face and tokenizes it. + """ megatron_preprocess_data( - input_path=input_jsonl, + hf_dataset="nanotron/minipile_100_samples", + hf_split="train", output_dir=tmp_path, tokenizer_name_or_path="gpt2", json_keys=["text"], - append_eod=True, # Test with end-of-document token - max_sequence_length=5, # Test with sequence length limit - workers=1, - log_interval=1, + append_eod=True, + max_sequence_length=512, + workers=4, ) - # Verify output files exist - output_prefix = tmp_path / "test_data" - expected_bin_file = f"{output_prefix}_text_document.bin" - expected_idx_file = f"{output_prefix}_text_document.idx" + bin_files = sorted(tmp_path.glob("*.bin")) + idx_files = sorted(tmp_path.glob("*.idx")) - assert os.path.exists(expected_bin_file), ( - f"Expected binary file {expected_bin_file} should exist" - ) - assert os.path.exists(expected_idx_file), ( - f"Expected index file {expected_idx_file} should exist" - ) - assert os.path.getsize(expected_bin_file) > 0, "Binary file should not be empty" - assert os.path.getsize(expected_idx_file) > 0, "Index file should not be empty" + assert len(bin_files) > 0, f"Expected .bin files in {tmp_path}, found none" + assert len(idx_files) > 0, f"Expected .idx files in {tmp_path}, found none" + + for f in bin_files + idx_files: + assert f.stat().st_size > 0, f"{f.name} should not be empty"