Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 215 additions & 0 deletions docs/user-guide/features/cuda_graph.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
<!---
Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved.
NVIDIA CORPORATION and its licensors retain all intellectual property
and proprietary rights in and to this software, related documentation
and any modifications thereto. Any use, reproduction, disclosure or
distribution of this software or related documentation without an express
license agreement from NVIDIA CORPORATION is strictly prohibited.
-->

# CUDA Graph

CUDA Graphs reduce kernel-launch overhead by recording GPU operations once and replaying the recording on subsequent iterations. Megatron-LM provides three CUDA graph implementations controlled by `--cuda-graph-impl`.

For implementation background and design details, see NVIDIA's
[Transformer Engine and Megatron-LM CUDA Graph Support](https://docs.nvidia.com/dl-cuda-graph/torch-cuda-graph/te-megatron-cuda-graphs.html).
That article is a useful conceptual reference, but some examples there still use older flags such as
`--enable-cuda-graph` or `--cuda-graph-scope full_iteration`; in this repository, prefer
`--cuda-graph-impl local|transformer_engine|full_iteration` as documented below.

## Overview

CUDA graph behavior is set by three orthogonal flags:

| Flag | Values | Purpose |
|---|---|---|
| `--cuda-graph-impl` | `none` / `local` / `transformer_engine` / `full_iteration` | Which capture backend or strategy to use |
| `--cuda-graph-modules` | `attn` / `mlp` / `moe` / `moe_router` / `moe_preprocess` / `mamba` | Per-layer **training** capture coverage; multi-valued and only meaningful for `local` and `transformer_engine` |
| `--inference-cuda-graph-scope` | `none` / `layer` / `block` | Granularity of CUDA graphs during **inference**; only `local` supports non-`none` values |

Supported combinations:

| `--cuda-graph-impl` | Backend | Training capture | Inference capture |
|---|---|---|---|
| `none` | — | off | off |
| `local` | MCore `CudaGraphManager` | per-layer, controlled by `--cuda-graph-modules` | `layer` (default) or `block`, controlled by `--inference-cuda-graph-scope` |
| `transformer_engine` | TE `make_graphed_callables()` | per-layer, controlled by `--cuda-graph-modules` | not supported (`none` only) |
| `full_iteration` | MCore `FullCudaGraphWrapper` | one graph per training iteration; `--cuda-graph-modules` must be empty | not supported (`none` only) |

---

## CUDA Graph — Local Implementation (`--cuda-graph-impl local`)

Uses MCore's built-in `CudaGraphManager`. During training, this is a per-layer mode:
leaving `--cuda-graph-modules` unset captures the whole Transformer layer, while specifying
modules restricts capture to selected sub-regions. During inference, `local` can instead attach
graphs at either the layer boundary or the enclosing block boundary, as controlled by
`--inference-cuda-graph-scope`.

Operationally, this path is tightly integrated into MCore training and inference:

- graphable modules create and own their `CudaGraphManager` instances automatically
- the existing training schedules drive warmup/capture/replay automatically
- users select the mode through config flags only; there is no separate helper API to
wire into a custom training loop or a separate need to handle static input buffers

### Usage

```bash
--cuda-graph-impl local
```

### `--cuda-graph-modules` options

| Module | What is captured |
|---|---|
| *(empty / not set)* | Entire Transformer layer (default) |
| `attn` | `TransformerLayer._forward_attention()` |
| `mlp` | `TransformerLayer._forward_mlp()` for dense layers |
| `moe` | `TransformerLayer._forward_mlp()` for MoE layers (drop-and-pad only) |
| `moe_router` | MoE router + shared experts (if not EP-comm-overlapped) |
| `moe_preprocess` | `MoELayer.preprocess()` — must be paired with `moe_router` |
| `mamba` | Mamba SSM layer |

**Example — MoE model, capture attention and router:**
```bash
--cuda-graph-impl local \
# Optionally restrict captured modules (default: capture whole layer, but not working for MoE dynamic shapes)
--cuda-graph-modules attn moe_router moe_preprocess
```

---

## CUDA Graph — Transformer Engine Implementation (`--cuda-graph-impl transformer_engine`)

Uses Transformer Engine's `make_graphed_callables()` path. In Megatron-LM's CLI, this has the
same training granularity as `local`: leaving `--cuda-graph-modules` unset captures the whole
Transformer layer, while specifying modules restricts capture to selected sub-regions. The main difference from
`local` is the backend implementation and feature compatibility. Unlike `local`, this path does
not support inference CUDA graphs.

Compared to `local`, this path exposes a more general and self-contained API via TE's
`make_graphed_callables()`, giving users greater flexibility and control over how CUDA graphs are
wired into custom training loops. The trade-off is that it requires more manual setup:

- the training loop must instantiate `TECudaGraphHelper`
- the training loop must call helper methods such as `create_cudagraphs()` and
`cuda_graph_set_manual_hooks()` at the correct points
Comment on lines +91 to +97
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd also note that this path does not do the static input buffer checking that the local path does to ensure that inputs get fixed memory addresses out of the box.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think TECudaGraphHelper makes this checking too?


Megatron-LM's stock training loop already wires these calls in `megatron/training/training.py`,
but custom training scripts must do the same work themselves.

### Usage

```bash
--cuda-graph-impl transformer_engine \
--cuda-graph-modules attn moe_router moe_preprocess
```

The same training `--cuda-graph-modules` options apply as for `local`, and the default is likewise
whole-layer training capture when the flag is omitted.

---

## Full-Iteration Training CUDA Graph (`--cuda-graph-impl full_iteration`)

Captures the entire training iteration (excluding optimizer) as a single CUDA graph. The same
wrapper is also used for training-loop validation/eval in forward-only mode. This provides the
largest training/validation latency reduction.

This implementation does not create inference CUDA graphs. For inference, use
`--cuda-graph-impl local --inference-cuda-graph-scope layer|block`.

### Requirements

- `--no-check-for-nan-in-loss-and-grad` is required: NaN checks involve CPU-GPU synchronization
which cannot run inside a CUDA graph.
- `--cuda-graph-modules` must be omitted (or left empty): per-module selection has no meaning
when the entire iteration is captured as a single graph.

### Example

```bash
--cuda-graph-impl full_iteration \
--no-check-for-nan-in-loss-and-grad
```

---

## Common Configuration Examples

### Dense Model Training

All three implementations work for dense models:

```bash
# Per-layer (local)
--cuda-graph-impl local
# equivalent: --cuda-graph-impl local --cuda-graph-modules attn mlp

# Per-layer (TE)
--cuda-graph-impl transformer_engine
# equivalent: --cuda-graph-impl transformer_engine --cuda-graph-modules attn mlp

# Full-iteration
--cuda-graph-impl full_iteration \
--no-check-for-nan-in-loss-and-grad
```

### MoE Model Training

MoE expert dispatch involves dynamic shapes and cannot be captured. `--cuda-graph-modules` is used
to capture only the static parts (attention, router, preprocess) while leaving expert compute in
eager mode. Example using `transformer_engine` (`local` works the same way):

```bash
--cuda-graph-impl transformer_engine \
--cuda-graph-modules attn moe_router moe_preprocess
```

With paged stash (currently available only on `dev`; see
`docs/user-guide/features/paged_stash.md` on the `dev` branch), expert dispatch shapes become
static (pre-sized via `--moe-expert-rank-capacity-factor`), which allows full-iteration CUDA
graphs to be used on MoE models as well:

```bash
--cuda-graph-impl full_iteration \
--no-check-for-nan-in-loss-and-grad \
--moe-flex-dispatcher-backend hybridep \
--use-transformer-engine-op-fuser \
--moe-expert-rank-capacity-factor <float> \
--moe-paged-stash
```

---

## Additional Notes

- `--cuda-graph-warmup-steps` (default: 3) controls how many warmup steps run before CUDA graph
capture. Setting it to 0 is not recommended: some operations rely on the first few iterations
for lazy initialization or autotuning, and capturing too early may produce incorrect or
suboptimal graphs.
- Inference CUDA graphs (serving or RL rollout) currently require
`--cuda-graph-impl local`. Use `--inference-cuda-graph-scope layer|block` with
`local`; all other implementations must set `--inference-cuda-graph-scope none`,
meaning inference runs in eager mode.
- Background reference: [Transformer Engine and Megatron-LM CUDA Graph Support](https://docs.nvidia.com/dl-cuda-graph/torch-cuda-graph/te-megatron-cuda-graphs.html),
which also covers PyTorch CUDA Graph best practices and lessons learned.

---

## Migration Guide

Legacy configurations (including `--enable-cuda-graph`, `--external-cuda-graph`, the renamed
`--cuda-graph-scope` flag (now `--cuda-graph-modules`), and deprecated module values such as
`full_iteration` and `full_iteration_inference`) are still accepted and automatically migrated
at runtime, but we encourage updating your configs to the new forms:

| Old command | New command |
|---|---|
| `--enable-cuda-graph` | `--cuda-graph-impl local` |
| `--external-cuda-graph` | `--cuda-graph-impl transformer_engine` |
| `--cuda-graph-scope <modules>` | `--cuda-graph-modules <modules>` |
| `--cuda-graph-impl local --cuda-graph-scope full_iteration` | `--cuda-graph-impl full_iteration` |
| `--cuda-graph-impl local --cuda-graph-scope full_iteration_inference` | `--cuda-graph-impl local --inference-cuda-graph-scope block` |
| `--cuda-graph-impl local --cuda-graph-scope attn moe_router moe_preprocess full_iteration_inference` | `--cuda-graph-impl local --cuda-graph-modules attn moe_router moe_preprocess --inference-cuda-graph-scope block` |
1 change: 1 addition & 0 deletions docs/user-guide/features/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Guides for Megatron Core training features.
```{toctree}
:maxdepth: 2
cuda_graph
fine_grained_activation_offloading
moe
context_parallel
Expand Down
1 change: 0 additions & 1 deletion examples/rl/model_configs/nemotron6_3b_moe.sh
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ MODEL_OPTIONS="\
--inference-dynamic-batching-num-cuda-graphs 2 \
--decode-only-cuda-graphs \
--cuda-graph-impl local \
--cuda-graph-scope full \
--use-checkpoint-args \
--enable-experimental \
--cross-entropy-loss-fusion \
Expand Down
17 changes: 6 additions & 11 deletions megatron/core/inference/engines/dynamic_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from megatron.core.inference.utils import Counter, await_process_call
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer.cuda_graphs import delete_cuda_graphs
from megatron.core.transformer.enums import CudaGraphScope
from megatron.core.transformer.enums import InferenceCudaGraphScope
from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction
from megatron.core.utils import (
deprecate_args,
Expand Down Expand Up @@ -224,7 +224,8 @@ def __init__(self, controller: TextGenerationController, context: DynamicInferen
self.disable_ep_consensus = inference_config.disable_ep_consensus
self.ep_consensus_interval = inference_config.ep_consensus_interval
self.cuda_graph_impl = model_config.cuda_graph_impl
self.cuda_graph_scope = model_config.cuda_graph_scope
self.inference_cuda_graph_scope = model_config.inference_cuda_graph_scope
self.cuda_graph_modules = model_config.cuda_graph_modules
# Initialize engine.
self.reset()

Expand Down Expand Up @@ -331,17 +332,11 @@ def create_cuda_graphs(self, reset_context: bool = True):
reset_context (bool): Whether to reset the context after building cuda graphs.
"""

if self.cuda_graph_impl != "local":
if self.inference_cuda_graph_scope == InferenceCudaGraphScope.none:
return

if (
CudaGraphScope.full_iteration in self.cuda_graph_scope
and CudaGraphScope.full_iteration_inference not in self.cuda_graph_scope
):
warnings.warn(
"\n\n*** WARNING: 'full_iteration' CUDA graph scope used during inference! "
"This will not create inference CUDA graphs. Use '--cuda-graph-scope=full_iteration_inference' instead. ***\n"
)
if self.cuda_graph_impl != "local":
return

context = self.context
controller = self.controller
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
gather_from_sequence_parallel_region,
scatter_to_sequence_parallel_region,
)
from megatron.core.transformer.enums import CudaGraphScope
from megatron.core.transformer.moe.moe_layer import BaseMoELayer
from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction
from megatron.core.transformer.utils import set_model_to_sequence_parallel
Expand Down Expand Up @@ -1924,10 +1923,7 @@ def generate_all_output_tokens_static_batch(
)

# Check whether CUDA graphs are enabled
enable_cuda_graph = (
model_config.cuda_graph_impl == "local"
and CudaGraphScope.full_iteration not in model_config.cuda_graph_scope
)
enable_cuda_graph = model_config.cuda_graph_impl == "local"

# Pad batch tokens if necessary
batch_size = len(active_requests)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
is_vp_last_stage,
)
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer.enums import AttnBackend, CudaGraphScope
from megatron.core.transformer.enums import AttnBackend
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.multi_token_prediction import tie_word_embeddings_state_dict
from megatron.core.transformer.transformer_config import TransformerConfig
Expand Down Expand Up @@ -159,8 +159,8 @@ def compute_language_model_loss(self, labels: Tensor, logits: Tensor) -> Tensor:
labels = torch.as_strided(labels, labels.size(), (labels.size()[1], 1))
# Use is_cg_capturable=True for full iteration CUDA graphs to avoid torch.equal checks
is_cg_capturable = (
hasattr(self.config, 'cuda_graph_scope')
and CudaGraphScope.full_iteration in self.config.cuda_graph_scope
hasattr(self.config, 'cuda_graph_impl')
and self.config.cuda_graph_impl == "full_iteration"
)
if is_cg_capturable and not is_te_min_version("2.7.0"):
from megatron.core.utils import get_te_version
Expand All @@ -169,7 +169,7 @@ def compute_language_model_loss(self, labels: Tensor, logits: Tensor) -> Tensor:
raise AssertionError(
f"CUDA graph compatible cross entropy requires TransformerEngine >= 2.7.0, "
f"but found version {current_version}. Please upgrade TransformerEngine "
f"or set cuda_graph_scope to a value other than 'full_iteration'."
f"or set cuda_graph_impl to a value other than 'full_iteration'."
)

loss = te_parallel_cross_entropy(
Expand Down
10 changes: 5 additions & 5 deletions megatron/core/models/gpt/fine_grained_callables.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
FineGrainedActivationOffloadingInterface as off_interface,
)
from megatron.core.pipeline_parallel.utils import ScheduleNode, make_viewless
from megatron.core.transformer.enums import CudaGraphScope
from megatron.core.transformer.enums import CudaGraphModule
from megatron.core.transformer.module import GraphableMegatronModule, float16_to_fp32
from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.transformer.multi_token_prediction import (
Expand Down Expand Up @@ -93,7 +93,7 @@ def should_free_input(name, is_moe, config, num_local_experts):
# If moe_preprocess is in cuda graph scope, tokens and probs are fixed size tensors,
# so they cannot be freed.
"moe_dispatch": not (enable_deepep or enable_hybridep)
and (CudaGraphScope.moe_preprocess not in config.cuda_graph_scope),
and (CudaGraphModule.moe_preprocess not in config.cuda_graph_modules),
}

return free_input_nodes.get(name, False)
Expand Down Expand Up @@ -388,16 +388,16 @@ def __init__(self, layer):
)
else:
self.shared_expert_dw_callable = None
self.cuda_graph_scope = layer.config.cuda_graph_scope
self.cuda_graph_modules = layer.config.cuda_graph_modules

def backward_dw(self):
"""Execute weight gradients, skipping CUDA graphed components during replay."""
is_replay = hasattr(self.layer, 'cuda_graphs') and self.layer.cuda_graphs
if self.shared_expert_dw_callable is not None and (
not is_replay or CudaGraphScope.moe_router not in self.cuda_graph_scope
not is_replay or CudaGraphModule.moe_router not in self.cuda_graph_modules
):
self.shared_expert_dw_callable()
if not is_replay or CudaGraphScope.attn not in self.cuda_graph_scope:
if not is_replay or CudaGraphModule.attn not in self.cuda_graph_modules:
self.attn_dw_callable()
if is_replay and self.graphed_backward_dw_callable is not None:
self.graphed_backward_dw_callable()
Expand Down
10 changes: 2 additions & 8 deletions megatron/core/models/gpt/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.quantization.utils import get_quant_config_or_none
from megatron.core.tensor_parallel import gather_from_sequence_parallel_region
from megatron.core.transformer.enums import CudaGraphScope, ModelType
from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.multi_token_prediction import (
MultiTokenPredictionBlock,
mtp_on_this_rank,
Expand Down Expand Up @@ -408,13 +408,7 @@ def _preprocess(

if (
in_inference_mode
and (
(
self.config.cuda_graph_impl == "local"
and CudaGraphScope.full_iteration not in self.config.cuda_graph_scope
)
or self.config.flash_decode
)
and (self.config.cuda_graph_impl == "local" or self.config.flash_decode)
and inference_context.is_static_batching()
):
current_batch_size = input_ids.shape[0]
Expand Down
Loading