Conversation
| import os | ||
| import pathlib | ||
| from dataclasses import dataclass | ||
| from torch.profiler import ProfilerActivity |
There was a problem hiding this comment.
Pull request overview
Adds MLPerf-oriented training support for Llama 3.1 8B FP8 in Primus/Megatron, including early-stop/time-to-train logging, reproducible validation sampling, optional TE fused SwiGLU, and runnable MLPerf example artifacts.
Changes:
- Add MLPerf-style early stopping + time-to-train logging driven by an eval-loss target.
- Patch Megatron validation sampling to be fixed/reproducible for MLPerf evaluation.
- Introduce MLPerf example scripts/configs (train entrypoint, run script, profiler handler, platform config, README).
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 15 comments.
Show a summary per file
| File | Description |
|---|---|
| primus/modules/trainer/megatron/trainer.py | Adds training start timestamp, early-stop on target eval loss, and time-to-train logging. |
| primus/modules/trainer/megatron/pre_trainer.py | Makes batch tuple return order explicit instead of relying on dict ordering. |
| primus/backends/megatron/training/evaluator.py | Stores eval “lm loss” onto args for downstream early-stop logic. |
| primus/backends/megatron/patches/validation_data_sampling_patches.py | Adds MLPerf patch hooks to fix validation sample counts and loader sampling behavior. |
| primus/backends/megatron/patches/te_patches/fused_bias_swiglu_patches.py | Adds optional TE fused swiglu/dswiglu patching via env flag. |
| examples/mlperf/src/train.py | Adds an MLPerf training entrypoint script. |
| examples/mlperf/src/prof_handler.py | Adds a torch profiler output/handler utility. |
| examples/mlperf/run_and_time.sh | Adds a run-and-time wrapper script for MLPerf timing runs. |
| examples/mlperf/README.md | Adds setup/run instructions for the MLPerf example. |
| examples/mlperf/configs/MI355X/llama3.1_8B-pretrain-FP8.yaml | Adds an MLPerf-oriented MI355X FP8 training config. |
| examples/mlperf/config_MI355X_1x8x1.sh | Adds MI355X 1x8x1 environment configuration for MLPerf runs. |
| enable_forward_pre_hook(model) | ||
| pre_hook_enabled = True | ||
|
|
||
| target_eval_loss = float(os.environ.get("TARGET_EVAL_LOSS", "0")) |
There was a problem hiding this comment.
Early-stop threshold is read from TARGET_EVAL_LOSS, but the MLPerf example scripts/configs in this PR set MLLOG_TARGET_EVAL_LOSS. As-is, early stopping won't trigger unless users also export TARGET_EVAL_LOSS. Consider reading MLLOG_TARGET_EVAL_LOSS (or supporting both names with a clear precedence).
| target_eval_loss = float(os.environ.get("TARGET_EVAL_LOSS", "0")) | |
| target_eval_loss_env = os.environ.get("TARGET_EVAL_LOSS") | |
| if target_eval_loss_env is None: | |
| target_eval_loss_env = os.environ.get("MLLOG_TARGET_EVAL_LOSS", "0") | |
| target_eval_loss = float(target_eval_loss_env) |
| if target_eval_loss > 0 and hasattr(args, "_eval_val_loss"): | ||
| if args._eval_val_loss <= target_eval_loss: | ||
| run_duration = time.time() - train_start_time | ||
| log_rank_0( | ||
| f"[EarlyStop] Reached target! Stopping training. " | ||
| f"eval_loss: {args._eval_val_loss:.6f} (target: {target_eval_loss}) | " |
There was a problem hiding this comment.
The early-stop decision is based on args._eval_val_loss, but that attribute is only set on the pipeline stage that computes loss (see primus/backends/megatron/training/evaluator.py). With pipeline parallelism > 1, only last-stage ranks will update args.train_iters/args.do_valid, while other ranks keep training, which can deadlock. Please synchronize the stop condition across all ranks (e.g., broadcast the eval loss/stop flag and apply the same train_iters update everywhere).
| if target_eval_loss > 0 and hasattr(args, "_eval_val_loss"): | |
| if args._eval_val_loss <= target_eval_loss: | |
| run_duration = time.time() - train_start_time | |
| log_rank_0( | |
| f"[EarlyStop] Reached target! Stopping training. " | |
| f"eval_loss: {args._eval_val_loss:.6f} (target: {target_eval_loss}) | " | |
| if target_eval_loss > 0: | |
| local_eval_val_loss = getattr(args, "_eval_val_loss", None) | |
| stop_training = ( | |
| local_eval_val_loss is not None and local_eval_val_loss <= target_eval_loss | |
| ) | |
| synced_eval_val_loss = local_eval_val_loss | |
| if dist.is_available() and dist.is_initialized(): | |
| sync_device = ( | |
| torch.device("cuda", torch.cuda.current_device()) | |
| if torch.cuda.is_available() | |
| else torch.device("cpu") | |
| ) | |
| stop_tensor = torch.tensor( | |
| [1 if stop_training else 0], device=sync_device, dtype=torch.int32 | |
| ) | |
| eval_loss_tensor = torch.tensor( | |
| [ | |
| local_eval_val_loss | |
| if local_eval_val_loss is not None | |
| else float("inf") | |
| ], | |
| device=sync_device, | |
| dtype=torch.float32, | |
| ) | |
| dist.all_reduce(stop_tensor, op=dist.ReduceOp.MAX) | |
| dist.all_reduce(eval_loss_tensor, op=dist.ReduceOp.MIN) | |
| stop_training = bool(stop_tensor.item()) | |
| synced_eval_val_loss = eval_loss_tensor.item() | |
| if synced_eval_val_loss == float("inf"): | |
| synced_eval_val_loss = None | |
| if stop_training: | |
| run_duration = time.time() - train_start_time | |
| eval_loss_for_log = ( | |
| synced_eval_val_loss | |
| if synced_eval_val_loss is not None | |
| else float("nan") | |
| ) | |
| log_rank_0( | |
| f"[EarlyStop] Reached target! Stopping training. " | |
| f"eval_loss: {eval_loss_for_log:.6f} (target: {target_eval_loss}) | " |
| run_duration = time.time() - train_start_time | ||
| target_eval_loss = float(os.environ.get("TARGET_EVAL_LOSS", "0")) | ||
| final_eval_loss = getattr(args, "_eval_val_loss", None) | ||
| status = "success" if (final_eval_loss and target_eval_loss > 0 and final_eval_loss <= target_eval_loss) else "aborted" |
There was a problem hiding this comment.
status uses final_eval_loss truthiness (final_eval_loss and ...). If the final loss is 0.0 (or any falsy float), this will incorrectly report aborted even when the target is met. Use an explicit final_eval_loss is not None check (and consider handling NaN separately if needed).
| status = "success" if (final_eval_loss and target_eval_loss > 0 and final_eval_loss <= target_eval_loss) else "aborted" | |
| status = ( | |
| "success" | |
| if ( | |
| final_eval_loss is not None | |
| and target_eval_loss > 0 | |
| and final_eval_loss <= target_eval_loss | |
| ) | |
| else "aborted" | |
| ) |
| if "lm loss" in total_loss_dict: | ||
| val = total_loss_dict["lm loss"] | ||
| args._eval_val_loss = val.item() if hasattr(val, "item") else float(val) |
There was a problem hiding this comment.
args._eval_val_loss is only set on ranks where is_pipeline_stage_containing_loss() is true. Any logic that relies on this value (e.g., early stopping) must broadcast/synchronize it to all ranks; otherwise different ranks will make different control-flow decisions and can hang with pipeline parallelism > 1.
| phase_transition_samples = ( | ||
| [0] | ||
| + [t * args.global_batch_size for t in args.phase_transition_iterations] | ||
| + [args.train_samples] |
There was a problem hiding this comment.
In the phase-transition branch, phase_transition_samples appends args.train_samples, but train_samples may have been computed from args.train_iters * global_batch_size when args.train_samples is unset/None. Appending None here will break comparisons/arithmetic. Use the computed train_samples variable (or args.train_iters * args.global_batch_size) instead of args.train_samples.
| + [args.train_samples] | |
| + [train_samples] |
| train_data_path: "/data/mlperf/data/c4-train.en_6_text_document" | ||
| valid_data_path: "/data/mlperf/data/c4-validation-91205-samples.en_text_document" | ||
| test_data_path: null | ||
| seq_length: 8192 |
There was a problem hiding this comment.
seq_length is defined again here, duplicating the earlier seq_length setting. Please remove one of the duplicate keys to avoid ambiguity and YAML-parser incompatibilities.
| seq_length: 8192 |
| set -e | ||
|
|
||
| mkdir -p /results |
There was a problem hiding this comment.
With set -e, a non-zero torchrun exit will terminate the script immediately, so ret_code=$? and the explicit failure handling below never run. Consider disabling -e around torchrun (or using torchrun ...; ret_code=$? with set +e/set -e guards) so timing/result logging works on failures too.
| torchrun \ | ||
| --nproc_per_node=${GPUS_PER_NODE} \ | ||
| --nnodes=${NNODES} \ | ||
| --node_rank=${NODE_RANK} \ | ||
| --master_addr=${MASTER_ADDR} \ | ||
| --master_port=${MASTER_PORT} \ | ||
| src/train.py | ||
|
|
||
| ret_code=$? | ||
|
|
There was a problem hiding this comment.
This ret_code=$? check is ineffective when the script runs with set -e (the script exits immediately if torchrun fails). If you want to handle failures explicitly, capture the exit code by temporarily disabling -e or by using torchrun ... || ret_code=$? and then re-enable -e.
| torchrun \ | |
| --nproc_per_node=${GPUS_PER_NODE} \ | |
| --nnodes=${NNODES} \ | |
| --node_rank=${NODE_RANK} \ | |
| --master_addr=${MASTER_ADDR} \ | |
| --master_port=${MASTER_PORT} \ | |
| src/train.py | |
| ret_code=$? | |
| ret_code=0 | |
| torchrun \ | |
| --nproc_per_node=${GPUS_PER_NODE} \ | |
| --nnodes=${NNODES} \ | |
| --node_rank=${NODE_RANK} \ | |
| --master_addr=${MASTER_ADDR} \ | |
| --master_port=${MASTER_PORT} \ | |
| src/train.py || ret_code=$? |
| TORCHPROF_VERBOSE = os.getenv("TORCHPROF_VERBOSE", 1) | ||
| TORCHPROF_DEVICES = os.getenv("TORCHPROF_DEVICES", "GPU") | ||
| TORCHPROF_MAXROWS = os.getenv("TORCHPROF_MAXROWS", 100) |
There was a problem hiding this comment.
TORCHPROF_VERBOSE and TORCHPROF_MAXROWS are read via os.getenv without casting, so they become strings. TORCHPROF_MAXROWS is later passed to table(row_limit=...), which expects an int. Please cast these with int(...) (and then to bool for verbose).
| TORCHPROF_VERBOSE = os.getenv("TORCHPROF_VERBOSE", 1) | |
| TORCHPROF_DEVICES = os.getenv("TORCHPROF_DEVICES", "GPU") | |
| TORCHPROF_MAXROWS = os.getenv("TORCHPROF_MAXROWS", 100) | |
| TORCHPROF_VERBOSE = bool(int(os.getenv("TORCHPROF_VERBOSE", 1))) | |
| TORCHPROF_DEVICES = os.getenv("TORCHPROF_DEVICES", "GPU") | |
| TORCHPROF_MAXROWS = int(os.getenv("TORCHPROF_MAXROWS", 100)) |
| TORCHPROF_PROFILE_MEMORY = bool(os.getenv("TORCHPROF_PROFILE_MEMORY", 1)) | ||
| TORCHPROF_WITH_STACK = bool(os.getenv("TORCHPROF_WITH_STACK", 0)) | ||
| TORCHPROF_RECORD_SHAPES = bool(os.getenv("TORCHPROF_RECORD_SHAPES", 1)) | ||
| TORCHPROF_WITH_FLOPS = bool(os.getenv("TORCHPROF_WITH_FLOPS", 1)) |
There was a problem hiding this comment.
bool(os.getenv(...)) treats any non-empty string as True, so values like '0' will incorrectly enable options such as TORCHPROF_PROFILE_MEMORY/TORCHPROF_WITH_STACK. Parse these as integers first (e.g., bool(int(os.getenv(..., '0')))), or implement a small strtobool helper.
| TORCHPROF_PROFILE_MEMORY = bool(os.getenv("TORCHPROF_PROFILE_MEMORY", 1)) | |
| TORCHPROF_WITH_STACK = bool(os.getenv("TORCHPROF_WITH_STACK", 0)) | |
| TORCHPROF_RECORD_SHAPES = bool(os.getenv("TORCHPROF_RECORD_SHAPES", 1)) | |
| TORCHPROF_WITH_FLOPS = bool(os.getenv("TORCHPROF_WITH_FLOPS", 1)) | |
| TORCHPROF_PROFILE_MEMORY = bool(int(os.getenv("TORCHPROF_PROFILE_MEMORY", "1"))) | |
| TORCHPROF_WITH_STACK = bool(int(os.getenv("TORCHPROF_WITH_STACK", "0"))) | |
| TORCHPROF_RECORD_SHAPES = bool(int(os.getenv("TORCHPROF_RECORD_SHAPES", "1"))) | |
| TORCHPROF_WITH_FLOPS = bool(int(os.getenv("TORCHPROF_WITH_FLOPS", "1"))) |
| self.num_quantizers = num_quantizers | ||
| self.dtype = get_fp4_te_dtype(recipe) | ||
| if device is None: | ||
| device = torch.device("cuda") |
| export TORCHPROF_OUTPUT_DIR=/home/vidgoyal/small_llm_pretraining/primus/outputs/ | ||
| export TORCHPROF_VERBOSE=0 | ||
| export TORCHPROF_MAXROWS=100 | ||
| export TORCHPROF_PROFILE_MEMORY=0 | ||
| export TORCHPROF_WITH_STACK=0 | ||
| export TORCHPROF_RECORD_SHAPES=0 | ||
| export TORCHPROF_WITH_FLOPS=0 |
There was a problem hiding this comment.
TORCHPROF_OUTPUT_DIR is hardcoded to a developer home directory. This will break in containerized/CI runs and on other machines. Prefer a relative path, a /results/... default, or require callers to set the env var.
| ``` | ||
| ## Notes | ||
|
|
||
| - `log_interval: 99999999` suppresses regular Primus logs |
There was a problem hiding this comment.
The note claims log_interval: 99999999 suppresses logs, but the provided config sets log_interval: 999999. Please align the README with the actual value/config behavior to avoid confusion.
| - `log_interval: 99999999` suppresses regular Primus logs | |
| - `log_interval: 999999` suppresses regular Primus logs |
| import primus_turbo # pylint: disable=W0611 | ||
|
|
||
| HAVE_TURBO = True | ||
| HAVE_TURBO = False |
There was a problem hiding this comment.
In the Primus-Turbo availability probe, HAVE_TURBO is set to False even when import primus_turbo succeeds. This makes the if HAVE_TE and HAVE_TURBO: branch unreachable and effectively disables the Turbo FP4 path even when installed. Set HAVE_TURBO = True on successful import (and keep False only in the except path).
| HAVE_TURBO = False | |
| HAVE_TURBO = True |
| export PRIMUS_PATH=/home/vidgoyal/Primus-dev/Primus/ | ||
| export PRIMUS_MLPERF=1 | ||
| export PYTHONPATH="${PRIMUS_PATH}:${PRIMUS_PATH}/third_party/Megatron-LM:${PYTHONPATH}" | ||
| export EXP=/home/vidgoyal/Primus-dev/Primus/examples/mlperf/configs/MI355X/llama3.1_8B-pretrain-FP8.yaml | ||
| export DATA_PATH=/data |
There was a problem hiding this comment.
This config script hardcodes developer-specific absolute paths for PRIMUS_PATH and EXP, which prevents reuse on other systems. Consider making these derived from the script location (e.g., repo root) or requiring them as inputs, and keep only portable defaults.
| mock_data: false | ||
| train_data_path: "/data/mlperf/data/c4-train.en_6_text_document" | ||
| valid_data_path: "/data/mlperf/data/c4-validation-91205-samples.en_text_document" | ||
| test_data_path: null |
There was a problem hiding this comment.
seq_length is defined twice under overrides (once near the top and again in the data section). YAML will keep only the latter, which is easy to miss and can cause confusing config drift. Remove the duplicate key (or add a comment explaining intentional override).
| test_data_path: null | |
| test_data_path: null | |
| # Intentionally overrides an earlier `seq_length` in `overrides`; 8192 is the effective value. |
| run_duration = time.time() - train_start_time | ||
| target_eval_loss = float(os.environ.get("TARGET_EVAL_LOSS", "0")) | ||
| final_eval_loss = getattr(args, "_eval_val_loss", None) | ||
| status = "success" if (final_eval_loss and target_eval_loss > 0 and final_eval_loss <= target_eval_loss) else "aborted" |
There was a problem hiding this comment.
status uses a truthiness check on final_eval_loss (final_eval_loss and ...). This misclassifies valid losses like 0.0 as false and will also report aborted on ranks where _eval_val_loss is not set. Prefer final_eval_loss is not None (and, if needed, gather/broadcast the final eval loss to rank 0 before logging).
| status = "success" if (final_eval_loss and target_eval_loss > 0 and final_eval_loss <= target_eval_loss) else "aborted" | |
| status = ( | |
| "success" | |
| if ( | |
| final_eval_loss is not None | |
| and target_eval_loss > 0 | |
| and final_eval_loss <= target_eval_loss | |
| ) | |
| else "aborted" | |
| ) |
| set -e | ||
|
|
||
| mkdir -p /results | ||
|
|
||
| export GPUS_PER_NODE=${GPUS_PER_NODE:-8} | ||
| export NNODES=${NNODES:-1} | ||
| export NODE_RANK=${NODE_RANK:-0} | ||
| export MASTER_ADDR=${MASTER_ADDR:-localhost} | ||
| export MASTER_PORT=${MASTER_PORT:-29502} | ||
| export EXP=${EXP:-/workspace/code/conf/llama3.1_8B-pretrain.yaml} | ||
| export DATA_PATH=${DATA_PATH:-/data} | ||
|
|
||
| echo "============================================" | ||
| echo "MLPerf LLama3.1 8B Training" | ||
| echo "============================================" | ||
| echo "Config: ${EXP}" | ||
| echo "Data: ${DATA_PATH}" | ||
| echo "GPUs: ${GPUS_PER_NODE}" | ||
| echo "Nodes: ${NNODES}" | ||
| echo "Train iters: ${PRIMUS_TRAIN_ITERS}" | ||
| echo "Eval interval: ${PRIMUS_EVAL_INTERVAL}" | ||
| echo "Enable MLPerf logging: ${ENABLE_MLPERF}" | ||
| echo "MLLOG_TRAIN_LOSS_LOG_FREQ: ${MLLOG_TRAIN_LOSS_LOG_FREQ}" | ||
| echo "MLLOG_TARGET_EVAL_LOSS: ${MLLOG_TARGET_EVAL_LOSS}" | ||
| echo "MLLOG_SUBMISSION_BENCHMARK: ${MLLOG_SUBMISSION_BENCHMARK}" | ||
| echo "MLLOG_SUBMISSION_DIVISION: ${MLLOG_SUBMISSION_DIVISION}" | ||
| echo "MLLOG_SUBMISSION_ORG: ${MLLOG_SUBMISSION_ORG}" | ||
| echo "MLLOG_SUBMISSION_PLATFORM: ${MLLOG_SUBMISSION_PLATFORM}" | ||
| echo "============================================" | ||
|
|
||
| start=$(date +%s) | ||
| start_fmt=$(date +%Y-%m-%d\ %r) | ||
| echo "STARTING TIMING RUN AT $start_fmt" | ||
|
|
||
| torchrun \ | ||
| --nproc_per_node=${GPUS_PER_NODE} \ | ||
| --nnodes=${NNODES} \ | ||
| --node_rank=${NODE_RANK} \ | ||
| --master_addr=${MASTER_ADDR} \ | ||
| --master_port=${MASTER_PORT} \ | ||
| src/train.py | ||
|
|
||
| ret_code=$? | ||
|
|
There was a problem hiding this comment.
This script uses set -e but then tries to capture ret_code=$? after torchrun. With -e, a non-zero torchrun exit will abort the script immediately, so ret_code/timing output won’t be recorded. If you want timing even on failure, temporarily disable -e around torchrun (or use an if ...; then ...; fi pattern) and handle the exit code explicitly.
| PRIMUS_PATH = os.getenv("PRIMUS_PATH", "/home/vidgoyal/Primus") | ||
| MEGATRON_PATH = os.path.join(PRIMUS_PATH, "third_party/Megatron-LM") | ||
|
|
||
| if PRIMUS_PATH not in sys.path: | ||
| sys.path.insert(0, PRIMUS_PATH) | ||
| if MEGATRON_PATH not in sys.path: | ||
| sys.path.insert(0, MEGATRON_PATH) |
There was a problem hiding this comment.
PRIMUS_PATH defaults to a user-specific absolute path (/home/vidgoyal/Primus) and the script mutates sys.path at import time. This makes the entrypoint non-portable. Prefer requiring PRIMUS_PATH to be set (or derive it relative to this file/repo root) and avoid hardcoding developer home directories.
This PR adds MLPERF support for llama3.1 8B FP8.
Time to train: 98.0 mins; eval_loss: 3.292395 (target: 3.3)
Full Log:
mlperf_primus_llama3.18b.log