diff --git a/docs/user-guide/features/cuda_graph.md b/docs/user-guide/features/cuda_graph.md new file mode 100644 index 00000000000..28a1a5575dc --- /dev/null +++ b/docs/user-guide/features/cuda_graph.md @@ -0,0 +1,215 @@ + + +# CUDA Graph + +CUDA Graphs reduce kernel-launch overhead by recording GPU operations once and replaying the recording on subsequent iterations. Megatron-LM provides three CUDA graph implementations controlled by `--cuda-graph-impl`. + +For implementation background and design details, see NVIDIA's +[Transformer Engine and Megatron-LM CUDA Graph Support](https://docs.nvidia.com/dl-cuda-graph/torch-cuda-graph/te-megatron-cuda-graphs.html). +That article is a useful conceptual reference, but some examples there still use older flags such as +`--enable-cuda-graph` or `--cuda-graph-scope full_iteration`; in this repository, prefer +`--cuda-graph-impl local|transformer_engine|full_iteration` as documented below. + +## Overview + +CUDA graph behavior is set by three orthogonal flags: + +| Flag | Values | Purpose | +|---|---|---| +| `--cuda-graph-impl` | `none` / `local` / `transformer_engine` / `full_iteration` | Which capture backend or strategy to use | +| `--cuda-graph-modules` | `attn` / `mlp` / `moe` / `moe_router` / `moe_preprocess` / `mamba` | Per-layer **training** capture coverage; multi-valued and only meaningful for `local` and `transformer_engine` | +| `--inference-cuda-graph-scope` | `none` / `layer` / `block` | Granularity of CUDA graphs during **inference**; only `local` supports non-`none` values | + +Supported combinations: + +| `--cuda-graph-impl` | Backend | Training capture | Inference capture | +|---|---|---|---| +| `none` | — | off | off | +| `local` | MCore `CudaGraphManager` | per-layer, controlled by `--cuda-graph-modules` | `layer` (default) or `block`, controlled by `--inference-cuda-graph-scope` | +| `transformer_engine` | TE `make_graphed_callables()` | per-layer, controlled by `--cuda-graph-modules` | not supported (`none` only) | +| `full_iteration` | MCore `FullCudaGraphWrapper` | one graph per training iteration; `--cuda-graph-modules` must be empty | not supported (`none` only) | + +--- + +## CUDA Graph — Local Implementation (`--cuda-graph-impl local`) + +Uses MCore's built-in `CudaGraphManager`. During training, this is a per-layer mode: +leaving `--cuda-graph-modules` unset captures the whole Transformer layer, while specifying +modules restricts capture to selected sub-regions. During inference, `local` can instead attach +graphs at either the layer boundary or the enclosing block boundary, as controlled by +`--inference-cuda-graph-scope`. + +Operationally, this path is tightly integrated into MCore training and inference: + +- graphable modules create and own their `CudaGraphManager` instances automatically +- the existing training schedules drive warmup/capture/replay automatically +- users select the mode through config flags only; there is no separate helper API to + wire into a custom training loop or a separate need to handle static input buffers + +### Usage + +```bash +--cuda-graph-impl local +``` + +### `--cuda-graph-modules` options + +| Module | What is captured | +|---|---| +| *(empty / not set)* | Entire Transformer layer (default) | +| `attn` | `TransformerLayer._forward_attention()` | +| `mlp` | `TransformerLayer._forward_mlp()` for dense layers | +| `moe` | `TransformerLayer._forward_mlp()` for MoE layers (drop-and-pad only) | +| `moe_router` | MoE router + shared experts (if not EP-comm-overlapped) | +| `moe_preprocess` | `MoELayer.preprocess()` — must be paired with `moe_router` | +| `mamba` | Mamba SSM layer | + +**Example — MoE model, capture attention and router:** +```bash +--cuda-graph-impl local \ +# Optionally restrict captured modules (default: capture whole layer, but not working for MoE dynamic shapes) +--cuda-graph-modules attn moe_router moe_preprocess +``` + +--- + +## CUDA Graph — Transformer Engine Implementation (`--cuda-graph-impl transformer_engine`) + +Uses Transformer Engine's `make_graphed_callables()` path. In Megatron-LM's CLI, this has the +same training granularity as `local`: leaving `--cuda-graph-modules` unset captures the whole +Transformer layer, while specifying modules restricts capture to selected sub-regions. The main difference from +`local` is the backend implementation and feature compatibility. Unlike `local`, this path does +not support inference CUDA graphs. + +Compared to `local`, this path exposes a more general and self-contained API via TE's +`make_graphed_callables()`, giving users greater flexibility and control over how CUDA graphs are +wired into custom training loops. The trade-off is that it requires more manual setup: + +- the training loop must instantiate `TECudaGraphHelper` +- the training loop must call helper methods such as `create_cudagraphs()` and + `cuda_graph_set_manual_hooks()` at the correct points + +Megatron-LM's stock training loop already wires these calls in `megatron/training/training.py`, +but custom training scripts must do the same work themselves. + +### Usage + +```bash +--cuda-graph-impl transformer_engine \ +--cuda-graph-modules attn moe_router moe_preprocess +``` + +The same training `--cuda-graph-modules` options apply as for `local`, and the default is likewise +whole-layer training capture when the flag is omitted. + +--- + +## Full-Iteration Training CUDA Graph (`--cuda-graph-impl full_iteration`) + +Captures the entire training iteration (excluding optimizer) as a single CUDA graph. The same +wrapper is also used for training-loop validation/eval in forward-only mode. This provides the +largest training/validation latency reduction. + +This implementation does not create inference CUDA graphs. For inference, use +`--cuda-graph-impl local --inference-cuda-graph-scope layer|block`. + +### Requirements + +- `--no-check-for-nan-in-loss-and-grad` is required: NaN checks involve CPU-GPU synchronization + which cannot run inside a CUDA graph. +- `--cuda-graph-modules` must be omitted (or left empty): per-module selection has no meaning + when the entire iteration is captured as a single graph. + +### Example + +```bash +--cuda-graph-impl full_iteration \ +--no-check-for-nan-in-loss-and-grad +``` + +--- + +## Common Configuration Examples + +### Dense Model Training + +All three implementations work for dense models: + +```bash +# Per-layer (local) +--cuda-graph-impl local +# equivalent: --cuda-graph-impl local --cuda-graph-modules attn mlp + +# Per-layer (TE) +--cuda-graph-impl transformer_engine +# equivalent: --cuda-graph-impl transformer_engine --cuda-graph-modules attn mlp + +# Full-iteration +--cuda-graph-impl full_iteration \ +--no-check-for-nan-in-loss-and-grad +``` + +### MoE Model Training + +MoE expert dispatch involves dynamic shapes and cannot be captured. `--cuda-graph-modules` is used +to capture only the static parts (attention, router, preprocess) while leaving expert compute in +eager mode. Example using `transformer_engine` (`local` works the same way): + +```bash +--cuda-graph-impl transformer_engine \ +--cuda-graph-modules attn moe_router moe_preprocess +``` + +With paged stash (currently available only on `dev`; see +`docs/user-guide/features/paged_stash.md` on the `dev` branch), expert dispatch shapes become +static (pre-sized via `--moe-expert-rank-capacity-factor`), which allows full-iteration CUDA +graphs to be used on MoE models as well: + +```bash +--cuda-graph-impl full_iteration \ +--no-check-for-nan-in-loss-and-grad \ +--moe-flex-dispatcher-backend hybridep \ +--use-transformer-engine-op-fuser \ +--moe-expert-rank-capacity-factor \ +--moe-paged-stash +``` + +--- + +## Additional Notes + +- `--cuda-graph-warmup-steps` (default: 3) controls how many warmup steps run before CUDA graph + capture. Setting it to 0 is not recommended: some operations rely on the first few iterations + for lazy initialization or autotuning, and capturing too early may produce incorrect or + suboptimal graphs. +- Inference CUDA graphs (serving or RL rollout) currently require + `--cuda-graph-impl local`. Use `--inference-cuda-graph-scope layer|block` with + `local`; all other implementations must set `--inference-cuda-graph-scope none`, + meaning inference runs in eager mode. +- Background reference: [Transformer Engine and Megatron-LM CUDA Graph Support](https://docs.nvidia.com/dl-cuda-graph/torch-cuda-graph/te-megatron-cuda-graphs.html), + which also covers PyTorch CUDA Graph best practices and lessons learned. + +--- + +## Migration Guide + +Legacy configurations (including `--enable-cuda-graph`, `--external-cuda-graph`, the renamed +`--cuda-graph-scope` flag (now `--cuda-graph-modules`), and deprecated module values such as +`full_iteration` and `full_iteration_inference`) are still accepted and automatically migrated +at runtime, but we encourage updating your configs to the new forms: + +| Old command | New command | +|---|---| +| `--enable-cuda-graph` | `--cuda-graph-impl local` | +| `--external-cuda-graph` | `--cuda-graph-impl transformer_engine` | +| `--cuda-graph-scope ` | `--cuda-graph-modules ` | +| `--cuda-graph-impl local --cuda-graph-scope full_iteration` | `--cuda-graph-impl full_iteration` | +| `--cuda-graph-impl local --cuda-graph-scope full_iteration_inference` | `--cuda-graph-impl local --inference-cuda-graph-scope block` | +| `--cuda-graph-impl local --cuda-graph-scope attn moe_router moe_preprocess full_iteration_inference` | `--cuda-graph-impl local --cuda-graph-modules attn moe_router moe_preprocess --inference-cuda-graph-scope block` | diff --git a/docs/user-guide/features/index.md b/docs/user-guide/features/index.md index 9dea6bd34a4..ed4a7b07543 100644 --- a/docs/user-guide/features/index.md +++ b/docs/user-guide/features/index.md @@ -14,6 +14,7 @@ Guides for Megatron Core training features. ```{toctree} :maxdepth: 2 +cuda_graph fine_grained_activation_offloading moe context_parallel diff --git a/examples/rl/model_configs/nemotron6_3b_moe.sh b/examples/rl/model_configs/nemotron6_3b_moe.sh index a807f270a01..7b3c58b799a 100644 --- a/examples/rl/model_configs/nemotron6_3b_moe.sh +++ b/examples/rl/model_configs/nemotron6_3b_moe.sh @@ -65,7 +65,6 @@ MODEL_OPTIONS="\ --inference-dynamic-batching-num-cuda-graphs 2 \ --decode-only-cuda-graphs \ --cuda-graph-impl local \ - --cuda-graph-scope full \ --use-checkpoint-args \ --enable-experimental \ --cross-entropy-loss-fusion \ diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index c8f5e8e6c62..c0318c3b10a 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -45,7 +45,7 @@ from megatron.core.inference.utils import Counter, await_process_call from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.cuda_graphs import delete_cuda_graphs -from megatron.core.transformer.enums import CudaGraphScope +from megatron.core.transformer.enums import InferenceCudaGraphScope from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction from megatron.core.utils import ( deprecate_args, @@ -224,7 +224,8 @@ def __init__(self, controller: TextGenerationController, context: DynamicInferen self.disable_ep_consensus = inference_config.disable_ep_consensus self.ep_consensus_interval = inference_config.ep_consensus_interval self.cuda_graph_impl = model_config.cuda_graph_impl - self.cuda_graph_scope = model_config.cuda_graph_scope + self.inference_cuda_graph_scope = model_config.inference_cuda_graph_scope + self.cuda_graph_modules = model_config.cuda_graph_modules # Initialize engine. self.reset() @@ -331,17 +332,11 @@ def create_cuda_graphs(self, reset_context: bool = True): reset_context (bool): Whether to reset the context after building cuda graphs. """ - if self.cuda_graph_impl != "local": + if self.inference_cuda_graph_scope == InferenceCudaGraphScope.none: return - if ( - CudaGraphScope.full_iteration in self.cuda_graph_scope - and CudaGraphScope.full_iteration_inference not in self.cuda_graph_scope - ): - warnings.warn( - "\n\n*** WARNING: 'full_iteration' CUDA graph scope used during inference! " - "This will not create inference CUDA graphs. Use '--cuda-graph-scope=full_iteration_inference' instead. ***\n" - ) + if self.cuda_graph_impl != "local": + return context = self.context controller = self.controller diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 87edddea566..3e788fec0b1 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -36,7 +36,6 @@ gather_from_sequence_parallel_region, scatter_to_sequence_parallel_region, ) -from megatron.core.transformer.enums import CudaGraphScope from megatron.core.transformer.moe.moe_layer import BaseMoELayer from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction from megatron.core.transformer.utils import set_model_to_sequence_parallel @@ -1924,10 +1923,7 @@ def generate_all_output_tokens_static_batch( ) # Check whether CUDA graphs are enabled - enable_cuda_graph = ( - model_config.cuda_graph_impl == "local" - and CudaGraphScope.full_iteration not in model_config.cuda_graph_scope - ) + enable_cuda_graph = model_config.cuda_graph_impl == "local" # Pad batch tokens if necessary batch_size = len(active_requests) diff --git a/megatron/core/models/common/language_module/language_module.py b/megatron/core/models/common/language_module/language_module.py index 84b0ca2fea3..34e3f6b1ba4 100644 --- a/megatron/core/models/common/language_module/language_module.py +++ b/megatron/core/models/common/language_module/language_module.py @@ -22,7 +22,7 @@ is_vp_last_stage, ) from megatron.core.process_groups_config import ProcessGroupCollection -from megatron.core.transformer.enums import AttnBackend, CudaGraphScope +from megatron.core.transformer.enums import AttnBackend from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.multi_token_prediction import tie_word_embeddings_state_dict from megatron.core.transformer.transformer_config import TransformerConfig @@ -159,8 +159,8 @@ def compute_language_model_loss(self, labels: Tensor, logits: Tensor) -> Tensor: labels = torch.as_strided(labels, labels.size(), (labels.size()[1], 1)) # Use is_cg_capturable=True for full iteration CUDA graphs to avoid torch.equal checks is_cg_capturable = ( - hasattr(self.config, 'cuda_graph_scope') - and CudaGraphScope.full_iteration in self.config.cuda_graph_scope + hasattr(self.config, 'cuda_graph_impl') + and self.config.cuda_graph_impl == "full_iteration" ) if is_cg_capturable and not is_te_min_version("2.7.0"): from megatron.core.utils import get_te_version @@ -169,7 +169,7 @@ def compute_language_model_loss(self, labels: Tensor, logits: Tensor) -> Tensor: raise AssertionError( f"CUDA graph compatible cross entropy requires TransformerEngine >= 2.7.0, " f"but found version {current_version}. Please upgrade TransformerEngine " - f"or set cuda_graph_scope to a value other than 'full_iteration'." + f"or set cuda_graph_impl to a value other than 'full_iteration'." ) loss = te_parallel_cross_entropy( diff --git a/megatron/core/models/gpt/fine_grained_callables.py b/megatron/core/models/gpt/fine_grained_callables.py index acfa7a1e8a8..9ad35c6ffd4 100644 --- a/megatron/core/models/gpt/fine_grained_callables.py +++ b/megatron/core/models/gpt/fine_grained_callables.py @@ -14,7 +14,7 @@ FineGrainedActivationOffloadingInterface as off_interface, ) from megatron.core.pipeline_parallel.utils import ScheduleNode, make_viewless -from megatron.core.transformer.enums import CudaGraphScope +from megatron.core.transformer.enums import CudaGraphModule from megatron.core.transformer.module import GraphableMegatronModule, float16_to_fp32 from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.multi_token_prediction import ( @@ -93,7 +93,7 @@ def should_free_input(name, is_moe, config, num_local_experts): # If moe_preprocess is in cuda graph scope, tokens and probs are fixed size tensors, # so they cannot be freed. "moe_dispatch": not (enable_deepep or enable_hybridep) - and (CudaGraphScope.moe_preprocess not in config.cuda_graph_scope), + and (CudaGraphModule.moe_preprocess not in config.cuda_graph_modules), } return free_input_nodes.get(name, False) @@ -388,16 +388,16 @@ def __init__(self, layer): ) else: self.shared_expert_dw_callable = None - self.cuda_graph_scope = layer.config.cuda_graph_scope + self.cuda_graph_modules = layer.config.cuda_graph_modules def backward_dw(self): """Execute weight gradients, skipping CUDA graphed components during replay.""" is_replay = hasattr(self.layer, 'cuda_graphs') and self.layer.cuda_graphs if self.shared_expert_dw_callable is not None and ( - not is_replay or CudaGraphScope.moe_router not in self.cuda_graph_scope + not is_replay or CudaGraphModule.moe_router not in self.cuda_graph_modules ): self.shared_expert_dw_callable() - if not is_replay or CudaGraphScope.attn not in self.cuda_graph_scope: + if not is_replay or CudaGraphModule.attn not in self.cuda_graph_modules: self.attn_dw_callable() if is_replay and self.graphed_backward_dw_callable is not None: self.graphed_backward_dw_callable() diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index da8a0e2fdfd..24378a00058 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -24,7 +24,7 @@ from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.quantization.utils import get_quant_config_or_none from megatron.core.tensor_parallel import gather_from_sequence_parallel_region -from megatron.core.transformer.enums import CudaGraphScope, ModelType +from megatron.core.transformer.enums import ModelType from megatron.core.transformer.multi_token_prediction import ( MultiTokenPredictionBlock, mtp_on_this_rank, @@ -408,13 +408,7 @@ def _preprocess( if ( in_inference_mode - and ( - ( - self.config.cuda_graph_impl == "local" - and CudaGraphScope.full_iteration not in self.config.cuda_graph_scope - ) - or self.config.flash_decode - ) + and (self.config.cuda_graph_impl == "local" or self.config.flash_decode) and inference_context.is_static_batching() ): current_batch_size = input_ids.shape[0] diff --git a/megatron/core/models/hybrid/hybrid_block.py b/megatron/core/models/hybrid/hybrid_block.py index 6d20bcdd6e5..93cb56f5297 100644 --- a/megatron/core/models/hybrid/hybrid_block.py +++ b/megatron/core/models/hybrid/hybrid_block.py @@ -23,7 +23,6 @@ from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer import TransformerConfig -from megatron.core.transformer.enums import CudaGraphScope from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec, build_module @@ -253,13 +252,7 @@ def forward( inference_context.seqlen_offset = inference_context.sequence_len_offset if ( - ( - ( - self.config.cuda_graph_impl == "local" - and CudaGraphScope.full_iteration not in self.config.cuda_graph_scope - ) - or self.config.flash_decode - ) + (self.config.cuda_graph_impl == "local" or self.config.flash_decode) and inference_context and inference_context.is_static_batching() and not self.training diff --git a/megatron/core/models/hybrid/hybrid_model.py b/megatron/core/models/hybrid/hybrid_model.py index 4b5858ef9da..15a7303ed2f 100644 --- a/megatron/core/models/hybrid/hybrid_model.py +++ b/megatron/core/models/hybrid/hybrid_model.py @@ -20,7 +20,7 @@ from megatron.core.quantization.utils import get_quant_config_or_none from megatron.core.tensor_parallel import gather_from_sequence_parallel_region from megatron.core.transformer import TransformerConfig -from megatron.core.transformer.enums import CudaGraphScope, ModelType +from megatron.core.transformer.enums import InferenceCudaGraphScope, ModelType from megatron.core.transformer.module import GraphableMegatronModule from megatron.core.transformer.multi_token_prediction import ( MultiTokenPredictionBlock, @@ -353,7 +353,7 @@ def _should_call_local_cudagraph(self, *args, **kwargs): kwargs.get('inference_context') is not None or kwargs.get('inference_params') is not None ) - and CudaGraphScope.full_iteration_inference in self.config.cuda_graph_scope + and self.config.inference_cuda_graph_scope == InferenceCudaGraphScope.block ): if kwargs['inference_context'].is_static_batching(): using_cuda_graph = kwargs['inference_context'].is_decode_only() @@ -373,7 +373,7 @@ def create_mcore_cudagraph_manager(self, config): """ Create the cudagraph manager for the full iteration inference scope """ - if CudaGraphScope.full_iteration_inference in config.cuda_graph_scope: + if config.inference_cuda_graph_scope == InferenceCudaGraphScope.block: from megatron.core.transformer.cuda_graphs import CudaGraphManager self.cudagraph_manager = CudaGraphManager(config) diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 14fc6041574..85f9e584953 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -24,7 +24,6 @@ ProcessGroupCollection, ) from megatron.core.transformer.cuda_graphs import create_cudagraphs, set_current_microbatch -from megatron.core.transformer.enums import CudaGraphScope from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler from megatron.core.utils import ( drain_embedding_wgrad_compute, @@ -756,11 +755,7 @@ def forward_backward_no_pipelining( if config.timers is not None: config.timers('forward-backward').stop() - if ( - hasattr(config, 'cuda_graph_impl') - and config.cuda_graph_impl == "local" - and CudaGraphScope.full_iteration not in config.cuda_graph_scope - ): + if hasattr(config, 'cuda_graph_impl') and config.cuda_graph_impl == "local": create_cudagraphs() return forward_data_store @@ -1989,11 +1984,7 @@ def pp_post_backward(input_tensor_grad, vp_stage=None): if config.timers is not None: config.timers('forward-backward').stop() - if ( - hasattr(config, 'cuda_graph_impl') - and config.cuda_graph_impl == "local" - and CudaGraphScope.full_iteration not in config.cuda_graph_scope - ): + if hasattr(config, 'cuda_graph_impl') and config.cuda_graph_impl == "local": create_cudagraphs() nvtx_range_pop(suffix="misc") @@ -2393,11 +2384,7 @@ def enable_grad_sync(): if config.timers is not None: config.timers('forward-backward').stop() - if ( - hasattr(config, 'cuda_graph_impl') - and config.cuda_graph_impl == "local" - and CudaGraphScope.full_iteration not in config.cuda_graph_scope - ): + if hasattr(config, 'cuda_graph_impl') and config.cuda_graph_impl == "local": create_cudagraphs() return forward_data_store diff --git a/megatron/core/safe_globals.py b/megatron/core/safe_globals.py index 9241405876a..92151968051 100755 --- a/megatron/core/safe_globals.py +++ b/megatron/core/safe_globals.py @@ -16,7 +16,12 @@ from megatron.core.enums import ModelType from megatron.core.optimizer import OptimizerConfig from megatron.core.rerun_state_machine import RerunDiagnostic, RerunMode, RerunState -from megatron.core.transformer.enums import AttnBackend, CudaGraphScope +from megatron.core.transformer.enums import ( + AttnBackend, + CudaGraphModule, + CudaGraphScope, + InferenceCudaGraphScope, +) SAFE_GLOBALS = [ SimpleNamespace, @@ -27,7 +32,9 @@ UInt32DType, Namespace, AttnBackend, + CudaGraphModule, CudaGraphScope, + InferenceCudaGraphScope, ModelType, OptimizerConfig, RerunDiagnostic, diff --git a/megatron/core/ssm/mamba_layer.py b/megatron/core/ssm/mamba_layer.py index 17903cebf3b..98a73e5ddfc 100644 --- a/megatron/core/ssm/mamba_layer.py +++ b/megatron/core/ssm/mamba_layer.py @@ -16,7 +16,7 @@ from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.process_groups_config import ProcessGroupCollection -from megatron.core.transformer.enums import CudaGraphScope +from megatron.core.transformer.enums import CudaGraphModule, InferenceCudaGraphScope from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.module import GraphableMegatronModule from megatron.core.transformer.spec_utils import ModuleSpec, build_module @@ -94,9 +94,14 @@ def __init__( def create_mcore_cudagraph_manager(self, config): """Register the mamba layer for cudagraphs.""" + assert self.config.cuda_graph_impl == "local" + from megatron.core.transformer.cuda_graphs import CudaGraphManager - if not self.config.cuda_graph_scope or CudaGraphScope.mamba in self.config.cuda_graph_scope: + if ( + not self.config.cuda_graph_modules + and self.config.inference_cuda_graph_scope != InferenceCudaGraphScope.block + ) or CudaGraphModule.mamba in self.config.cuda_graph_modules: self.cudagraph_manager = CudaGraphManager(config) def mamba_state_shapes_per_request(self) -> Tuple[Tuple[int], Tuple[int]]: @@ -202,7 +207,7 @@ def _should_call_local_cudagraph(self, *args, **kwargs): hasattr(self, 'cudagraph_manager') and kwargs.get('attention_mask') is None and kwargs.get('inference_context') is not None - and not self.config.cuda_graph_scope # empty-list = per-layer CUDA graphs + and not self.config.cuda_graph_modules # empty-list = per-layer CUDA graphs ): context = kwargs['inference_context'] using_cuda_graph = (context.is_static_batching() and context.is_decode_only()) or ( diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index f89259be442..553dff60106 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -52,7 +52,7 @@ from ..models.common.embeddings.yarn_rotary_pos_embedding import ( _yarn_get_concentration_factor_from_config, ) -from .enums import AttnMaskType, CudaGraphScope +from .enums import AttnMaskType from .transformer_config import TransformerConfig try: @@ -1139,7 +1139,6 @@ def forward( if ( in_decode_mode and self.config.cuda_graph_impl == "local" - and CudaGraphScope.full_iteration not in self.config.cuda_graph_scope and inference_context.is_static_batching() ): raise ValueError(f"CUDA graphs must use flash decode with static batching!") diff --git a/megatron/core/transformer/cuda_graph_config.py b/megatron/core/transformer/cuda_graph_config.py new file mode 100644 index 00000000000..075e09ea94f --- /dev/null +++ b/megatron/core/transformer/cuda_graph_config.py @@ -0,0 +1,136 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +from typing import List, Optional, Set, Tuple, Union + +from megatron.core.transformer.enums import CudaGraphModule, InferenceCudaGraphScope + +# Maps deprecated scope strings to the (attr_name, new_value) they should set. +# new_value is the actual value to assign: a str for cuda_graph_impl (which is a +# Literal string type) or an InferenceCudaGraphScope enum for inference_cuda_graph_scope. +CUDA_GRAPH_MODULES_DEPRECATIONS = { + 'full_iteration': ('cuda_graph_impl', 'full_iteration'), + 'full_iteration_inference': ('inference_cuda_graph_scope', InferenceCudaGraphScope.block), +} + +# Canonical mapping from cuda_graph_impl to the set of allowed inference granularities. +# Shared by transformer_config.__post_init__ and validate_args to avoid duplication. +ALLOWED_INFERENCE_SCOPES: dict[str, Set[InferenceCudaGraphScope]] = { + "none": {InferenceCudaGraphScope.none}, + "local": {InferenceCudaGraphScope.layer, InferenceCudaGraphScope.block}, + "transformer_engine": {InferenceCudaGraphScope.none}, + "full_iteration": {InferenceCudaGraphScope.none}, +} + + +def normalize_cuda_graph_modules( + scopes: Optional[Union[str, CudaGraphModule, List[Union[str, CudaGraphModule]]]] +) -> Tuple[List[CudaGraphModule], List[Tuple[str, str, object]], bool]: + """Normalize mixed CUDA graph scope inputs into enum values plus deprecation metadata.""" + + if scopes is None: + raw_scopes = [] + elif isinstance(scopes, list): + raw_scopes = list(scopes) + elif isinstance(scopes, str): + raw_scopes = scopes.split(',') if scopes else [] + else: + raw_scopes = [scopes] + + if "full" in raw_scopes: + assert raw_scopes == ["full"], "full scope cannot be used with other scopes." + return [], [], True + + normalized_scopes: List[CudaGraphModule] = [] + deprecated_scopes: List[Tuple[str, str, object]] = [] + for scope in raw_scopes: + if isinstance(scope, CudaGraphModule): + normalized_scopes.append(scope) + else: + assert isinstance(scope, str), ( + "cuda_graph_modules values must be strings or CudaGraphModule enums, " + f"got {scope!r}." + ) + if scope in CUDA_GRAPH_MODULES_DEPRECATIONS: + attr, value = CUDA_GRAPH_MODULES_DEPRECATIONS[scope] + deprecated_scopes.append((scope, attr, value)) + else: + normalized_scopes.append(CudaGraphModule[scope]) + + return normalized_scopes, deprecated_scopes, False + + +def normalize_inference_cuda_graph_scope( + scope: Optional[Union[str, InferenceCudaGraphScope]], cuda_graph_impl: str +) -> InferenceCudaGraphScope: + """Normalize inference CUDA graph scope and apply the impl-derived default.""" + + if scope is None: + if cuda_graph_impl == "local": + return InferenceCudaGraphScope.layer + return InferenceCudaGraphScope.none + + if isinstance(scope, InferenceCudaGraphScope): + return scope + + assert isinstance(scope, str), ( + "inference_cuda_graph_scope must be a string or " + f"InferenceCudaGraphScope enum, got {scope!r}." + ) + return InferenceCudaGraphScope[scope] + + +def validate_deprecated_cuda_graph_modules_migration_inputs( + deprecated_scopes: List[Tuple[str, str, object]], + cuda_graph_impl: str, + inference_cuda_graph_scope: Optional[Union[str, InferenceCudaGraphScope]], +) -> None: + """Reject ambiguous mixed old/new CUDA graph inputs before applying migration. + + Deprecated scope strings are still accepted for compatibility, but only when they are not + combined with conflicting new-style fields. + """ + + deprecated_scope_names = [scope for scope, _, _ in deprecated_scopes] + if not deprecated_scope_names: + return + + if len(set(deprecated_scope_names)) > 1: + raise AssertionError( + "cuda_graph_modules cannot contain multiple deprecated values at the same time: " + f"{deprecated_scope_names!r}." + ) + + scope = deprecated_scope_names[0] + if isinstance(inference_cuda_graph_scope, str): + inference_cuda_graph_scope = InferenceCudaGraphScope[inference_cuda_graph_scope] + + if scope == "full_iteration": + assert cuda_graph_impl in ("none", "local", "full_iteration"), ( + "cuda_graph_modules='full_iteration' cannot be combined with " + f"cuda_graph_impl={cuda_graph_impl!r}." + ) + assert inference_cuda_graph_scope in (None, InferenceCudaGraphScope.none), ( + "cuda_graph_modules='full_iteration' cannot be combined with " + "inference_cuda_graph_scope=" + f"{getattr(inference_cuda_graph_scope, 'name', inference_cuda_graph_scope)!r}." + ) + elif scope == "full_iteration_inference": + assert cuda_graph_impl in ("none", "local"), ( + "cuda_graph_modules='full_iteration_inference' cannot be combined with " + f"cuda_graph_impl={cuda_graph_impl!r}." + ) + assert inference_cuda_graph_scope in (None, InferenceCudaGraphScope.block), ( + "cuda_graph_modules='full_iteration_inference' cannot be combined with " + "inference_cuda_graph_scope=" + f"{getattr(inference_cuda_graph_scope, 'name', inference_cuda_graph_scope)!r}." + ) + + +def get_deprecated_cuda_graph_modules_migration( + scope: str, attr: str, value: object, cuda_graph_impl: str +) -> Optional[Tuple[str, object]]: + """Return the effective new-style migration for a deprecated cuda_graph_modules value.""" + + if scope == "full_iteration_inference" and cuda_graph_impl == "none": + return None + return attr, value diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index 9cd8cd2ffb6..4b1dc3260ab 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -28,7 +28,7 @@ get_cuda_rng_tracker, is_checkpointing, ) -from megatron.core.transformer.enums import CudaGraphScope +from megatron.core.transformer.enums import CudaGraphModule from megatron.core.transformer.module import GraphableMegatronModule, MegatronModule from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import ( @@ -1696,8 +1696,8 @@ def _layer_is_graphable(layer, config): if not isinstance(layer, GraphableMegatronModule): return False - # If cuda_graph_scope is not set, every layer is graphed. - if not config.cuda_graph_scope: + # If cuda_graph_modules is not set, every layer is graphed. + if not config.cuda_graph_modules: return True # import modules here to avoid a circular import @@ -1707,24 +1707,24 @@ def _layer_is_graphable(layer, config): from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.transformer_layer import TransformerLayer - if isinstance(layer, MambaLayer) and CudaGraphScope.mamba in config.cuda_graph_scope: + if isinstance(layer, MambaLayer) and CudaGraphModule.mamba in config.cuda_graph_modules: # mamba layer. return True if isinstance(layer, TransformerLayer): - if CudaGraphScope.attn in config.cuda_graph_scope and not ( + if CudaGraphModule.attn in config.cuda_graph_modules and not ( isinstance(layer.self_attention, IdentityOp) and isinstance(layer.cross_attention, IdentityOp) ): # attn layer. return True if ( - CudaGraphScope.moe in config.cuda_graph_scope - or CudaGraphScope.moe_router in config.cuda_graph_scope - or CudaGraphScope.moe_preprocess in config.cuda_graph_scope + CudaGraphModule.moe in config.cuda_graph_modules + or CudaGraphModule.moe_router in config.cuda_graph_modules + or CudaGraphModule.moe_preprocess in config.cuda_graph_modules ) and isinstance(layer.mlp, MoELayer): # moe layer. return True - if CudaGraphScope.mlp in config.cuda_graph_scope and isinstance(layer.mlp, MLP): + if CudaGraphModule.mlp in config.cuda_graph_modules and isinstance(layer.mlp, MLP): # mlp layer. return True return False @@ -1753,11 +1753,6 @@ def __init__( "Setting NCCL_GRAPH_REGISTER=0 to avoid illegal memory access when using " "CUDA Graph with PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True." ) - assert CudaGraphScope.full_iteration not in config.cuda_graph_scope, ( - "full_iteration cuda graph is not supported for cuda_graph_impl=transformer_engine. " - "Please use cuda_graph_impl=local instead." - ) - self.model = model self.config = config self.seq_length = seq_length @@ -1978,8 +1973,8 @@ def get_rotary_pos_emb(transformer_module, transformer_input): isinstance(layer, TransformerLayer) and not isinstance(layer.self_attention, IdentityOp) and ( - not self.config.cuda_graph_scope - or CudaGraphScope.attn in self.config.cuda_graph_scope + not self.config.cuda_graph_modules + or CudaGraphModule.attn in self.config.cuda_graph_modules ) ) @@ -2187,8 +2182,8 @@ def _get_cuda_graph_input_data(self): ) chunk_id_list = None if self.config.overlap_moe_expert_parallel_comm: - wgrad_in_graph_scope = CudaGraphScope.attn in self.config.cuda_graph_scope or ( - CudaGraphScope.moe_router in self.config.cuda_graph_scope + wgrad_in_graph_scope = CudaGraphModule.attn in self.config.cuda_graph_modules or ( + CudaGraphModule.moe_router in self.config.cuda_graph_modules and self.config.moe_shared_expert_intermediate_size is not None and not self.config.moe_shared_expert_overlap ) diff --git a/megatron/core/transformer/enums.py b/megatron/core/transformer/enums.py index 1bf16095908..eb83c9c4ca0 100644 --- a/megatron/core/transformer/enums.py +++ b/megatron/core/transformer/enums.py @@ -58,14 +58,55 @@ class AttnBackend(enum.Enum): auto = 5 +class CudaGraphModule(enum.Enum): + """Named capture regions for per-layer CUDA graphs. + + Whole-layer capture is represented outside this enum by an empty scope. Current per-layer + implementations that consume these values are `local` and `transformer_engine`. + """ + + attn = 1 # Captures attention layers + mlp = 2 # Captures MLP layers (dense layers only) + moe = 3 # Captures MoE layers (drop-and-pad MoE layers only) + moe_router = 4 # Captures MoE router part + moe_preprocess = 5 # Captures MoE preprocessing part (requires moe_router) + mamba = 6 # Captures Mamba layers + + +# Deprecated: use CudaGraphModule instead. Retained only for checkpoint backward compat. class CudaGraphScope(enum.Enum): - """Cuda Graph Scope - defines which parts of the model to capture.""" - - full_iteration = 1 # Captures the entire training iteration - attn = 2 # Captures attention layers - mlp = 3 # Captures MLP layers (dense layers only) - moe = 4 # Captures MoE layers (drop-and-pad MoE layers only) - moe_router = 5 # Captures MoE router part - moe_preprocess = 6 # Captures MoE preprocessing part (requires moe_router) - mamba = 7 # Captures Mamba layers - full_iteration_inference = 8 # Captures the entire inference iteration + """Deprecated predecessor of CudaGraphModule. + + Preserved as a standalone class (not an alias) so that pre-refactor checkpoints that + stored CudaGraphScope enum instances can be deserialized correctly. The original ordinals + differ from CudaGraphModule (full_iteration=1, attn=2, …), so a simple alias would + silently reconstruct enum members with the wrong identity. + + Do NOT use in new code. Migration guide: + - full_iteration → cuda_graph_impl="full_iteration" + - full_iteration_inference → inference_cuda_graph_scope=InferenceCudaGraphScope.block + - all other members → equivalent CudaGraphModule member + """ + + full_iteration = 1 + attn = 2 + mlp = 3 + moe = 4 + moe_router = 5 + moe_preprocess = 6 + mamba = 7 + full_iteration_inference = 8 + + +class InferenceCudaGraphScope(enum.Enum): + """Inference CUDA graph scope. + + This controls the ownership boundary for inference CUDA graphs: + - none: inference runs in eager mode (no CUDA graphs). + - layer: graphs are owned at the module/layer boundary, e.g. TransformerLayer or MambaLayer. + - block: graphs are owned by the enclosing block, e.g. TransformerBlock or HybridBlock. + """ + + none = 1 + layer = 2 + block = 3 diff --git a/megatron/core/transformer/module.py b/megatron/core/transformer/module.py index 6539ee36105..c30c107e791 100644 --- a/megatron/core/transformer/module.py +++ b/megatron/core/transformer/module.py @@ -10,7 +10,6 @@ from megatron.core import parallel_state from megatron.core.dist_checkpointing.mapping import ShardedStateDict -from megatron.core.transformer.enums import CudaGraphScope from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.utils import ( ensure_metadata_has_dp_cp_group, @@ -169,10 +168,7 @@ def __init__(self, config: TransformerConfig, vp_stage: Optional[int] = None): assert isinstance(config, TransformerConfig), "config must be a TransformerConfig" # Enable cuda graphs. - if ( - config.cuda_graph_impl == "local" - and CudaGraphScope.full_iteration not in config.cuda_graph_scope - ): + if config.cuda_graph_impl == "local": if hasattr(self, "create_mcore_cudagraph_manager"): self.create_mcore_cudagraph_manager(config) else: diff --git a/megatron/core/transformer/moe/README.md b/megatron/core/transformer/moe/README.md index 213941b32cf..6a268ce43e9 100644 --- a/megatron/core/transformer/moe/README.md +++ b/megatron/core/transformer/moe/README.md @@ -259,7 +259,7 @@ After establishing a working parallel configuration, profile your training to id | Optimization | Config | |--------------|--------| | Disable Python GC | `--manual-gc --manual-gc-interval 100` | -| Enable CUDA Graphs | `--cuda-graph-impl transformer_engine --cuda-graph-scope attn moe_router moe_preprocess` | +| Enable CUDA Graphs | `--cuda-graph-impl transformer_engine --cuda-graph-modules attn moe_router moe_preprocess` | | Reduce kernel launches | Decrease TP size or increase micro-batch size | #### Computation Bottleneck @@ -499,14 +499,19 @@ FP8 training provides benefits across all three performance walls: ### CUDA Graph -CUDA Graph functionality can be enabled through the `--cuda-graph-impl` option. There are two implementations: +CUDA Graph functionality can be enabled through the `--cuda-graph-impl` option. There are three implementations: -1. `--cuda-graph-impl=local`: Captures cuda graphs using the MCore-internal cuda graph manager. -2. `--cuda-graph-impl=transformer_engine`: Captures cuda graphs using the TE `make_graphed_callables()` interface. +1. `--cuda-graph-impl=local`: Captures per-layer cuda graphs using the MCore-internal cuda graph manager. +2. `--cuda-graph-impl=transformer_engine`: Captures per-layer cuda graphs using the TE `make_graphed_callables()` interface. +3. `--cuda-graph-impl=full_iteration`: Captures the whole training/evaluation forward-backward iteration as a single cuda graph. + +For inference, CUDA graph scope is controlled separately with +`--inference-cuda-graph-scope=layer|block` together with +`--cuda-graph-impl=local`. To use `--cuda-graph-impl=transformer_engine`, the user should call related methods `TECudaGraphHelper.create_cudagraphs()` and `TECudaGraphHelper.cuda_graph_set_manual_hooks()` in the training script. Please refer to the usage in `megatron/training/training.py`. -For MoE models, certain configurations may prevent CUDA Graph capture of MoE layers. Specifically, when `--moe-expert-capacity-factor` and `--moe-pad-expert-input-to-capacity` are not set, the resulting dynamic shapes make MoE layers uncapturable. In such cases, you can still leverage CUDA Graphs for the attention layers (operations in `TransformerLayer._forward_attention()`) by setting `--cuda-graph-scope=attn`, while leaving the MoE layers (operations in `TransformerLayer._forward_mlp()`) unmodified. See the argument description for more usage of `--cuda-graph-scope`. +For MoE models, certain configurations may prevent CUDA Graph capture of MoE layers. Specifically, when `--moe-expert-capacity-factor` and `--moe-pad-expert-input-to-capacity` are not set, the resulting dynamic shapes make MoE layers uncapturable. In such cases, you can still leverage CUDA Graphs for the attention layers (operations in `TransformerLayer._forward_attention()`) by setting `--cuda-graph-modules=attn`, while leaving the MoE layers (operations in `TransformerLayer._forward_mlp()`) unmodified. See the argument description for more usage of `--cuda-graph-modules`. ## MoE Arguments Reference ### Core Arguments diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index 2ddc17a567a..e240d74d846 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -637,7 +637,7 @@ def custom_forward(hidden_states, intermediate_tensors=None, padding_mask=None): # This signal is raised from the maybe_skip_or_early_return_by_cudagraph decorator. # It means we should early-return from the MoE layer forward pass. # This happens when we are partially capturing the CUDA graph of the MoE layer, - # like cuda_graph_scope=["moe_router", "moe_preprocess"]. + # like cuda_graph_modules=["moe_router", "moe_preprocess"]. # We need to return the intermediate tensors as CUDA graph outputs. return e.get_early_return_outputs(hidden_states, shared_expert_output) diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index f258f3474ae..8552ecaad82 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -19,7 +19,7 @@ ) from megatron.core.tensor_parallel.mappings import reduce_from_tensor_model_parallel_region from megatron.core.transformer.cuda_graphs import is_graph_capturing -from megatron.core.transformer.enums import CudaGraphScope +from megatron.core.transformer.enums import CudaGraphModule from megatron.core.transformer.moe.router_replay import RouterReplay from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import internal_api, is_te_min_version @@ -1605,13 +1605,13 @@ def maybe_raise_signal(moe_layer, **kwargs): ): if ( step_condition == "route" - and CudaGraphScope.moe_router in moe_layer.config.cuda_graph_scope - and CudaGraphScope.moe_preprocess not in moe_layer.config.cuda_graph_scope + and CudaGraphModule.moe_router in moe_layer.config.cuda_graph_modules + and CudaGraphModule.moe_preprocess not in moe_layer.config.cuda_graph_modules ): raise MoECudaGraphPartialCaptureSignal(moe_layer, "route", **kwargs) elif ( step_condition == "preprocess" - and CudaGraphScope.moe_preprocess in moe_layer.config.cuda_graph_scope + and CudaGraphModule.moe_preprocess in moe_layer.config.cuda_graph_modules ): raise MoECudaGraphPartialCaptureSignal(moe_layer, "preprocess", **kwargs) diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index 717e285a249..0e379b10bb8 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -16,7 +16,7 @@ gather_from_sequence_parallel_region, reduce_scatter_to_sequence_parallel_region, ) -from megatron.core.transformer.enums import CudaGraphScope +from megatron.core.transformer.enums import CudaGraphModule from megatron.core.transformer.moe.fused_a2a import ( fused_combine, fused_dispatch, @@ -80,7 +80,7 @@ def __init__( self.ep_size = utils.get_pg_size(self.ep_group) # Attributes that need to be captured in cudagraph. These attributes are returned - # as cudagraph outputs when the cuda_graph_scope contains moe_preprocess. + # as cudagraph outputs when the cuda_graph_modules contains moe_preprocess. self.cudagraph_attrs = [] self.valid_cudagraph_attrs = None @@ -243,7 +243,7 @@ def __init__( self.global_local_map = None # Attributes that need to be captured in cudagraph. These attributes are returned - # as cudagraph outputs when the cuda_graph_scope contains moe_preprocess. + # as cudagraph outputs when the cuda_graph_modules contains moe_preprocess. self.cudagraph_attrs = ['routing_map'] def dispatch_preprocess( @@ -442,15 +442,15 @@ def __init__( } self.cuda_dtoh_point = "before_permutation_1" if config.cuda_graph_impl != "none" and ( - CudaGraphScope.moe_preprocess in config.cuda_graph_scope - or not self.config.cuda_graph_scope + CudaGraphModule.moe_preprocess in config.cuda_graph_modules + or not self.config.cuda_graph_modules ): self.cuda_dtoh_point = "before_ep_alltoall" if MoEAlltoAllTokenDispatcher.cuda_dtoh_stream is None: MoEAlltoAllTokenDispatcher.cuda_dtoh_stream = torch.cuda.Stream() # Attributes that need to be captured in cudagraph. These attributes are returned - # as cudagraph outputs when the cuda_graph_scope contains moe_preprocess. + # as cudagraph outputs when the cuda_graph_modules contains moe_preprocess. self.cudagraph_attrs = [ 'tokens_per_expert', 'input_splits', diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index 8bea3b8c94e..ec100448514 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -19,7 +19,7 @@ from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.pipeline_parallel.utils import is_vp_first_stage, is_vp_last_stage from megatron.core.process_groups_config import ProcessGroupCollection -from megatron.core.transformer.enums import CudaGraphScope, LayerType +from megatron.core.transformer.enums import InferenceCudaGraphScope, LayerType from megatron.core.transformer.module import GraphableMegatronModule, MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.torch_norm import LayerNormBuilder @@ -621,7 +621,7 @@ def _should_call_local_cudagraph(self, *args, **kwargs): kwargs.get('inference_context') is not None or kwargs.get('inference_params') is not None ) - and CudaGraphScope.full_iteration_inference in self.config.cuda_graph_scope + and self.config.inference_cuda_graph_scope == InferenceCudaGraphScope.block ): if kwargs['inference_context'].is_static_batching(): using_cuda_graph = kwargs['inference_context'].is_decode_only() diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index bb044787b9c..5351e681f88 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -12,7 +12,19 @@ from megatron.core.enums import Fp4Recipe, Fp8Recipe from megatron.core.inference.moe import InferenceGroupedGemmBackend from megatron.core.quantization.quant_config import RecipeConfig -from megatron.core.transformer.enums import AttnBackend, CudaGraphScope +from megatron.core.transformer.cuda_graph_config import ( + ALLOWED_INFERENCE_SCOPES, + get_deprecated_cuda_graph_modules_migration, + normalize_cuda_graph_modules, + normalize_inference_cuda_graph_scope, + validate_deprecated_cuda_graph_modules_migration_inputs, +) +from megatron.core.transformer.enums import ( + AttnBackend, + CudaGraphModule, + CudaGraphScope, + InferenceCudaGraphScope, +) from megatron.core.transformer.pipeline_parallel_layer_layout import PipelineParallelLayerLayout from .._rank_utils import log_single_rank @@ -846,7 +858,7 @@ class TransformerConfig(ModelParallelConfig): enable_cuda_graph: bool = False """DEPRECATED and replaced by cuda_graph_impl. When set to true, either partial CUDA graph (1/many CUDA graph per layer) or full iteration - CUDA graph (1 CUDA graph for whole iteration excluding optimizer) is enabled. --cuda-graph-scope + CUDA graph (1 CUDA graph for whole iteration excluding optimizer) is enabled. cuda_graph_modules determines the scope of graph capture.""" cuda_graph_use_single_mempool: bool = False @@ -868,22 +880,69 @@ class TransformerConfig(ModelParallelConfig): """DEPRECATED and replaced by cuda_graph_impl. When set to true, TransformerLayer layers are swapped with user provided CUDA graphs.""" - cuda_graph_impl: Literal['none', 'local', 'transformer_engine'] = "none" + cuda_graph_impl: Literal['none', 'local', 'transformer_engine', 'full_iteration'] = "none" """Determines the CUDA graph capture implementation. "none": no CUDA graph. - "local": capture the CUDA graph using MCore local implementation. Either partial CUDA graph - (1/many CUDA graph per layer) or full iteration CUDA graph (1 CUDA graph for whole iteration - excluding optimizer) is enabled. - "transformer_engine": capture the CUDA graph using TE make_graphed_callables().""" - - cuda_graph_scope: Union[str, CudaGraphScope, List[str], List[CudaGraphScope]] = "full" - """Determines the CUDA graphs capturing scope. - When cuda_graph_impl is set to "transformer_engine", valid values are "attn", "mlp", "moe", - "moe_router", "moe_preprocess", "mamba". "full" or an empty list means the full layer. "full" - is actually deprecated, but for backward compatibility, we still use "full" as the default - value. It will be transformed to an empty list in __post_init__. - When cuda_graph_impl is set to "local", "full_iteration" can be specified as cuda_graph_scope - to enable whole iteration CUDA graph. All other values enable layerwise CUDA graph.""" + "local": MCore CUDA graph implementation. During training, graphable modules own per-layer + CUDA graphs controlled by cuda_graph_modules. During inference, graph ownership is controlled + separately by inference_cuda_graph_scope. + "transformer_engine": Transformer Engine CUDA graph implementation. During training, TE + make_graphed_callables() creates per-layer CUDA graphs controlled by cuda_graph_modules. + Inference CUDA graphs are not supported; inference_cuda_graph_scope must be "none". + "full_iteration": full-iteration CUDA graph implementation for the training iteration + (1 CUDA graph for the whole forward-backward path excluding the optimizer step). Inference + CUDA graphs are not supported; inference_cuda_graph_scope must be "none". + cuda_graph_modules has no effect when cuda_graph_impl="none" and must be empty when + cuda_graph_impl="full_iteration".""" + + cuda_graph_modules: Union[str, CudaGraphModule, List[str], List[CudaGraphModule]] = "full" + """Selects training capture coverage within per-layer CUDA graphs (local and + transformer_engine implementations). + Valid values are "attn", "mlp", "moe", "moe_router", "moe_preprocess", and "mamba": + "attn": captures operations in TransformerLayer._forward_attention(). + "mlp": captures operations in TransformerLayer._forward_mlp() for a dense layer. + "moe": captures operations in TransformerLayer._forward_mlp() for a MoE layer. + "moe_router": captures operations in TransformerLayer._forward_mlp() up to MoELayer.router(), + including the shared experts if they are not overlapped with EP comm. + "moe_preprocess": captures operations in MoELayer.preprocess(). Must be used together with + "moe_router". + "mamba": captures the mamba layer. + An empty list means capturing the whole Transformer layer. + This field is meaningless when cuda_graph_impl="full_iteration" and must be empty. + Backward compatibility: "full" is deprecated but kept for backward compatibility; it is + transformed to an empty list in __post_init__. The deprecated values "full_iteration" and + "full_iteration_inference" are also accepted and migrated to the new API in __post_init__.""" + + inference_cuda_graph_scope: Optional[InferenceCudaGraphScope] = field( + default=None, + metadata={ + "argparse_meta": { + "type": str, + "choices": [scope.name for scope in InferenceCudaGraphScope], + } + }, + ) + """Controls the CUDA graph scope during inference. + When unset, the effective default is derived from cuda_graph_impl: + "local" -> "layer", all other impls -> "none". + "none": inference runs in eager mode (no CUDA graphs). + "layer": inference graphs are owned at the module/layer boundary, e.g. TransformerLayer or + MambaLayer. + "block": inference graphs are owned by the enclosing block, e.g. TransformerBlock or + HybridBlock. + Currently supported combinations are: + cuda_graph_impl="local" -> "layer" or "block"; + all other cuda_graph_impl values -> "none".""" + + cuda_graph_scope: Optional[ + Union[ + str, CudaGraphModule, CudaGraphScope, List[Union[str, CudaGraphModule, CudaGraphScope]] + ] + ] = None + """Deprecated: renamed to cuda_graph_modules. Accepted for backward compatibility and + migrated to cuda_graph_modules in __post_init__. Will be removed in a future release. + CudaGraphScope instances deserialized from pre-refactor checkpoints are converted to their + string names before normalization so existing CUDA_GRAPH_MODULES_DEPRECATIONS handles them.""" #################### # miscellaneous @@ -1966,103 +2025,149 @@ def __post_init__(self): ) self.cuda_graph_impl = "transformer_engine" - if self.cuda_graph_scope is None: - self.cuda_graph_scope = [] - elif not isinstance(self.cuda_graph_scope, list): - if isinstance(self.cuda_graph_scope, CudaGraphScope): - self.cuda_graph_scope = [self.cuda_graph_scope] + if self.cuda_graph_scope is not None: + assert self.cuda_graph_modules in ( + "full", + None, + [], + ), "cuda_graph_scope and cuda_graph_modules cannot be set together." + warnings.warn( + "cuda_graph_scope is deprecated, use cuda_graph_modules instead.", + DeprecationWarning, + stacklevel=2, + ) + + # CudaGraphScope is preserved as a standalone class (not an alias of CudaGraphModule) + # so pre-refactor checkpoint values deserialize with the correct member identity. + # Convert to string names here so normalize_cuda_graph_modules handles them uniformly. + def _scope_to_str(s): + return s.name if isinstance(s, CudaGraphScope) else s + + scope = self.cuda_graph_scope + if isinstance(scope, list): + self.cuda_graph_modules = [_scope_to_str(s) for s in scope] else: - assert isinstance(self.cuda_graph_scope, str), ( - "cuda_graph_scope must be a string that can be converted to a list of " - f"CudaGraphScope, got {self.cuda_graph_scope}." - ) - self.cuda_graph_scope = self.cuda_graph_scope.split(',') - if all(isinstance(scope, str) for scope in self.cuda_graph_scope): - # Backward compatibility for "full" scope. Now we use an empty list instead. - if "full" in self.cuda_graph_scope: - assert self.cuda_graph_scope == [ - "full" - ], "full scope cannot be used with other scopes." + self.cuda_graph_modules = _scope_to_str(scope) + self.cuda_graph_scope = None + + normalized_scopes, deprecated_scopes, used_full_scope = normalize_cuda_graph_modules( + self.cuda_graph_modules + ) + validate_deprecated_cuda_graph_modules_migration_inputs( + deprecated_scopes, self.cuda_graph_impl, self.inference_cuda_graph_scope + ) + if used_full_scope: + warnings.warn( + "full scope is deprecated. " + "Use empty cuda_graph_modules to capture the whole layer." + ) + for scope, attr, value in deprecated_scopes: + migration = get_deprecated_cuda_graph_modules_migration( + scope, attr, value, self.cuda_graph_impl + ) + if migration is None: warnings.warn( - "full scope is deprecated. " - "Use empty cuda_graph_scope to capture the whole layer." - ) - self.cuda_graph_scope = [] - else: - self.cuda_graph_scope = [CudaGraphScope[scope] for scope in self.cuda_graph_scope] + f"cuda_graph_modules '{scope}' is deprecated and has no effect when " + "cuda_graph_impl='none'. Use cuda_graph_impl='local' with " + "inference_cuda_graph_scope='block' to enable inference CUDA graphs.", + DeprecationWarning, + stacklevel=2, + ) + continue + migration_attr, migration_value = migration + warnings.warn( + f"cuda_graph_modules '{scope}' is deprecated. " + f"Use {migration_attr}={migration_value!r} instead.", + DeprecationWarning, + stacklevel=2, + ) + setattr(self, migration_attr, migration_value) + self.cuda_graph_modules = normalized_scopes assert all( - isinstance(scope, CudaGraphScope) for scope in self.cuda_graph_scope - ), f"cuda_graph_scope must be a list of CudaGraphScope, got {self.cuda_graph_scope}." + isinstance(scope, CudaGraphModule) for scope in self.cuda_graph_modules + ), f"cuda_graph_modules must be a list of CudaGraphModule, got {self.cuda_graph_modules}." + + assert self.cuda_graph_impl in [ + "none", + "transformer_engine", + "local", + "full_iteration", + ], f"Invalid cuda graph implementation: {self.cuda_graph_impl}" + + self.inference_cuda_graph_scope = normalize_inference_cuda_graph_scope( + self.inference_cuda_graph_scope, self.cuda_graph_impl + ) + + assert self.inference_cuda_graph_scope in ALLOWED_INFERENCE_SCOPES[self.cuda_graph_impl], ( + "Invalid inference CUDA graph scope " + f"{self.inference_cuda_graph_scope.name!r} for cuda_graph_impl=" + f"{self.cuda_graph_impl!r}." + ) + assert not ( + self.cuda_graph_impl == "full_iteration" and self.cuda_graph_modules + ), 'cuda_graph_modules must be empty when cuda_graph_impl="full_iteration".' if self.cuda_graph_impl != "none": - assert self.cuda_graph_impl in [ - "transformer_engine", - "local", - ], f"Invalid cuda graph implementation: {self.cuda_graph_impl}" - if self.cpu_offloading and self.cuda_graph_scope != [CudaGraphScope.full_iteration]: + if self.cpu_offloading and self.cuda_graph_impl != "full_iteration": raise ValueError("CUDA graphs not supported with CPU offloading.") - if self.cuda_graph_impl == "local": - # local impl doesn't currently distinguish between moe_preproocess or moe_router - # so just set both if either is specified. - if ( - CudaGraphScope.moe_router in self.cuda_graph_scope - or CudaGraphScope.moe_preprocess in self.cuda_graph_scope - ): - if CudaGraphScope.moe_router not in self.cuda_graph_scope: - self.cuda_graph_scope.append(CudaGraphScope.moe_router) - if CudaGraphScope.moe_preprocess not in self.cuda_graph_scope: - self.cuda_graph_scope.append(CudaGraphScope.moe_preprocess) + # Check cuda graph scopes for per-layer implementations. + if self.cuda_graph_impl in ("local", "transformer_engine"): + if self.cuda_graph_impl == "local": + # local impl doesn't currently distinguish between moe_preprocess or moe_router + # so just set both if either is specified. + if ( + CudaGraphModule.moe_router in self.cuda_graph_modules + or CudaGraphModule.moe_preprocess in self.cuda_graph_modules + ): + if CudaGraphModule.moe_router not in self.cuda_graph_modules: + self.cuda_graph_modules.append(CudaGraphModule.moe_router) + if CudaGraphModule.moe_preprocess not in self.cuda_graph_modules: + self.cuda_graph_modules.append(CudaGraphModule.moe_preprocess) - # Check cuda graph scopes - if self.cuda_graph_impl == "transformer_engine": - assert CudaGraphScope.full_iteration not in self.cuda_graph_scope, ( - "To use full iteration cuda graph, please use " - "cuda_graph_impl=local instead of cuda_graph_impl=transformer_engine." - ) - assert ( - CudaGraphScope.moe not in self.cuda_graph_scope - or CudaGraphScope.moe_router not in self.cuda_graph_scope - ), 'cuda_graph_scope must not contain both moe and moe_router.' - if CudaGraphScope.moe_preprocess in self.cuda_graph_scope: - assert ( - CudaGraphScope.moe_router in self.cuda_graph_scope - ), 'moe_preprocess cuda graph is only supported with moe_router cuda graph.' - if self.num_moe_experts is None or self.num_moe_experts <= 1: assert ( - CudaGraphScope.moe not in self.cuda_graph_scope - and CudaGraphScope.moe_router not in self.cuda_graph_scope - ), 'moe cuda graph is only supported for MoE.' - else: - if self.moe_layer_freq == 1 or ( - isinstance(self.moe_layer_freq, list) and 0 not in self.moe_layer_freq - ): - assert CudaGraphScope.mlp not in self.cuda_graph_scope, ( - 'mlp cuda graph is only supported for dense layers, ' - 'but not found in the model.' - ) - if ( - self.moe_expert_capacity_factor is None - or not self.moe_pad_expert_input_to_capacity - ): + CudaGraphModule.moe not in self.cuda_graph_modules + or CudaGraphModule.moe_router not in self.cuda_graph_modules + ), 'cuda_graph_modules must not contain both moe and moe_router.' + if CudaGraphModule.moe_preprocess in self.cuda_graph_modules: + assert ( + CudaGraphModule.moe_router in self.cuda_graph_modules + ), 'moe_preprocess cuda graph is only supported with moe_router cuda graph.' + if self.num_moe_experts is None or self.num_moe_experts <= 1: assert ( - CudaGraphScope.moe not in self.cuda_graph_scope - ), 'moe cuda graph is only supported with drop-padding MoE.' - if self.moe_token_dispatcher_type == 'alltoall' and ( - self.moe_expert_capacity_factor is not None - or self.moe_router_padding_for_fp8 + CudaGraphModule.moe not in self.cuda_graph_modules + and CudaGraphModule.moe_router not in self.cuda_graph_modules + ), 'moe cuda graph is only supported for MoE.' + else: + if self.moe_layer_freq == 1 or ( + isinstance(self.moe_layer_freq, list) and 0 not in self.moe_layer_freq ): - assert CudaGraphScope.moe_preprocess not in self.cuda_graph_scope, ( - 'moe_preprocess cuda graph is not supported when there are ' - 'DtoH copies and synchronizations in the preprocess step.' + assert CudaGraphModule.mlp not in self.cuda_graph_modules, ( + 'mlp cuda graph is only supported for dense layers, ' + 'but not found in the model.' ) + if ( + self.moe_expert_capacity_factor is None + or not self.moe_pad_expert_input_to_capacity + ): + assert ( + CudaGraphModule.moe not in self.cuda_graph_modules + ), 'moe cuda graph is only supported with drop-padding MoE.' + if self.moe_token_dispatcher_type == 'alltoall' and ( + self.moe_expert_capacity_factor is not None + or self.moe_router_padding_for_fp8 + ): + assert CudaGraphModule.moe_preprocess not in self.cuda_graph_modules, ( + 'moe_preprocess cuda graph is not supported when there are ' + 'DtoH copies and synchronizations in the preprocess step.' + ) if self.recompute_granularity: if self.recompute_granularity != "selective": - assert self.cuda_graph_scope == [ - CudaGraphScope.full_iteration - ], "full recompute is only supported with full iteration CUDA graph." + assert ( + self.cuda_graph_impl == "full_iteration" + ), "full recompute is only supported with full iteration CUDA graph." else: # The recompute module should be inside or outside of the graph scope. # Recompute module coverring graph scope is not allowed. @@ -2071,36 +2176,43 @@ def __post_init__(self): and "moe" in self.recompute_modules ): assert ( - CudaGraphScope.moe_router not in self.cuda_graph_scope + CudaGraphModule.moe_router not in self.cuda_graph_modules ), "moe recompute is not supported with moe_router CUDA graph with: " "--cuda-graph-impl transformer_engine." # Graphed recompute module doesn't accept random number. - if ( - not self.cuda_graph_scope - or CudaGraphScope.full_iteration in self.cuda_graph_scope - ): + # full_cudagraph means either full_iteration impl or an empty per-layer scope + # (which captures the whole layer). + if self.cuda_graph_impl == "full_iteration" or not self.cuda_graph_modules: full_cudagraph = True else: full_cudagraph = False if self.attention_dropout != 0.0: assert ( - not full_cudagraph and CudaGraphScope.attn not in self.cuda_graph_scope + not full_cudagraph + and CudaGraphModule.attn not in self.cuda_graph_modules ) or "core_attn" not in self.recompute_modules, ( "attention dropout is not supported with graphed attention " "recomputation." ) if self.hidden_dropout != 0.0: assert ( - (not full_cudagraph and CudaGraphScope.mlp not in self.cuda_graph_scope) + ( + not full_cudagraph + and CudaGraphModule.mlp not in self.cuda_graph_modules + ) or "mlp" not in self.recompute_modules ) and ( - (not full_cudagraph and CudaGraphScope.moe not in self.cuda_graph_scope) + ( + not full_cudagraph + and CudaGraphModule.moe not in self.cuda_graph_modules + ) or "moe" not in self.recompute_modules ), "hidden dropout is not supported with graphed MLP/MoE recomputation." if self.moe_input_jitter_eps is not None: assert ( - not full_cudagraph and CudaGraphScope.moe not in self.cuda_graph_scope + not full_cudagraph + and CudaGraphModule.moe not in self.cuda_graph_modules ) or "moe" not in self.recompute_modules, ( "moe_input_jitter_eps is not supported with graphed moe recomputation." ) @@ -2182,8 +2294,8 @@ def __post_init__(self): if self.cuda_graph_impl != "none": assert ( self.cuda_graph_impl == "transformer_engine" - and CudaGraphScope.moe not in self.cuda_graph_scope - and CudaGraphScope.mlp not in self.cuda_graph_scope + and CudaGraphModule.moe not in self.cuda_graph_modules + and CudaGraphModule.mlp not in self.cuda_graph_modules ), ( 'CUDA graph scope on moe and mlp is not ' 'supported with overlap_moe_expert_parallel_comm' diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 2191197a7ac..265f8fce256 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -18,7 +18,7 @@ from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.cuda_graphs import is_graph_capturing -from megatron.core.transformer.enums import CudaGraphScope, LayerType +from megatron.core.transformer.enums import CudaGraphModule, InferenceCudaGraphScope, LayerType from megatron.core.transformer.identity_op import IdentityFuncOp, IdentityOp from megatron.core.transformer.mlp import MLP from megatron.core.transformer.module import GraphableMegatronModule @@ -433,7 +433,7 @@ def __init__( def can_recompute_pre_mlp_layernorm_for_cudagraph(): if ( not self.is_moe_layer - or CudaGraphScope.moe_router not in self.config.cuda_graph_scope + or CudaGraphModule.moe_router not in self.config.cuda_graph_modules or self.config.cuda_graph_impl == "local" ): # Not a MoE layer, or not capturing the router part. @@ -453,7 +453,7 @@ def can_recompute_pre_mlp_layernorm_for_cudagraph(): "recompute.", ) return False - if CudaGraphScope.moe_preprocess in self.config.cuda_graph_scope and ( + if CudaGraphModule.moe_preprocess in self.config.cuda_graph_modules and ( self.config.moe_token_dispatcher_type == "alltoall" or self.config.moe_latent_size ): @@ -511,18 +511,25 @@ def can_recompute_pre_mlp_layernorm_for_cudagraph(): def create_mcore_cudagraph_manager(self, config): """Register the transformer layer for cudagraphs.""" + assert self.config.cuda_graph_impl == "local" + from megatron.core.transformer.cuda_graphs import CudaGraphManager - # If full scope, just cudagraph the entire layer - if not self.config.cuda_graph_scope: + # If full scope (no specific sub-scope), cudagraph the entire layer. + # Skip only when inference uses TransformerBlock-level graphs; otherwise the layer keeps + # owning the empty-scope manager. + if ( + not self.config.cuda_graph_modules + and self.config.inference_cuda_graph_scope != InferenceCudaGraphScope.block + ): self.cudagraph_manager = CudaGraphManager(config) elif ( - CudaGraphScope.attn in self.config.cuda_graph_scope + CudaGraphModule.attn in self.config.cuda_graph_modules and self.submodules_config.self_attention != IdentityOp ): self.cudagraph_manager = CudaGraphManager(config) elif ( - CudaGraphScope.mlp in self.config.cuda_graph_scope + CudaGraphModule.mlp in self.config.cuda_graph_modules and self.submodules_config.mlp != IdentityOp ): # Cudagraphing MoE layers are supposed handled by MoeTransforerLayer @@ -859,7 +866,7 @@ def _forward_mlp( and self.config.cuda_graph_impl == "transformer_engine" and self.training and is_graph_capturing() - and CudaGraphScope.moe_router in self.config.cuda_graph_scope + and CudaGraphModule.moe_router in self.config.cuda_graph_modules ): if self.recompute_pre_mlp_layernorm: # Register the recompute hooks to all the cudagraph output tensors, because some @@ -1027,7 +1034,8 @@ def get_layer_static_inputs(self, seq_length, micro_batch_size): static_inputs = super().get_layer_static_inputs(seq_length, micro_batch_size) if not isinstance(self.self_attention, IdentityOp) and ( - not self.config.cuda_graph_scope or CudaGraphScope.attn in self.config.cuda_graph_scope + not self.config.cuda_graph_modules + or CudaGraphModule.attn in self.config.cuda_graph_modules ): slen_per_cp = seq_length // self.config.context_parallel_size static_inputs["attention_mask"] = ( @@ -1042,22 +1050,22 @@ def _get_submodules_under_cudagraphs(self): """ Get the submodules that are covered by cudagraphs. """ - if not self.config.cuda_graph_scope: + if not self.config.cuda_graph_modules: return super()._get_submodules_under_cudagraphs() submodules = [] - if CudaGraphScope.attn in self.config.cuda_graph_scope: + if CudaGraphModule.attn in self.config.cuda_graph_modules: submodules += [ self.input_layernorm, self.self_attention, self.pre_cross_attn_layernorm, self.cross_attention, ] - if (not self.is_moe_layer and CudaGraphScope.mlp in self.config.cuda_graph_scope) or ( - self.is_moe_layer and CudaGraphScope.moe in self.config.cuda_graph_scope + if (not self.is_moe_layer and CudaGraphModule.mlp in self.config.cuda_graph_modules) or ( + self.is_moe_layer and CudaGraphModule.moe in self.config.cuda_graph_modules ): submodules += [self.pre_mlp_layernorm, self.mlp] - elif self.is_moe_layer and CudaGraphScope.moe_router in self.config.cuda_graph_scope: + elif self.is_moe_layer and CudaGraphModule.moe_router in self.config.cuda_graph_modules: submodules += [self.pre_mlp_layernorm, self.mlp.router] if ( self.config.moe_shared_expert_intermediate_size is not None @@ -1070,12 +1078,15 @@ def _te_cuda_graph_capture(self, *args, **kwargs): """ CUDA Graph capture for this layer using TE interface. There are some differences from the normal pass: - 1. In some conditions CUDA graph cannot cover the entire layer. The `cuda_graph_scope` + 1. In some conditions CUDA graph cannot cover the entire layer. The `cuda_graph_modules` attribute can be set to control the scope of the CUDA graph. 2. If context is None, it cannot be returned as output. """ context = None - if not self.config.cuda_graph_scope or CudaGraphScope.attn in self.config.cuda_graph_scope: + if ( + not self.config.cuda_graph_modules + or CudaGraphModule.attn in self.config.cuda_graph_modules + ): hidden_states, context = self._forward_attention(*args, **kwargs) else: if len(args) > 0: @@ -1084,13 +1095,13 @@ def _te_cuda_graph_capture(self, *args, **kwargs): hidden_states = kwargs.pop("hidden_states") if ( - not self.config.cuda_graph_scope - or (not self.is_moe_layer and CudaGraphScope.mlp in self.config.cuda_graph_scope) + not self.config.cuda_graph_modules + or (not self.is_moe_layer and CudaGraphModule.mlp in self.config.cuda_graph_modules) or ( self.is_moe_layer and ( - CudaGraphScope.moe in self.config.cuda_graph_scope - or CudaGraphScope.moe_router in self.config.cuda_graph_scope + CudaGraphModule.moe in self.config.cuda_graph_modules + or CudaGraphModule.moe_router in self.config.cuda_graph_modules ) ) ): @@ -1111,7 +1122,10 @@ def _te_cuda_graph_replay(self, *args, **kwargs): Hence, `inference_context` and `packed_seq_params` are excluded from input list. """ context = None - if self.config.cuda_graph_scope and CudaGraphScope.attn not in self.config.cuda_graph_scope: + if ( + self.config.cuda_graph_modules + and CudaGraphModule.attn not in self.config.cuda_graph_modules + ): hidden_states, context = self._forward_attention(*args, **kwargs) args = (hidden_states,) kwargs = {} @@ -1130,9 +1144,9 @@ def _te_cuda_graph_replay(self, *args, **kwargs): context = cuda_graph_output.pop() if ( - not self.config.cuda_graph_scope - or (not self.is_moe_layer and CudaGraphScope.mlp in self.config.cuda_graph_scope) - or (self.is_moe_layer and CudaGraphScope.moe in self.config.cuda_graph_scope) + not self.config.cuda_graph_modules + or (not self.is_moe_layer and CudaGraphModule.mlp in self.config.cuda_graph_modules) + or (self.is_moe_layer and CudaGraphModule.moe in self.config.cuda_graph_modules) ): # CUDA Graph captures the whole MLP/MoE part. CUDA Graph output is the layer output. assert len(cuda_graph_output) == 1, "CUDA Graph output should be the layer output." @@ -1141,7 +1155,7 @@ def _te_cuda_graph_replay(self, *args, **kwargs): not self.config.overlap_moe_expert_parallel_comm ), "EP overlap must be \ disabled when CUDA graph captures the whole MLP/MoE part." - elif self.is_moe_layer and CudaGraphScope.moe_router in self.config.cuda_graph_scope: + elif self.is_moe_layer and CudaGraphModule.moe_router in self.config.cuda_graph_modules: # CUDA Graph partially captures the MoE. # The rest of the layer should go to the normal pass. shared_expert_output, routing_map = None, None @@ -1154,7 +1168,7 @@ def _te_cuda_graph_replay(self, *args, **kwargs): # The shared expert output is the last second element in the CUDA graph output. shared_expert_output = cuda_graph_output.pop() - if CudaGraphScope.moe_preprocess in self.config.cuda_graph_scope: + if CudaGraphModule.moe_preprocess in self.config.cuda_graph_modules: # CUDA graph output is [hidden_states, probs] + attributes outputs. (hidden_states, probs), attr_outputs = cuda_graph_output[:2], cuda_graph_output[2:] valid_cudagraph_attrs = self.mlp.token_dispatcher.valid_cudagraph_attrs @@ -1302,7 +1316,7 @@ def _should_call_local_cudagraph(self, *args, **kwargs): (kwargs.get('inference_context') is not None) or (kwargs.get('inference_params') is not None) ) - and not self.config.cuda_graph_scope # empty-list = per-layer CUDA graphs + and not self.config.cuda_graph_modules # empty-list = per-layer CUDA graphs ): if kwargs['inference_context'].is_static_batching(): using_cuda_graph = kwargs['inference_context'].is_decode_only() @@ -1404,17 +1418,22 @@ def create_mcore_cudagraph_manager(self, config): Unlike the standard layer which typically uses a single manager, this method can configure multiple graph managers if partial CUDA graphs are enabled via - `cuda_graph_scope`. This allows capturing the static parts of the MoE pass + `cuda_graph_modules`. This allows capturing the static parts of the MoE pass while leaving the expert computation to execute eagerly. """ + assert self.config.cuda_graph_impl == "local" + from megatron.core.transformer.cuda_graphs import CudaGraphManager - if not self.config.cuda_graph_scope or CudaGraphScope.moe in self.config.cuda_graph_scope: + if ( + not self.config.cuda_graph_modules + and self.config.inference_cuda_graph_scope != InferenceCudaGraphScope.block + ) or CudaGraphModule.moe in self.config.cuda_graph_modules: self.cudagraph_manager = CudaGraphManager(config) elif ( - CudaGraphScope.moe_router in self.config.cuda_graph_scope - or CudaGraphScope.moe_preprocess in self.config.cuda_graph_scope + CudaGraphModule.moe_router in self.config.cuda_graph_modules + or CudaGraphModule.moe_preprocess in self.config.cuda_graph_modules ): self.transition_cudagraph_scope('partial') @@ -1519,7 +1538,7 @@ def _forward_mlp(self, hidden_states, inference_context=None, padding_mask=None) if inference_context is not None: assert not self.use_partial_cudagraphs, ( "Partial cudagraphs for MoEs were detected during inference!" - "Please do not use --cuda-graph-scope moe_router moe_preprocess " + "Please do not use --cuda-graph-modules moe_router moe_preprocess " "alongside inference." ) diff --git a/megatron/inference/utils.py b/megatron/inference/utils.py index 0e14251c5aa..725c9b95083 100644 --- a/megatron/inference/utils.py +++ b/megatron/inference/utils.py @@ -25,6 +25,7 @@ TextGenerationController, ) from megatron.core.tokenizers.utils.build_tokenizer import build_tokenizer +from megatron.core.transformer.enums import InferenceCudaGraphScope from megatron.core.transformer.module import MegatronModule from megatron.core.utils import get_attr_wrapped_model, log_single_rank, unwrap_model from megatron.training import get_args @@ -347,7 +348,7 @@ def get_inference_config_from_model_and_args(model: MegatronModule, args): mamba_memory_ratio=args.inference_dynamic_batching_mamba_memory_ratio, num_cuda_graphs=( args.inference_dynamic_batching_num_cuda_graphs - if args.cuda_graph_impl == "local" + if args.inference_cuda_graph_scope != InferenceCudaGraphScope.none else None ), max_requests=args.inference_dynamic_batching_max_requests, diff --git a/megatron/rl/rl_utils.py b/megatron/rl/rl_utils.py index 4d018217cec..c0fafa4e386 100644 --- a/megatron/rl/rl_utils.py +++ b/megatron/rl/rl_utils.py @@ -35,7 +35,7 @@ from megatron.core.tokenizers import MegatronTokenizer from megatron.core.tokenizers.text.libraries.huggingface_tokenizer import HuggingFaceTokenizer from megatron.core.transformer.cuda_graphs import _CudagraphGlobalRecord -from megatron.core.transformer.enums import CudaGraphScope +from megatron.core.transformer.enums import CudaGraphModule from megatron.core.transformer.utils import ( toggle_cuda_graphs, transition_moe_cudagraphs, @@ -1511,7 +1511,7 @@ def prepare_data_for_update( # Before we can update the model, we need to get the logprobs for the \pi_{old} model. forward_backward_func = get_forward_backward_func() - if args.cuda_graph_impl == "local" and CudaGraphScope.full_iteration in args.cuda_graph_scope: + if args.cuda_graph_impl == "full_iteration": forward_backward_func = FullCudaGraphWrapper( forward_backward_func, cuda_graph_warmup_steps=args.cuda_graph_warmup_steps ) @@ -1973,9 +1973,11 @@ def megatron_rl_inference_mode( logger.debug(f"[{dist.get_rank()}] Entering inference mode") - # Set cudagraph scope for inference. - model[0].config.cuda_graph_scope = args.cuda_graph_scope + # Use local CUDA graphs during rollout inference. An empty module list preserves + # full-layer capture when the configured inference scope is layer. + model[0].config.cuda_graph_modules = [] model[0].config.cuda_graph_impl = "local" + model[0].config.inference_cuda_graph_scope = args.inference_cuda_graph_scope # If we get a lower precision wrapper, we go one object deeper. lang_module = model[0].module.module if hasattr(model[0].module, "module") else model[0].module @@ -2033,17 +2035,17 @@ def megatron_rl_inference_mode( # Restore cudagraph scope for training. # MoE partial capture requires specific scopes that aren't user-facing. + model[0].config.cuda_graph_impl = args.cuda_graph_impl + model[0].config.inference_cuda_graph_scope = args.inference_cuda_graph_scope if args.num_experts is not None: - model[0].config.cuda_graph_scope = [ - CudaGraphScope.mamba, - CudaGraphScope.attn, - CudaGraphScope.moe_router, - CudaGraphScope.moe_preprocess, + model[0].config.cuda_graph_modules = [ + CudaGraphModule.mamba, + CudaGraphModule.attn, + CudaGraphModule.moe_router, + CudaGraphModule.moe_preprocess, ] else: - model[0].config.cuda_graph_scope = [ - s for s in args.cuda_graph_scope if s != CudaGraphScope.full_iteration_inference - ] + model[0].config.cuda_graph_modules = copy.copy(args.cuda_graph_modules) # Switch MoE layers to partial CUDA graph capture for training if args.rl_training_cuda_graphs and args.num_experts is not None: diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index be3894999b4..2e42863e919 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -18,7 +18,14 @@ from megatron.core.rerun_state_machine import RerunStateMachine from megatron.core.transformer import MLATransformerConfig, TransformerConfig from megatron.core.transformer.pipeline_parallel_layer_layout import PipelineParallelLayerLayout -from megatron.core.transformer.enums import AttnBackend, CudaGraphScope +from megatron.core.transformer.cuda_graph_config import ( + ALLOWED_INFERENCE_SCOPES, + get_deprecated_cuda_graph_modules_migration, + normalize_cuda_graph_modules, + normalize_inference_cuda_graph_scope, + validate_deprecated_cuda_graph_modules_migration_inputs, +) +from megatron.core.transformer.enums import AttnBackend, CudaGraphModule, InferenceCudaGraphScope from megatron.core.transformer.heterogeneous.heterogeneous_config import ( HeterogeneousTransformerConfig, MLPConfig, @@ -239,6 +246,55 @@ def _eval_pattern(pattern): return eval(pattern) + +def _parse_cuda_graph_modules_arg(scope): + """Parse CUDA graph module CLI values while preserving deprecated spellings for migration.""" + if scope in {"full", "full_iteration", "full_iteration_inference"}: + return scope + return CudaGraphModule[scope] + + +def _normalize_cuda_graph_modules_args(args): + """Normalize cuda_graph_modules to enums and apply deprecated scope migrations.""" + normalized_scopes, deprecated_scopes, used_full_scope = normalize_cuda_graph_modules( + args.cuda_graph_modules + ) + validate_deprecated_cuda_graph_modules_migration_inputs( + deprecated_scopes, + args.cuda_graph_impl, + args.inference_cuda_graph_scope, + ) + if used_full_scope: + warn_rank_0('full scope is deprecated. Use empty cuda_graph_modules to capture the whole layer.') + + for scope, attr, value in deprecated_scopes: + migration = get_deprecated_cuda_graph_modules_migration( + scope, attr, value, args.cuda_graph_impl + ) + if migration is None: + warn_rank_0( + f"--cuda-graph-modules '{scope}' is deprecated and has no effect when " + "--cuda-graph-impl=none. Use --cuda-graph-impl=local with " + "--inference-cuda-graph-scope=block to enable inference CUDA graphs." + ) + continue + migration_attr, migration_value = migration + warn_rank_0( + f"--cuda-graph-modules '{scope}' is deprecated. " + f"Setting --{migration_attr.replace('_', '-')}={migration_value} instead." + ) + setattr(args, migration_attr, migration_value) + + args.cuda_graph_modules = normalized_scopes + + +def _normalize_inference_cuda_graph_scope_arg(args): + """Normalize inference_cuda_graph_scope and apply the impl-derived default.""" + args.inference_cuda_graph_scope = normalize_inference_cuda_graph_scope( + args.inference_cuda_graph_scope, args.cuda_graph_impl + ) + + def no_rope_freq_type(x): """ Controls which layers to skip performing Rotary Position Embedding. - An integer N: Represents a 1:N ratio, meaning RoPE is skipped every N-1 layers. @@ -556,6 +612,29 @@ def validate_args(args, defaults={}): args.cuda_graph_impl = "transformer_engine" del args.external_cuda_graph + if getattr(args, 'cuda_graph_scope_deprecated', None) is not None: + assert not args.cuda_graph_modules, ( + "--cuda-graph-scope and --cuda-graph-modules cannot be used together." + ) + warn_rank_0( + '--cuda-graph-scope is deprecated, use --cuda-graph-modules instead.' + ) + args.cuda_graph_modules = args.cuda_graph_scope_deprecated + del args.cuda_graph_scope_deprecated + + # Normalize cuda_graph_modules and inference_cuda_graph_scope early so that + # all subsequent validation sees fully-typed enum values. + _normalize_cuda_graph_modules_args(args) + _normalize_inference_cuda_graph_scope_arg(args) + assert ( + args.inference_cuda_graph_scope + in ALLOWED_INFERENCE_SCOPES[args.cuda_graph_impl] + ), ( + "Invalid inference CUDA graph scope " + f"{args.inference_cuda_graph_scope.name!r} for " + f"--cuda-graph-impl={args.cuda_graph_impl!r}." + ) + # Set input defaults. for key in defaults: # For default to be valid, it should not be provided in the @@ -1082,17 +1161,17 @@ def validate_args(args, defaults={}): elif not args.accumulate_allreduce_grads_in_fp32 and args.main_grads_dtype == torch.float32: args.accumulate_allreduce_grads_in_fp32 = True print_rank_0('accumulate and all-reduce gradients in fp32 for bfloat16 data type.') - if args.cuda_graph_impl == "local" and CudaGraphScope.full_iteration in args.cuda_graph_scope: + if args.cuda_graph_impl == "full_iteration": assert not args.check_for_nan_in_loss_and_grad, \ - "--no-check-for-nan-in-loss-and-grad should be set with --cuda-graph-scope=full_iteration for training. Note: If you are trying to use full_iteration CUDA graphs for inference, please use --cuda-graph-scope full_iteration_inference instead" - - if args.cuda_graph_impl == "local" and CudaGraphScope.full_iteration_inference in args.cuda_graph_scope: + "--no-check-for-nan-in-loss-and-grad should be set with --cuda-graph-impl=full_iteration for training." + + if args.inference_cuda_graph_scope == InferenceCudaGraphScope.block: if args.fp8 is not None: assert args.transformer_impl == "inference_optimized", \ - "fp8 with full_iteration_inference CUDA graphs is only supported with " \ + "fp8 with --inference-cuda-graph-scope=block is only supported with " \ "--transformer-impl=inference_optimized" assert args.fp8_recipe == "mxfp8", \ - "Only --fp8-recipe=mxfp8 is supported with full_iteration_inference CUDA graphs" + "Only --fp8-recipe=mxfp8 is supported with --inference-cuda-graph-scope=block" if args.cuda_graph_impl == 'local': assert args.inference_dynamic_batching_num_cuda_graphs > 0 or args.inference_dynamic_batching_num_cuda_graphs == -1, \ @@ -1586,8 +1665,8 @@ def validate_args(args, defaults={}): assert is_te_min_version("2.8.0"), ( "overlap_grad_reduce is only supported with TE >= 2.8.0 when enabling delay_wgrad_compute" ) - wgrad_in_graph_scope = CudaGraphScope.attn in args.cuda_graph_scope or ( - CudaGraphScope.moe_router in args.cuda_graph_scope + wgrad_in_graph_scope = CudaGraphModule.attn in args.cuda_graph_modules or ( + CudaGraphModule.moe_router in args.cuda_graph_modules and args.moe_shared_expert_intermediate_size is not None and not args.moe_shared_expert_overlap ) @@ -1600,7 +1679,7 @@ def validate_args(args, defaults={}): 'to be enabled. This is because the default gradient accumulation does not ' 'use static memory addresses, which breaks CUDA graph requirements.' ) - if CudaGraphScope.attn in args.cuda_graph_scope: + if CudaGraphModule.attn in args.cuda_graph_modules: assert ( not args.add_bias_linear and not args.add_qkv_bias ), "CUDA graph with delay_wgrad_compute doesn't support attn bias for now." @@ -1645,15 +1724,9 @@ def validate_args(args, defaults={}): "Setting NCCL_GRAPH_REGISTER=0 to avoid illegal memory access when using " "CUDA Graph with PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True." ) - if args.cuda_graph_scope == "full" or ( - isinstance(args.cuda_graph_scope, list) and "full" in args.cuda_graph_scope - ): - if isinstance(args.cuda_graph_scope, list): - assert args.cuda_graph_scope == ["full"], "full scope cannot be used with other scopes." - args.cuda_graph_scope = [] - warn_rank_0( - 'full scope is deprecated. Use empty cuda_graph_scope to capture the whole layer.' - ) + assert not ( + args.cuda_graph_impl == "full_iteration" and args.cuda_graph_modules + ), '--cuda-graph-modules must be empty when --cuda-graph-impl=full_iteration.' if args.multi_latent_attention: assert not args.group_query_attention, "Group query attention is mutually exclusive with multi latent attention." @@ -1821,9 +1894,13 @@ def _add_inference_args(parser): choices=["megatron", "huggingface"], help='Select either Megatron or Huggingface as the ' 'Bert embedder.') - group.add_argument('--cuda-graph-scope', nargs='+', type=lambda scope: CudaGraphScope[scope] if scope != "full" else scope, default=[], - help='Determines the CUDA graphs capturing scope. ' - 'choices: "attn", "mlp", "moe", "moe_router", "moe_preprocess", "mamba", "full_iteration". ' + group.add_argument('--cuda-graph-scope', nargs='+', type=_parse_cuda_graph_modules_arg, + default=None, dest='cuda_graph_scope_deprecated', + help=argparse.SUPPRESS) # hidden; use --cuda-graph-modules instead + group.add_argument('--cuda-graph-modules', nargs='+', type=_parse_cuda_graph_modules_arg, default=[], + help='Selects training capture coverage within per-layer CUDA graphs ' + '(local and transformer_engine implementations). ' + 'Valid values are "attn", "mlp", "moe", "moe_router", "moe_preprocess", and "mamba": ' '"attn": captures operations in TransformerLayer._forward_attention(). ' '"mlp": captures operations in TransformerLayer._forward_mlp() for a dense layer. ' '"moe": captures operations in TransformerLayer._forward_mlp() for a MoE layer. ' @@ -1831,11 +1908,12 @@ def _add_inference_args(parser): 'including the shared experts if they are not overlapped with EP comm. ' '"moe_preprocess": captures operations in MoELayer.preprocess(). Must be used together with "moe_router". ' '"mamba": captures the mamba layer. ' - '"full_iteration": captures a whole training iteration. ' - '"full_iteration_inference": captures a whole inference iteration. ' - 'full_iteration and full_iteration_inference scopes are only supported with --cuda-graph-impl=local, other scopes are only supported with --cuda-graph-impl=transformer_engine. ' - 'If not specified, the default scope is to capture the whole Transformer layer. ' - 'For backward compatibility, we still allow passing "full" to specify capturing the whole layer, and convert it to an empty list.') + 'An empty list means capturing the whole Transformer layer. ' + 'This field is meaningless when --cuda-graph-impl=full_iteration and must be empty. ' + 'Backward compatibility: "full" is deprecated but kept for backward compatibility; ' + 'it is transformed to an empty list in validate_args. The deprecated values ' + '"full_iteration" and "full_iteration_inference" are also accepted and migrated ' + 'to the new API in validate_args.') group.add_argument('--use-legacy-static-engine', action='store_true', default=False, help='Use legacy static engine. (Current static engine uses dynamic engine under the hood)', dest='use_legacy_static_engine') @@ -2015,7 +2093,8 @@ def _add_network_size_args(parser): "moe_router_load_balancing_type", "moe_aux_loss_coeff", "cp_comm_type", - "cuda_graph_scope", + "cuda_graph_modules", + "cuda_graph_scope", # deprecated alias; handled manually by --cuda-graph-scope flag # no CLI argument exists for these "virtual_pipeline_model_parallel_size", "params_dtype", diff --git a/megatron/training/training.py b/megatron/training/training.py index c6cab8df952..d064cb2f7b7 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -157,7 +157,6 @@ def set_startup_timestamps(program_start=None, main_entry=None): from megatron.core.full_cuda_graph import FullCudaGraphWrapper from megatron.core.optimizer.optimizer_cuda_graph import OptimizerCudaGraphWrapper from megatron.core.transformer.cuda_graphs import TECudaGraphHelper -from megatron.core.transformer.enums import CudaGraphScope from megatron.core.transformer.module import Float16Module from megatron.core.distributed import DistributedDataParallelConfig, TorchFullyShardedDataParallelConfig from megatron.core.distributed import DistributedDataParallel as DDP @@ -2942,7 +2941,7 @@ def train( eval_iterations = 0 # Wrap forward_backward_func for Full iteration CUDA graph forward_backward_func = get_forward_backward_func() - if args.cuda_graph_impl == "local" and CudaGraphScope.full_iteration in args.cuda_graph_scope: + if args.cuda_graph_impl == "full_iteration": forward_backward_func = FullCudaGraphWrapper(forward_backward_func, cuda_graph_warmup_steps=args.cuda_graph_warmup_steps) if args.optimizer_cuda_graph: optimizer.step = OptimizerCudaGraphWrapper(optimizer.step, cuda_graph_warmup_steps=args.cuda_graph_warmup_steps) @@ -3456,7 +3455,7 @@ def evaluate( eval_micro_batch_size = args.eval_micro_batch_size eval_num_microbatches = eval_batch_size // (eval_micro_batch_size * args.data_parallel_size) forward_backward_func = get_forward_backward_func() - if args.cuda_graph_impl == "local" and CudaGraphScope.full_iteration in args.cuda_graph_scope: + if args.cuda_graph_impl == "full_iteration": forward_backward_func = FullCudaGraphWrapper(forward_backward_func, cuda_graph_warmup_steps=args.cuda_graph_warmup_steps) if has_nvidia_modelopt: diff --git a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_cuda_graphs_fp8_logitsmatch/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_cuda_graphs_fp8_logitsmatch/model_config.yaml index 743c4f50da3..21fd7749ea7 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_cuda_graphs_fp8_logitsmatch/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_cuda_graphs_fp8_logitsmatch/model_config.yaml @@ -45,7 +45,7 @@ MODEL_ARGS: --top_k: 1 --return-log-probs: true --num-tokens-to-generate: 30 - --enable-cuda-graph: true + --cuda-graph-impl: local --inference-dynamic-batching-buffer-size-gb: 20 --dist-ckpt-strictness: log_unexpected --inference-ckpt-non-strict: true # To handle the extra_state errors diff --git a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_cuda_graphs_logitsmatch_decode_graphs_only/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_cuda_graphs_logitsmatch_decode_graphs_only/model_config.yaml index b5dc7cd5bd2..9996433cf1f 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_cuda_graphs_logitsmatch_decode_graphs_only/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_cuda_graphs_logitsmatch_decode_graphs_only/model_config.yaml @@ -45,7 +45,7 @@ MODEL_ARGS: --top_k: 1 --return-log-probs: true --num-tokens-to-generate: 30 - --enable-cuda-graph: true + --cuda-graph-impl: local --decode-only-cuda-graphs: true --inference-dynamic-batching-buffer-size-gb: 20 --dist-ckpt-strictness: log_unexpected diff --git a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_dp8_583m_throughputtest_zmq/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_dp8_583m_throughputtest_zmq/model_config.yaml index aa4fde5e512..e94dfc8f3e1 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_dp8_583m_throughputtest_zmq/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_dp8_583m_throughputtest_zmq/model_config.yaml @@ -44,7 +44,6 @@ MODEL_ARGS: --inference-dynamic-batching-buffer-size-gb: 20 --inference-dynamic-batching-cuda-graph-max-tokens: 2048 --cuda-graph-impl: local - --cuda-graph-scope: full --disable-chunked-prefill: true --dist-ckpt-strictness: log_unexpected --inference-ckpt-non-strict: true # To handle the extra_state errors diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_ep8_nanov3_chunked_prefill/model_config.yaml b/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_ep8_nanov3_chunked_prefill/model_config.yaml index 8a5d0bc3508..ba07a85c024 100644 --- a/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_ep8_nanov3_chunked_prefill/model_config.yaml +++ b/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_ep8_nanov3_chunked_prefill/model_config.yaml @@ -53,8 +53,8 @@ MODEL_ARGS: --inference-dynamic-batching-max-requests: 32 --inference-max-seq-length: 4096 --enable-chunked-prefill: true - --cuda-graph-scope: full_iteration_inference --cuda-graph-impl: local + --inference-cuda-graph-scope: block --inference-dynamic-batching-num-cuda-graphs: -1 --output-path: ${INFERENCE_OUTPUT_PATH} --prompts: 'Artificial intelligence has transformed numerous industries over the past decade. From healthcare to finance, manufacturing to education, AI systems are now capable of performing tasks that once required significant human expertise. Machine learning models can diagnose diseases from medical images, detect fraud in financial transactions, optimize supply chains, and personalize educational content for individual students. Large language models in particular have demonstrated remarkable capabilities in understanding and generating human language, enabling applications such as code generation, document summarization, question answering, and creative writing assistance. As these systems continue to improve, researchers and practitioners are working to address challenges around reliability, fairness, and interpretability. The next generation of AI systems will likely be even more capable, but also require careful consideration of their societal implications. In summary, the field of artificial intelligence is' diff --git a/tests/functional_tests/test_cases/moe/gpt3_mcore_te_tp2_pp2_ep4_etp1_resume_torch_dist_attn_cudagraph/model_config.yaml b/tests/functional_tests/test_cases/moe/gpt3_mcore_te_tp2_pp2_ep4_etp1_resume_torch_dist_attn_cudagraph/model_config.yaml index 02134fe47c3..64cdacd6076 100644 --- a/tests/functional_tests/test_cases/moe/gpt3_mcore_te_tp2_pp2_ep4_etp1_resume_torch_dist_attn_cudagraph/model_config.yaml +++ b/tests/functional_tests/test_cases/moe/gpt3_mcore_te_tp2_pp2_ep4_etp1_resume_torch_dist_attn_cudagraph/model_config.yaml @@ -120,7 +120,7 @@ MODEL_ARGS: --tensorboard-dir: ${TENSORBOARD_PATH} # CUDA Graph args --cuda-graph-impl: transformer_engine - --cuda-graph-scope: attn + --cuda-graph-modules: attn --cuda-graph-warmup-steps: 0 --te-rng-tracker: true # Add mixed precision args diff --git a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp4_ep2_etp2_pp2_scoped_cudagraph/model_config.yaml b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp4_ep2_etp2_pp2_scoped_cudagraph/model_config.yaml index cfd0c2b7132..efdac2478fd 100644 --- a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp4_ep2_etp2_pp2_scoped_cudagraph/model_config.yaml +++ b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp4_ep2_etp2_pp2_scoped_cudagraph/model_config.yaml @@ -65,7 +65,7 @@ MODEL_ARGS: --recompute-granularity: selective --recompute-modules: "[moe_act]" --cuda-graph-impl: transformer_engine - --cuda-graph-scope: "[attn mlp moe_router moe_preprocess]" + --cuda-graph-modules: "[attn mlp moe_router moe_preprocess]" --log-memory-to-tensorboard: true --log-params-norm: true --log-num-zeros-in-grad: true diff --git a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp4_ep2_etp2_pp2_scoped_cudagraph_1node/model_config.yaml b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp4_ep2_etp2_pp2_scoped_cudagraph_1node/model_config.yaml index a33ffbe1018..26af6497637 100644 --- a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp4_ep2_etp2_pp2_scoped_cudagraph_1node/model_config.yaml +++ b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp4_ep2_etp2_pp2_scoped_cudagraph_1node/model_config.yaml @@ -64,7 +64,7 @@ MODEL_ARGS: --recompute-granularity: selective --recompute-modules: "[moe_act]" --cuda-graph-impl: transformer_engine - --cuda-graph-scope: "[attn mlp moe_router moe_preprocess]" + --cuda-graph-modules: "[attn mlp moe_router moe_preprocess]" --log-memory-to-tensorboard: true --log-params-norm: true --log-num-zeros-in-grad: true diff --git a/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_cuda_graphs_pad_tp4_pp1_ep4_16B_logitsmatch/model_config.yaml b/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_cuda_graphs_pad_tp4_pp1_ep4_16B_logitsmatch/model_config.yaml index afc75144dc8..3bd326a56e1 100644 --- a/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_cuda_graphs_pad_tp4_pp1_ep4_16B_logitsmatch/model_config.yaml +++ b/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_cuda_graphs_pad_tp4_pp1_ep4_16B_logitsmatch/model_config.yaml @@ -52,7 +52,7 @@ MODEL_ARGS: --seq-length: 4096 --max-position-embeddings: 4096 --micro-batch-size: 1 - --enable-cuda-graph: true + --cuda-graph-impl: local --te-rng-tracker: true --inference-rng-tracker: true --moe-pad-experts-for-cuda-graph-inference: true diff --git a/tests/functional_tests/test_cases/moe/gpt_static_inference_cuda_graphs_pad_tp4_pp1_ep4_16B_logitsmatch/model_config.yaml b/tests/functional_tests/test_cases/moe/gpt_static_inference_cuda_graphs_pad_tp4_pp1_ep4_16B_logitsmatch/model_config.yaml index 1c22a729f6e..049c9090099 100644 --- a/tests/functional_tests/test_cases/moe/gpt_static_inference_cuda_graphs_pad_tp4_pp1_ep4_16B_logitsmatch/model_config.yaml +++ b/tests/functional_tests/test_cases/moe/gpt_static_inference_cuda_graphs_pad_tp4_pp1_ep4_16B_logitsmatch/model_config.yaml @@ -53,7 +53,7 @@ MODEL_ARGS: --max-position-embeddings: 4096 --micro-batch-size: 1 --flash-decode: true - --enable-cuda-graph: true + --cuda-graph-impl: local --te-rng-tracker: true --inference-rng-tracker: true --moe-pad-experts-for-cuda-graph-inference: true diff --git a/tests/functional_tests/test_cases/nemotron/nemotron3_super_release_gb200/model_config.yaml b/tests/functional_tests/test_cases/nemotron/nemotron3_super_release_gb200/model_config.yaml index 6fa15c22b5d..241b741adff 100644 --- a/tests/functional_tests/test_cases/nemotron/nemotron3_super_release_gb200/model_config.yaml +++ b/tests/functional_tests/test_cases/nemotron/nemotron3_super_release_gb200/model_config.yaml @@ -31,7 +31,7 @@ MODEL_ARGS: --cross-entropy-fusion-impl: native --attention-backend: flash --enable-cuda-graph: true - --cuda-graph-scope: "[mamba attn]" + --cuda-graph-modules: "[mamba attn]" --te-rng-tracker: true --manual-gc: true --manual-gc-interval: 10 diff --git a/tests/functional_tests/test_cases/nemotron/nemotron3_super_release_gb200_sm/model_config.yaml b/tests/functional_tests/test_cases/nemotron/nemotron3_super_release_gb200_sm/model_config.yaml index 5cbc331f2c1..af75e0f095c 100644 --- a/tests/functional_tests/test_cases/nemotron/nemotron3_super_release_gb200_sm/model_config.yaml +++ b/tests/functional_tests/test_cases/nemotron/nemotron3_super_release_gb200_sm/model_config.yaml @@ -31,7 +31,7 @@ MODEL_ARGS: --cross-entropy-fusion-impl: native --attention-backend: flash --enable-cuda-graph: true - --cuda-graph-scope: "[mamba attn]" + --cuda-graph-modules: "[mamba attn]" --te-rng-tracker: true --manual-gc: true --manual-gc-interval: 10 diff --git a/tests/unit_tests/a2a_overlap/test_cuda_graphed_schedule_chunk_1f1b.py b/tests/unit_tests/a2a_overlap/test_cuda_graphed_schedule_chunk_1f1b.py index 85586095bd7..82dab51dc4b 100644 --- a/tests/unit_tests/a2a_overlap/test_cuda_graphed_schedule_chunk_1f1b.py +++ b/tests/unit_tests/a2a_overlap/test_cuda_graphed_schedule_chunk_1f1b.py @@ -15,7 +15,7 @@ from megatron.core.num_microbatches_calculator import destroy_num_microbatches_calculator from megatron.core.pipeline_parallel.utils import set_streams from megatron.core.tensor_parallel.random import HAVE_TE, model_parallel_cuda_manual_seed -from megatron.core.transformer.enums import CudaGraphScope +from megatron.core.transformer.enums import CudaGraphModule from megatron.core.transformer.module import float16_to_fp32 from megatron.core.utils import is_te_min_version, unwrap_model from megatron.training.arguments import core_transformer_config_from_args, parse_args, validate_args @@ -118,7 +118,7 @@ def model_provider( ) def create_test_args( - self, cuda_graph_impl, cuda_graph_scope, cuda_graph_warmup_steps, ep_size, **kwargs + self, cuda_graph_impl, cuda_graph_modules, cuda_graph_warmup_steps, ep_size, **kwargs ): destroy_global_vars() destroy_num_microbatches_calculator() @@ -163,7 +163,7 @@ def create_test_args( # CUDA graph settings args.cuda_graph_impl = cuda_graph_impl - args.cuda_graph_scope = cuda_graph_scope + args.cuda_graph_modules = cuda_graph_modules args.cuda_graph_warmup_steps = cuda_graph_warmup_steps args.use_te_rng_tracker = cuda_graph_impl != "none" @@ -247,7 +247,7 @@ def _run_test_helper( self, ep_size, cuda_graph_impl, - cuda_graph_scope, + cuda_graph_modules, cuda_graph_warmup_steps, ep_overlap=False, **kwargs, @@ -255,7 +255,7 @@ def _run_test_helper( """Test fp8_param with gpt_model.""" args = self.create_test_args( cuda_graph_impl, - cuda_graph_scope, + cuda_graph_modules, cuda_graph_warmup_steps, ep_size, overlap_moe_expert_parallel_comm=ep_overlap, @@ -357,16 +357,16 @@ def test_moe_partial_cudagraph_with_ep_overlap(self, moe_dispatcher_type): extra_kwargs["moe_token_dispatcher_type"] = moe_dispatcher_type loss_list_ref = self._run_test_helper(4, "none", None, 3, **extra_kwargs) - for cuda_graph_scope in [ - [CudaGraphScope.attn], - [CudaGraphScope.attn, CudaGraphScope.moe_router], - [CudaGraphScope.attn, CudaGraphScope.moe_router, CudaGraphScope.moe_preprocess], + for cuda_graph_modules in [ + [CudaGraphModule.attn], + [CudaGraphModule.attn, CudaGraphModule.moe_router], + [CudaGraphModule.attn, CudaGraphModule.moe_router, CudaGraphModule.moe_preprocess], ]: cuda_graph_warmup_steps = 3 loss_list = self._run_test_helper( 4, "transformer_engine", - cuda_graph_scope, + cuda_graph_modules, cuda_graph_warmup_steps, ep_overlap=True, **extra_kwargs, @@ -375,5 +375,5 @@ def test_moe_partial_cudagraph_with_ep_overlap(self, moe_dispatcher_type): for i in range(len(loss_list)): assert torch.equal( loss_list[i].mean(), loss_list_ref[i].mean() - ), f"scope={cuda_graph_scope}, i={i},loss_list={loss_list[i]}, loss_list_ref={loss_list_ref[i]}" - print(f"[DEBUG] Pass {cuda_graph_scope}") + ), f"scope={cuda_graph_modules}, i={i},loss_list={loss_list[i]}, loss_list_ref={loss_list_ref[i]}" + print(f"[DEBUG] Pass {cuda_graph_modules}") diff --git a/tests/unit_tests/inference/engines/test_dynamic_engine.py b/tests/unit_tests/inference/engines/test_dynamic_engine.py index 7bcf21882c1..d0849594df9 100644 --- a/tests/unit_tests/inference/engines/test_dynamic_engine.py +++ b/tests/unit_tests/inference/engines/test_dynamic_engine.py @@ -51,7 +51,7 @@ from megatron.core.ssm.mamba_mixer import _check_mamba_sequence_packing_support from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.cuda_graphs import delete_cuda_graphs -from megatron.core.transformer.enums import CudaGraphScope +from megatron.core.transformer.enums import CudaGraphModule, InferenceCudaGraphScope from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import is_fa_min_version, is_te_min_version from tests.unit_tests.test_utilities import Utils, clear_nvte_env_vars @@ -128,9 +128,9 @@ class DynamicEngineTestConfig: skip_prompt_log_probs: bool = False enable_chunked_prefill: bool = False enable_prefix_caching: bool = False - cuda_graph_scope: List[CudaGraphScope] = field( - default_factory=lambda: [CudaGraphScope.full_iteration_inference] - ) + cuda_graph_modules: List[CudaGraphModule] = field(default_factory=list) + inference_cuda_graph_scope: InferenceCudaGraphScope = InferenceCudaGraphScope.block + cuda_graph_impl: Optional[str] = None force_build_cuda_graphs: bool = False transformer_impl: str = "local" inference_moe_token_dispatcher_type: str = "nccl" @@ -182,6 +182,23 @@ class DynamicEngineTestEnv: class DynamicInferenceEngineTestBase: + @staticmethod + def _assert_inference_cuda_graphs_disabled(env) -> None: + model = env.engine.controller.inference_wrapped_model.model + assert env.engine.cuda_graph_impl == "full_iteration" + assert env.engine.inference_cuda_graph_scope == InferenceCudaGraphScope.none + assert env.engine.capture_stats is None + assert not hasattr(model.decoder, 'cudagraph_manager') + for layer in model.decoder.layers: + assert not hasattr(layer, 'cudagraph_manager') + + @staticmethod + def _cuda_graph_batch_dimension_signature(env) -> List[Tuple[int, int, int]]: + return [ + (dim.token_count, dim.prefill_req_count, dim.decode_req_count) + for dim in env.engine.context.cuda_graph_batch_dimensions_list + ] + @classmethod def _build_requests(cls, test_config: DynamicEngineTestConfig) -> List[DynamicInferenceRequest]: @@ -299,6 +316,13 @@ def _build_test_env(cls, test_config): # Requests. requests = cls._build_requests(test_config) + effective_cuda_graph_impl = test_config.cuda_graph_impl + if effective_cuda_graph_impl is None: + effective_cuda_graph_impl = ( + "local" + if test_config.num_cuda_graphs is not None and test_config.force_build_cuda_graphs + else "none" + ) if test_config.model_provider == "gpt": # Transformer config. @@ -309,12 +333,7 @@ def _build_test_env(cls, test_config): hidden_size=128 if test_config.fp8 else 32, num_attention_heads=4, use_cpu_initialization=True, - cuda_graph_impl=( - "local" - if test_config.num_cuda_graphs is not None - and test_config.force_build_cuda_graphs - else "none" - ), + cuda_graph_impl=effective_cuda_graph_impl, inference_rng_tracker=True, tensor_model_parallel_size=test_config.tensor_model_parallel_size, pipeline_model_parallel_size=test_config.pipeline_model_parallel_size, @@ -331,7 +350,13 @@ def _build_test_env(cls, test_config): fp8="hybrid" if test_config.fp8 else None, fp8_recipe="tensorwise" if test_config.fp8 else None, inference_sampling_seed=test_config.random_seed, - cuda_graph_scope=test_config.cuda_graph_scope, + cuda_graph_modules=test_config.cuda_graph_modules, + inference_cuda_graph_scope=( + test_config.inference_cuda_graph_scope + if test_config.num_cuda_graphs is not None + and test_config.force_build_cuda_graphs + else InferenceCudaGraphScope.none + ), transformer_impl=test_config.transformer_impl, inference_moe_token_dispatcher_type=( test_config.inference_moe_token_dispatcher_type @@ -383,12 +408,7 @@ def _build_test_env(cls, test_config): mamba_num_heads=16, num_attention_heads=16, use_cpu_initialization=True, - cuda_graph_impl=( - "local" - if test_config.num_cuda_graphs is not None - and test_config.force_build_cuda_graphs - else "none" - ), + cuda_graph_impl=effective_cuda_graph_impl, inference_rng_tracker=True, tensor_model_parallel_size=test_config.tensor_model_parallel_size, pipeline_model_parallel_size=pp_size, @@ -405,7 +425,13 @@ def _build_test_env(cls, test_config): fp8="hybrid" if test_config.fp8 else None, fp8_recipe="tensorwise" if test_config.fp8 else None, inference_sampling_seed=test_config.random_seed, - cuda_graph_scope=test_config.cuda_graph_scope, + cuda_graph_modules=test_config.cuda_graph_modules, + inference_cuda_graph_scope=( + test_config.inference_cuda_graph_scope + if test_config.num_cuda_graphs is not None + and test_config.force_build_cuda_graphs + else InferenceCudaGraphScope.none + ), transformer_impl=test_config.transformer_impl, inference_moe_token_dispatcher_type=( test_config.inference_moe_token_dispatcher_type @@ -597,8 +623,10 @@ def teardown_class(cls): ) @pytest.mark.parametrize("model_provider", ["gpt", "hybrid"]) @pytest.mark.parametrize("num_cuda_graphs", [None, 1, 4, -1]) - @pytest.mark.parametrize("cuda_graph_scope", [[], [CudaGraphScope.full_iteration_inference]]) - def test_simple(self, model_provider, num_cuda_graphs, cuda_graph_scope) -> None: + @pytest.mark.parametrize( + "inference_cuda_graph_scope", [InferenceCudaGraphScope.layer, InferenceCudaGraphScope.block] + ) + def test_simple(self, model_provider, num_cuda_graphs, inference_cuda_graph_scope) -> None: """Simple test that runs without errors, and validates output.""" skip_if_mamba_sequence_packing_not_available(model_provider) num_tokens_to_generate = 16 @@ -608,7 +636,7 @@ def test_simple(self, model_provider, num_cuda_graphs, cuda_graph_scope) -> None num_tokens_to_generate=num_tokens_to_generate, model_provider=model_provider, num_cuda_graphs=num_cuda_graphs, - cuda_graph_scope=cuda_graph_scope, + inference_cuda_graph_scope=inference_cuda_graph_scope, force_build_cuda_graphs=True, context_max_requests=128, ) @@ -620,7 +648,7 @@ def test_simple(self, model_provider, num_cuda_graphs, cuda_graph_scope) -> None assert env.engine.context.cuda_graph_token_counts is not None assert env.engine.context.cuda_graph_batch_dimensions_list model = env.engine.controller.inference_wrapped_model.model - if cuda_graph_scope == [CudaGraphScope.full_iteration_inference]: + if inference_cuda_graph_scope == InferenceCudaGraphScope.block: # hybrid models attach cudagraph_manager to the model; others attach to the decoder if model_provider == "hybrid": assert model.cudagraph_manager.cudagraph_runners @@ -674,6 +702,63 @@ def test_simple(self, model_provider, num_cuda_graphs, cuda_graph_scope) -> None f"expected ({expected_generated_tokens})." ) + @pytest.mark.internal + @pytest.mark.skipif( + not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" + ) + def test_full_iteration_impl_does_not_setup_inference_cuda_graphs(self) -> None: + """impl=full_iteration is training-only; inference graph setup follows inference scope.""" + env = self._build_test_env( + DynamicEngineTestConfig( + model_provider="gpt", + num_tokens_to_generate=4, + num_cuda_graphs=1, + force_build_cuda_graphs=True, + context_max_requests=128, + cuda_graph_impl="full_iteration", + inference_cuda_graph_scope=InferenceCudaGraphScope.none, + ) + ) + + self._assert_inference_cuda_graphs_disabled(env) + + with mock.patch.object( + env.engine.controller, + "_dynamic_step_forward_logits", + wraps=env.engine.controller._dynamic_step_forward_logits, + ) as forward_logits: + with torch.inference_mode(): + env.engine.create_cuda_graphs() + + assert forward_logits.call_count == 0 + + @pytest.mark.internal + @pytest.mark.skipif( + not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" + ) + def test_deprecated_full_iteration_inference_scope_matches_new_flag_runtime_behavior( + self, + ) -> None: + """Deprecated scope='full_iteration_inference' must still build block-level graphs.""" + with pytest.warns( + DeprecationWarning, match="cuda_graph_modules 'full_iteration_inference' is deprecated" + ): + env = self._run_test( + num_tokens_to_generate=4, + model_provider="gpt", + num_cuda_graphs=1, + cuda_graph_modules='full_iteration_inference', + force_build_cuda_graphs=True, + context_max_requests=128, + ) + + model = env.engine.controller.inference_wrapped_model.model + assert model.config.inference_cuda_graph_scope == InferenceCudaGraphScope.block + assert model.config.cuda_graph_modules == [] + assert model.decoder.cudagraph_manager.cudagraph_runners + for layer in model.decoder.layers: + assert not hasattr(layer, 'cudagraph_manager') + @pytest.mark.skipif( not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" ) diff --git a/tests/unit_tests/inference/test_moe_dispatching_and_routing.py b/tests/unit_tests/inference/test_moe_dispatching_and_routing.py index a9caa12f178..94234362e23 100644 --- a/tests/unit_tests/inference/test_moe_dispatching_and_routing.py +++ b/tests/unit_tests/inference/test_moe_dispatching_and_routing.py @@ -17,7 +17,7 @@ from megatron.core.activations import squared_relu from megatron.core.inference.communication.torch_symm_triton import are_tensors_nvls_eligible -from megatron.core.transformer.enums import AttnBackend +from megatron.core.transformer.enums import AttnBackend, InferenceCudaGraphScope from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import is_te_min_version, is_torch_min_version from megatron.training.initialize import _set_random_seed @@ -64,7 +64,7 @@ use_cpu_initialization=True, attention_backend=AttnBackend.local, cuda_graph_impl="local", - cuda_graph_scope="full_iteration_inference", + inference_cuda_graph_scope=InferenceCudaGraphScope.block, moe_pad_experts_for_cuda_graph_inference=False, mamba_state_dim=128, mamba_head_dim=64, diff --git a/tests/unit_tests/models/test_hybrid_moe_model.py b/tests/unit_tests/models/test_hybrid_moe_model.py index 3935964c975..4cecc2f9a87 100644 --- a/tests/unit_tests/models/test_hybrid_moe_model.py +++ b/tests/unit_tests/models/test_hybrid_moe_model.py @@ -69,8 +69,9 @@ "cross_entropy_loss_fusion": True, "cuda_graph_impl": "none", "cuda_graph_retain_backward_graph": False, - "cuda_graph_scope": [], + "cuda_graph_modules": [], "cuda_graph_use_single_mempool": False, + "cuda_graph_scope": None, "cuda_graph_warmup_steps": 3, "deallocate_pipeline_outputs": True, "defer_embedding_wgrad_compute": False, @@ -288,6 +289,12 @@ "offload_modules": [], "hybrid_context_parallel": False, "max_seqlen_per_dp_cp_rank": None, + "inference_cuda_graph_scope": { + "__objclass__": "megatron.core.transformer.enums.InferenceCudaGraphScope", + "_name_": "none", + "_sort_order_": 0, + "_value_": 1, + }, "inference_disable_triton_nvls_kernels": False, "moe_router_force_biased": None, "inference_grouped_gemm_backend": "vllm", diff --git a/tests/unit_tests/rl/test_rl_utils.py b/tests/unit_tests/rl/test_rl_utils.py index 6bf6e994ffb..0a04caa8732 100644 --- a/tests/unit_tests/rl/test_rl_utils.py +++ b/tests/unit_tests/rl/test_rl_utils.py @@ -1,7 +1,9 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import itertools -from unittest.mock import MagicMock +from contextlib import nullcontext +from types import SimpleNamespace +from unittest.mock import MagicMock, call import numpy as np import pytest @@ -28,6 +30,7 @@ create_cudagraphs, delete_cuda_graphs, ) +from megatron.core.transformer.enums import CudaGraphModule, InferenceCudaGraphScope from megatron.core.transformer.module import Float16Module from megatron.rl import rl_utils from megatron.rl.agent.api import TokenRollout @@ -84,6 +87,27 @@ def detokenize(self, tokens): return [str(tok) for tok in tokens] +class DummyLangModule: + def __init__(self, config): + self.config = config + self.rotary_pos_emb = None + self.eval = MagicMock() + self.train = MagicMock() + + def modules(self): + return iter(()) + + +class DummyMoELayer: + def __init__(self, use_partial_cudagraphs): + self.use_partial_cudagraphs = use_partial_cudagraphs + self.transition_calls = [] + + def transition_cudagraph_scope(self, mode): + self.transition_calls.append(mode) + self.use_partial_cudagraphs = mode == "partial" + + @pytest.fixture def initialize_model_parallel(request, monkeypatch): """Fixture to initialize and destroy model parallel. @@ -141,6 +165,73 @@ def create_test_args(self, **kwargs): set_global_variables(args, False) return args + def _patch_rl_inference_mode_deps(self, monkeypatch, args): + interface = MagicMock() + interface.resume.return_value = object() + interface.suspend.return_value = object() + loop = SimpleNamespace(run_until_complete=MagicMock()) + + monkeypatch.setattr(rl_utils, "get_args", lambda: args) + monkeypatch.setattr(rl_utils, "get_asyncio_loop", lambda: loop) + monkeypatch.setattr( + rl_utils, "get_nvtx_range", lambda: (lambda *args, **kwargs: nullcontext()) + ) + monkeypatch.setattr(rl_utils, "get_inference_interface", lambda *_args: interface) + monkeypatch.setattr( + rl_utils, + "unwrap_model", + lambda model: model.module if hasattr(model, "module") else model, + ) + monkeypatch.setattr( + rl_utils, "_maybe_prefetch_separate_inference_model_weights", MagicMock() + ) + monkeypatch.setattr(rl_utils, "set_decode_expert_padding", MagicMock()) + monkeypatch.setattr(rl_utils.dist, "get_rank", lambda: 0) + return interface, loop + + def _make_toggle_cuda_graphs_mock(self): + def _toggle(lang_module, set_to): + assert set_to in {"none", "local"}, f"Invalid CUDA graph implementation: {set_to}" + lang_module.config.cuda_graph_impl = set_to + + return MagicMock(side_effect=_toggle) + + def test_megatron_rl_inference_mode_restores_training_cuda_graph_state(self, monkeypatch): + config = SimpleNamespace( + cuda_graph_impl="none", + cuda_graph_modules=[CudaGraphModule.attn], + inference_cuda_graph_scope=InferenceCudaGraphScope.none, + ) + lang_module = DummyLangModule(config) + model = [SimpleNamespace(config=config, module=lang_module)] + args = SimpleNamespace( + rl_training_cuda_graphs=False, + num_experts=None, + curr_iteration=11, + cuda_graph_impl="local", + cuda_graph_modules=[CudaGraphModule.attn], + inference_cuda_graph_scope=InferenceCudaGraphScope.block, + ) + interface, _ = self._patch_rl_inference_mode_deps(monkeypatch, args) + toggle_cuda_graphs = self._make_toggle_cuda_graphs_mock() + monkeypatch.setattr(rl_utils, "toggle_cuda_graphs", toggle_cuda_graphs) + + with rl_utils.megatron_rl_inference_mode(model, MagicMock(), "local", False) as result: + assert result is interface + assert config.cuda_graph_impl == "local" + assert config.cuda_graph_modules == [] + assert config.inference_cuda_graph_scope == InferenceCudaGraphScope.block + + assert toggle_cuda_graphs.call_args_list == [ + call(lang_module, "local"), + call(lang_module, "none"), + ] + assert config.cuda_graph_impl == "local" + assert config.cuda_graph_modules == [CudaGraphModule.attn] + assert config.inference_cuda_graph_scope == InferenceCudaGraphScope.block + lang_module.eval.assert_called_once() + lang_module.train.assert_called_once() + @pytest.mark.parametrize( "initialize_model_parallel", [ diff --git a/tests/unit_tests/test_fp4_param.py b/tests/unit_tests/test_fp4_param.py index f01d6592d23..66f67f83c2c 100644 --- a/tests/unit_tests/test_fp4_param.py +++ b/tests/unit_tests/test_fp4_param.py @@ -139,7 +139,7 @@ def create_test_args( if kwargs.get("enable_cuda_graph", False): args.cuda_graph_impl = "transformer_engine" args.cuda_graph_warmup_steps = 0 - args.cuda_graph_scope = "full" + args.cuda_graph_modules = "full" for key, value in kwargs.items(): if key == "enable_cuda_graph": diff --git a/tests/unit_tests/transformer/test_cuda_graphs.py b/tests/unit_tests/transformer/test_cuda_graphs.py index fb81fe588bf..5d2434165f5 100644 --- a/tests/unit_tests/transformer/test_cuda_graphs.py +++ b/tests/unit_tests/transformer/test_cuda_graphs.py @@ -35,7 +35,7 @@ TECudaGraphHelper, _CudagraphGlobalRecord, ) -from megatron.core.transformer.enums import CudaGraphScope +from megatron.core.transformer.enums import CudaGraphModule, CudaGraphScope, InferenceCudaGraphScope from megatron.core.transformer.mlp import MLPSubmodules from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.moe.fused_a2a import reset_hybrid_ep_buffer @@ -44,6 +44,7 @@ from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_layer import TransformerLayer from megatron.core.utils import is_te_min_version +from megatron.training import arguments as training_arguments from megatron.training.arguments import core_transformer_config_from_args, parse_args, validate_args from megatron.training.global_vars import ( destroy_global_vars, @@ -57,6 +58,273 @@ fp8_available, _ = check_fp8_support() +def _base_cuda_graph_config(**kwargs) -> TransformerConfig: + return TransformerConfig(num_layers=2, hidden_size=64, num_attention_heads=4, **kwargs) + + +def _validated_cuda_graph_cli_args(monkeypatch, cli_args=None, **overrides): + destroy_global_vars() + destroy_num_microbatches_calculator() + + warning_messages = [] + print_messages = [] + + monkeypatch.setattr( + training_arguments, "warn_rank_0", lambda msg, *args, **kwargs: warning_messages.append(msg) + ) + monkeypatch.setattr( + training_arguments, "print_rank_0", lambda msg, *args, **kwargs: print_messages.append(msg) + ) + monkeypatch.setattr(sys, "argv", ["test_cuda_graphs.py", *(cli_args or [])]) + + args = parse_args() + args.num_layers = 2 + args.vocab_size = 256 + args.hidden_size = 64 + args.num_attention_heads = 4 + args.max_position_embeddings = 128 + args.seq_length = 128 + args.micro_batch_size = 1 + + for key, value in overrides.items(): + setattr(args, key, value) + + args = validate_args(args) + return args, warning_messages, print_messages + + +class TestCudaGraphConfigAndArguments: + def test_local_impl_defaults_to_layer_scope(self): + cfg = _base_cuda_graph_config(cuda_graph_impl='local') + assert cfg.inference_cuda_graph_scope == InferenceCudaGraphScope.layer + + def test_full_iteration_impl_requires_empty_scope(self): + with pytest.raises( + AssertionError, + match='cuda_graph_modules must be empty when cuda_graph_impl="full_iteration"', + ): + _base_cuda_graph_config( + cuda_graph_impl='full_iteration', cuda_graph_modules=[CudaGraphModule.attn] + ) + + def test_full_iteration_scope_string_in_config_migrated(self): + with pytest.warns(DeprecationWarning, match="deprecated"): + cfg = _base_cuda_graph_config( + cuda_graph_impl='local', cuda_graph_modules='full_iteration' + ) + assert cfg.cuda_graph_impl == 'full_iteration' + assert cfg.cuda_graph_modules == [] + assert cfg.cuda_graph_scope is None + + def test_full_iteration_inference_scope_string_in_config_migrated(self): + with pytest.warns(DeprecationWarning, match="deprecated"): + cfg = _base_cuda_graph_config( + cuda_graph_impl='local', cuda_graph_modules='full_iteration_inference' + ) + assert cfg.inference_cuda_graph_scope == InferenceCudaGraphScope.block + assert cfg.cuda_graph_modules == [] + assert cfg.cuda_graph_scope is None + + def test_full_iteration_inference_scope_string_noops_without_local_impl(self): + with pytest.warns(DeprecationWarning, match="has no effect"): + cfg = _base_cuda_graph_config(cuda_graph_modules='full_iteration_inference') + assert cfg.cuda_graph_impl == 'none' + assert cfg.inference_cuda_graph_scope == InferenceCudaGraphScope.none + assert cfg.cuda_graph_modules == [] + assert cfg.cuda_graph_scope is None + + def test_deprecated_full_iteration_scope_rejects_conflicting_new_scope(self): + with pytest.raises( + AssertionError, + match="cuda_graph_modules='full_iteration' cannot be combined with " + "inference_cuda_graph_scope='block'", + ): + _base_cuda_graph_config( + cuda_graph_impl='local', + cuda_graph_modules='full_iteration', + inference_cuda_graph_scope='block', + ) + + def test_deprecated_full_iteration_inference_scope_rejects_conflicting_new_scope(self): + with pytest.raises( + AssertionError, + match="cuda_graph_modules='full_iteration_inference' cannot be combined with " + "inference_cuda_graph_scope='layer'", + ): + _base_cuda_graph_config( + cuda_graph_impl='local', + cuda_graph_modules='full_iteration_inference', + inference_cuda_graph_scope='layer', + ) + + def test_enable_cuda_graph_flag_migrates_to_local_impl(self, monkeypatch): + args, _, print_messages = _validated_cuda_graph_cli_args( + monkeypatch, ['--enable-cuda-graph'] + ) + assert args.cuda_graph_impl == 'local' + assert any("--enable-cuda-graph is deprecated" in msg for msg in print_messages) + + def test_full_iteration_inference_scope_cli_migrates_to_block_scope(self, monkeypatch): + args, warning_messages, _ = _validated_cuda_graph_cli_args( + monkeypatch, + ['--cuda-graph-impl', 'local', '--cuda-graph-modules', 'full_iteration_inference'], + ) + assert args.cuda_graph_impl == 'local' + assert args.inference_cuda_graph_scope == InferenceCudaGraphScope.block + assert args.cuda_graph_modules == [] + assert any( + "--cuda-graph-modules 'full_iteration_inference' is deprecated" in msg + for msg in warning_messages + ) + + def test_full_iteration_inference_scope_cli_noops_without_local_impl(self, monkeypatch): + args, warning_messages, _ = _validated_cuda_graph_cli_args( + monkeypatch, ['--cuda-graph-scope', 'full_iteration_inference'] + ) + assert args.cuda_graph_impl == 'none' + assert args.inference_cuda_graph_scope == InferenceCudaGraphScope.none + assert args.cuda_graph_modules == [] + assert any("has no effect when --cuda-graph-impl=none" in msg for msg in warning_messages) + + def test_full_iteration_inference_scope_cli_rejects_conflicting_new_scope(self, monkeypatch): + with pytest.raises( + AssertionError, + match="cuda_graph_modules='full_iteration_inference' cannot be combined with " + "inference_cuda_graph_scope='layer'", + ): + _validated_cuda_graph_cli_args( + monkeypatch, + [ + '--cuda-graph-impl', + 'local', + '--cuda-graph-modules', + 'full_iteration_inference', + '--inference-cuda-graph-scope', + 'layer', + ], + ) + + def test_new_scope_cli_accepts_block(self, monkeypatch): + args, _, _ = _validated_cuda_graph_cli_args( + monkeypatch, ['--cuda-graph-impl', 'local', '--inference-cuda-graph-scope', 'block'] + ) + assert args.cuda_graph_impl == 'local' + assert args.inference_cuda_graph_scope == InferenceCudaGraphScope.block + + def test_new_scope_cli_accepts_layer(self, monkeypatch): + args, _, _ = _validated_cuda_graph_cli_args( + monkeypatch, ['--cuda-graph-impl', 'local', '--inference-cuda-graph-scope', 'layer'] + ) + assert args.cuda_graph_impl == 'local' + assert args.inference_cuda_graph_scope == InferenceCudaGraphScope.layer + + def test_removed_module_scoped_scope_name_is_not_accepted(self, monkeypatch): + destroy_global_vars() + destroy_num_microbatches_calculator() + monkeypatch.setattr( + sys, + "argv", + [ + 'test_cuda_graphs.py', + '--cuda-graph-impl', + 'local', + '--inference-cuda-graph-scope', + 'module_scoped', + ], + ) + with pytest.raises(SystemExit): + parse_args() + + def test_removed_old_inference_bool_flag_is_not_accepted(self, monkeypatch): + destroy_global_vars() + destroy_num_microbatches_calculator() + monkeypatch.setattr( + sys, "argv", ['test_cuda_graphs.py', '--inference-use-full-iteration-cuda-graph'] + ) + with pytest.raises(SystemExit): + parse_args() + + # --- Backward compat: cuda_graph_scope → cuda_graph_modules rename --- + + def test_deprecated_cuda_graph_scope_kwarg_migrates_to_modules(self): + with pytest.warns(DeprecationWarning, match="cuda_graph_scope is deprecated"): + cfg = _base_cuda_graph_config(cuda_graph_scope=['attn']) + assert cfg.cuda_graph_modules == [CudaGraphModule.attn] + assert cfg.cuda_graph_scope is None + + def test_new_cuda_graph_modules_does_not_populate_deprecated_scope(self): + cfg = _base_cuda_graph_config(cuda_graph_modules=['attn', 'mlp']) + assert cfg.cuda_graph_modules == [CudaGraphModule.attn, CudaGraphModule.mlp] + assert cfg.cuda_graph_scope is None + + def test_new_full_iteration_impl_does_not_populate_deprecated_scope(self): + cfg = _base_cuda_graph_config(cuda_graph_impl='full_iteration', cuda_graph_modules=[]) + assert cfg.cuda_graph_scope is None + + def test_deprecated_cuda_graph_scope_cli_migrates_to_modules(self, monkeypatch): + args, warning_messages, _ = _validated_cuda_graph_cli_args( + monkeypatch, ['--cuda-graph-impl', 'local', '--cuda-graph-scope', 'attn'] + ) + assert args.cuda_graph_modules == [CudaGraphModule.attn] + assert any('--cuda-graph-scope is deprecated' in msg for msg in warning_messages) + + def test_cuda_graph_scope_is_standalone_class_for_pickle_compat(self): + from megatron.core.transformer.enums import CudaGraphScope + + # CudaGraphScope is preserved as a standalone class (not an alias) so that + # pre-refactor checkpoints can be deserialized without value-collision errors. + assert CudaGraphScope is not CudaGraphModule + assert CudaGraphScope.attn.value == 2 # original ordinals preserved + assert CudaGraphScope.mamba.value == 7 + + def test_cuda_graph_scope_and_inference_scope_in_safe_globals(self): + from megatron.core.safe_globals import SAFE_GLOBALS + from megatron.core.transformer.enums import CudaGraphScope + + assert CudaGraphScope in SAFE_GLOBALS + assert InferenceCudaGraphScope in SAFE_GLOBALS + + def test_deprecated_cuda_graph_scope_enum_instance_migrates_to_modules(self): + from megatron.core.transformer.enums import CudaGraphScope + + with pytest.warns(DeprecationWarning, match="cuda_graph_scope is deprecated"): + cfg = _base_cuda_graph_config(cuda_graph_scope=[CudaGraphScope.attn]) + assert cfg.cuda_graph_modules == [CudaGraphModule.attn] + assert cfg.cuda_graph_scope is None + + def test_deprecated_cuda_graph_scope_full_iteration_enum_migrates_to_impl(self): + from megatron.core.transformer.enums import CudaGraphScope + + with pytest.warns(DeprecationWarning): + cfg = _base_cuda_graph_config(cuda_graph_scope=[CudaGraphScope.full_iteration]) + assert cfg.cuda_graph_impl == "full_iteration" + assert cfg.cuda_graph_modules == [] + assert cfg.cuda_graph_scope is None + + def test_deprecated_cuda_graph_scope_full_iteration_inference_enum_migrates_to_scope(self): + from megatron.core.transformer.enums import CudaGraphScope + + with pytest.warns(DeprecationWarning): + cfg = _base_cuda_graph_config( + cuda_graph_impl="local", cuda_graph_scope=[CudaGraphScope.full_iteration_inference] + ) + assert cfg.inference_cuda_graph_scope == InferenceCudaGraphScope.block + assert cfg.cuda_graph_modules == [] + assert cfg.cuda_graph_scope is None + + def test_deprecated_cuda_graph_scope_full_iteration_inference_noops_without_local_impl(self): + from megatron.core.transformer.enums import CudaGraphScope + + with pytest.warns(DeprecationWarning, match="has no effect"): + cfg = _base_cuda_graph_config( + cuda_graph_scope=[CudaGraphScope.full_iteration_inference] + ) + assert cfg.cuda_graph_impl == "none" + assert cfg.inference_cuda_graph_scope == InferenceCudaGraphScope.none + assert cfg.cuda_graph_modules == [] + assert cfg.cuda_graph_scope is None + + class TestParallelTransformerBlockCudagraphs: def setup_method(self, method): # initialize parallel state @@ -877,7 +1145,7 @@ def model_provider( ) def create_test_args( - self, cuda_graph_impl, cuda_graph_scope, cuda_graph_warmup_steps, ep_size, **kwargs + self, cuda_graph_impl, cuda_graph_modules, cuda_graph_warmup_steps, ep_size, **kwargs ): destroy_global_vars() destroy_num_microbatches_calculator() @@ -922,7 +1190,7 @@ def create_test_args( # CUDA graph settings args.cuda_graph_impl = cuda_graph_impl - args.cuda_graph_scope = cuda_graph_scope + args.cuda_graph_modules = cuda_graph_modules args.cuda_graph_warmup_steps = cuda_graph_warmup_steps # fp8 settings @@ -953,11 +1221,11 @@ def get_batch(self, seq_length, micro_batch_size, cp_size): return input_ids, labels, position_ids, attention_mask, loss_mask def _run_test_helper( - self, ep_size, cuda_graph_impl, cuda_graph_scope, cuda_graph_warmup_steps, **kwargs + self, ep_size, cuda_graph_impl, cuda_graph_modules, cuda_graph_warmup_steps, **kwargs ): """Test fp8_param with gpt_model.""" args = self.create_test_args( - cuda_graph_impl, cuda_graph_scope, cuda_graph_warmup_steps, ep_size, **kwargs + cuda_graph_impl, cuda_graph_modules, cuda_graph_warmup_steps, ep_size, **kwargs ) set_args(args) @@ -1062,20 +1330,20 @@ def test_moe_partial_cudagraph(self, ep_size, moe_dropless_dispatcher, moe_dispa extra_kwargs["moe_pad_expert_input_to_capacity"] = True loss_list_ref = self._run_test_helper(ep_size, "none", None, 0, **extra_kwargs) - for cuda_graph_scope in [ + for cuda_graph_modules in [ None, - [CudaGraphScope.attn], - [CudaGraphScope.moe], - [CudaGraphScope.mlp, CudaGraphScope.moe_router], + [CudaGraphModule.attn], + [CudaGraphModule.moe], + [CudaGraphModule.mlp, CudaGraphModule.moe_router], [ - CudaGraphScope.attn, - CudaGraphScope.mlp, - CudaGraphScope.moe_router, - CudaGraphScope.moe_preprocess, + CudaGraphModule.attn, + CudaGraphModule.mlp, + CudaGraphModule.moe_router, + CudaGraphModule.moe_preprocess, ], ]: if (moe_dropless_dispatcher or moe_dispatcher_type == "hybridep") and ( - cuda_graph_scope is None or CudaGraphScope.moe in cuda_graph_scope + cuda_graph_modules is None or CudaGraphModule.moe in cuda_graph_modules ): # Dropless MoE or Hybrid EP doesn't work with "moe" scope cudagraph. Skip. continue @@ -1083,7 +1351,7 @@ def test_moe_partial_cudagraph(self, ep_size, moe_dropless_dispatcher, moe_dispa loss_list = self._run_test_helper( ep_size, "transformer_engine", - cuda_graph_scope, + cuda_graph_modules, cuda_graph_warmup_steps, **extra_kwargs, ) diff --git a/tests/unit_tests/transformer/test_transformer_layer.py b/tests/unit_tests/transformer/test_transformer_layer.py index 93570e07678..50f9476d58f 100644 --- a/tests/unit_tests/transformer/test_transformer_layer.py +++ b/tests/unit_tests/transformer/test_transformer_layer.py @@ -1,6 +1,8 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +import gc + import pytest import torch @@ -8,14 +10,22 @@ from megatron.core.dist_checkpointing.mapping import ShardedObject, ShardedTensor from megatron.core.inference.contexts import StaticInferenceContext from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_with_transformer_engine_spec, get_gpt_layer_with_transformer_engine_submodules, ) -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.tensor_parallel.random import ( + HAVE_TE, + initialize_rng_tracker, + model_parallel_cuda_manual_seed, +) +from megatron.core.transformer.cuda_graphs import CudaGraphManager, _CudagraphGlobalRecord +from megatron.core.transformer.enums import InferenceCudaGraphScope from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_layer import ( TransformerLayer, get_transformer_layer_offset, ) +from megatron.core.utils import is_te_min_version from tests.unit_tests.test_utilities import Utils @@ -336,3 +346,73 @@ def get_tensor_shapes_for_tp(transformer_config, tp_size): 'self_attention.linear_qkv.weight': (hs * 3 // tp_size, hs), 'self_attention.linear_qkv.bias': (hs * 3 // tp_size,), } + + +def _make_cuda_graph_gpt_block(**config_kwargs): + cfg = TransformerConfig( + num_layers=2, + hidden_size=64, + num_attention_heads=4, + use_cpu_initialization=True, + **config_kwargs, + ) + from megatron.core.transformer.transformer_block import TransformerBlock + + return TransformerBlock(cfg, get_gpt_layer_with_transformer_engine_spec()) + + +def _reset_cudagraph_state(): + _CudagraphGlobalRecord.cudagraph_created = False + _CudagraphGlobalRecord.cudagraph_record = [] + CudaGraphManager.global_mempool = None + torch.cuda.synchronize() + + +def _all_layers_have_manager(block) -> bool: + return all(hasattr(layer, 'cudagraph_manager') for layer in block.layers) + + +def _no_layers_have_manager(block) -> bool: + return all(not hasattr(layer, 'cudagraph_manager') for layer in block.layers) + + +@pytest.mark.skipif( + not (HAVE_TE and is_te_min_version("1.5.0")), + reason="CUDA graph tests require TransformerEngine >= 1.5", +) +class TestTransformerLayerCudaGraphManagers: + def setup_method(self, method): + initialize_rng_tracker(use_te_rng_tracker=True, force_reset=True) + Utils.initialize_model_parallel(1, 1) + model_parallel_cuda_manual_seed(123) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + _reset_cudagraph_state() + gc.collect() + + def test_empty_scope_transformer_layer_has_per_layer_manager(self): + block = _make_cuda_graph_gpt_block( + cuda_graph_impl='local', cuda_graph_modules=[], inference_cuda_graph_scope='layer' + ) + assert _all_layers_have_manager(block) + _reset_cudagraph_state() + + def test_empty_scope_transformer_block_no_per_layer_manager(self): + block = _make_cuda_graph_gpt_block( + cuda_graph_impl='local', cuda_graph_modules=[], inference_cuda_graph_scope='block' + ) + assert _no_layers_have_manager(block) + _reset_cudagraph_state() + + def test_deprecated_full_iteration_inference_scope_string_matches_new_granularity(self): + with pytest.warns( + DeprecationWarning, match="cuda_graph_modules 'full_iteration_inference' is deprecated" + ): + block = _make_cuda_graph_gpt_block( + cuda_graph_impl='local', cuda_graph_modules='full_iteration_inference' + ) + assert block.config.inference_cuda_graph_scope == InferenceCudaGraphScope.block + assert block.config.cuda_graph_modules == [] + assert _no_layers_have_manager(block) + _reset_cudagraph_state()