diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 929b0f60765..e901e40597a 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -433,41 +433,227 @@ def __new__(cls, config: TransformerConfig): TEActivationOp = None +if HAVE_TE and is_te_min_version("1.13.0"): + + class TEFusedResidualRMSNorm(te.pytorch.RMSNorm): + """ + RMSNorm with fused residual output for Megatron Core. + + Inherits from te.pytorch.RMSNorm to maintain all parameter management, + checkpoint compatibility, and Megatron-specific features. Creates a fused + implementation using TE's ops API that shares the base class parameters. + + The fused implementation uses: + - MakeExtraOutput: Forks the residual connection + - RMSNorm: Normalizes the main path + + Forward pass returns: (normalized_output, residual) + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Fused implementation (stored in tuple to avoid submodule registration) + self._fused_impl: Optional[Tuple[te.pytorch.ops.Sequential]] = None + + def _make_fused_impl(self) -> te.pytorch.ops.Sequential: + """ + Construct fused ops pipeline that shares parameters with base RMSNorm. + + Creates MakeExtraOutput + RMSNorm ops, where the RMSNorm op shares + the weight parameter with self.weight from the base class. + """ + + fused_impl = te.pytorch.ops.Sequential() + + # Op 1: MakeExtraOutput - forks the residual + fused_impl.append(te.pytorch.ops.MakeExtraOutput()) + + # Op 2: RMSNorm - shares weight parameter with self + kwargs = { + "eps": self.eps, + "device": "meta", # Already initialized + "dtype": self.weight.dtype, + "zero_centered_gamma": self.zero_centered_gamma, + } + + # Add sm_margin if available (TE 2.5+) + if hasattr(self, '_sm_margins'): + kwargs["sm_margin"] = self._sm_margins + + rmsnorm_op = te.pytorch.ops.RMSNorm(self.weight.shape, **kwargs) + + rmsnorm_op.weight = self.weight + + fused_impl.append(rmsnorm_op) + + self._register_hooks_on_fused_impl(fused_impl) + + return fused_impl + + def _register_hooks_on_fused_impl(self, fused_impl: torch.nn.Module) -> None: + + forward_pre_hooks = [] + forward_post_hooks = [] + backward_pre_hooks = [] + backward_post_hooks = [] + + for submodule in self.modules(): + for hook in submodule._forward_pre_hooks.values(): + forward_pre_hooks.append((submodule, hook)) + for hook in submodule._forward_hooks.values(): + forward_post_hooks.append((submodule, hook)) + for hook in submodule._backward_pre_hooks.values(): + backward_pre_hooks.append((submodule, hook)) + for hook in submodule._backward_hooks.values(): + backward_post_hooks.append((submodule, hook)) + + # Pre-forward hooks + # Note: DDP pre-forward hooks are safe since they do not + # interact with input tensor. + if forward_pre_hooks: + from megatron.core.distributed import distributed_data_parallel + + if any( + inspect.getmodule(hook) != distributed_data_parallel + for _, hook in forward_pre_hooks + ): + warnings.warn( + "TEFusedResidualRMSNorm module has a submodule with a pre-forward hook. " + "TEFusedResidualRMSNorm module does not expose intermediate tensors, " + "so the hook may have incorrect behavior if it attempts to " + "access the input tensor." + ) + + def forward_pre_hook(module, *_) -> None: + for submodule, hook in forward_pre_hooks: + # Assume that hook does not interact with input + ret = hook(submodule, None) + if ret is not None: + raise RuntimeError( + "TEFusedResidualRMSNorm module does not expose " + "intermediate tensors, but submodule has " + "pre-forward hook that modifies input tensor." + ) + + fused_impl.register_forward_pre_hook(forward_pre_hook) + + # Post-forward hooks + if forward_post_hooks: + warnings.warn( + "TEFusedResidualRMSNorm module has a submodule with a post-forward hook. " + "TEFusedResidualRMSNorm module does not expose intermediate tensors, " + "so the hook may have incorrect behavior if it attempts to " + "access the input or output tensors." + ) + + def forward_post_hook(module, *_) -> None: + for submodule, hook in forward_post_hooks: + # Assume that hook does not interact with input or output + ret = hook(submodule, None, None) + if ret is not None: + raise RuntimeError( + "TEFusedResidualRMSNorm module does not expose " + "intermediate tensors, but submodule has " + "post-forward hook that modifies output tensor." + ) + + fused_impl.register_forward_hook(forward_post_hook) + + # Backward hooks + if backward_pre_hooks: + raise RuntimeError( + "TEFusedResidualRMSNorm module does not support " + "submodules with pre-backward hooks" + ) + if backward_post_hooks: + raise RuntimeError( + "TEFusedResidualRMSNorm module does not support " + "submodules with post-backward hooks" + ) + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass with fused residual output. + + Args: + hidden_states: Input tensor [s, b, h] + + Returns: + Tuple of (normalized_output, residual), both [s, b, h] + + Note: + Sequential.forward() automatically returns (output, extra_outputs...) + when MakeExtraOutput is present, so we don't need manual unpacking. + """ + + # Construct fused impl lazily on first forward + # (in case parameters are modified after __init__) + if self._fused_impl is None: + self._fused_impl = (self._make_fused_impl(),) + + # Apply fused implementation + # Sequential returns (normalized_output, residual) automatically + return self._fused_impl[0](hidden_states) + +else: + TEFusedResidualRMSNorm = None # type: ignore[assignment, misc] + + class TENorm: """A conditional wrapper to initialize an instance of - Transformer-Engine's `LayerNorm` or `RMSNorm` based on input.""" + Transformer-Engine's `LayerNorm` or `RMSNorm` based on input. + + Residual fusion is a two-level opt-in mechanism: + + 1. Global capability: config.fused_residual_rmsnorm must be True (enables the feature) + 2. Local intent: has_residual=True must be passed at build site (declares this specific + norm is followed by a residual connection) + + Fusion only happens when BOTH conditions are met. + + """ # TODO should we ditch normalization config and just use spec to choose LayerNorm vs RMSNorm? def __new__( - cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5 - ) -> LayerNormInterface: + cls, + config: TransformerConfig, + hidden_size: int, + eps: float = 1e-5, + has_residual: bool = False, + ): if not HAVE_TE: raise ImportError( "Transformer Engine is not installed. " "Please install it with `pip install transformer-engine`." ) + use_fused_residual = config.fused_residual_rmsnorm and has_residual + if use_fused_residual and config.normalization != "RMSNorm": + raise ValueError("Fused residual is only supported " "for RMSNorm normalization") + if config.normalization == "LayerNorm": - instance = te.pytorch.LayerNorm( - hidden_size=hidden_size, - eps=eps, - sequence_parallel=config.sequence_parallel, - zero_centered_gamma=config.layernorm_zero_centered_gamma, - **_get_extra_te_kwargs(config), - ) + norm_module = te.pytorch.LayerNorm elif config.normalization == "RMSNorm": assert hasattr( te.pytorch, "RMSNorm" ), "Transformer-Engine >= v0.11 required to use this feature" - instance = te.pytorch.RMSNorm( - hidden_size=hidden_size, - eps=eps, - sequence_parallel=config.sequence_parallel, - zero_centered_gamma=config.layernorm_zero_centered_gamma, - **_get_extra_te_kwargs(config), - ) + if use_fused_residual: + assert ( + TEFusedResidualRMSNorm is not None + ), "TEFusedResidualRMSNorm requires Transformer-Engine >= v1.13.0" + norm_module = TEFusedResidualRMSNorm + else: + norm_module = te.pytorch.RMSNorm else: - raise Exception("Only LayerNorm and RMSNorm are curently supported") + raise Exception("Only LayerNorm and RMSNorm are currently supported") + + instance = norm_module( + normalized_shape=hidden_size, + eps=eps, + sequence_parallel=config.sequence_parallel, + zero_centered_gamma=config.layernorm_zero_centered_gamma, + **_get_extra_te_kwargs(config), + ) return cast(LayerNormInterface, instance) diff --git a/megatron/core/extensions/transformer_engine_spec_provider.py b/megatron/core/extensions/transformer_engine_spec_provider.py index 4ac6a061552..d6e6b1f8f33 100644 --- a/megatron/core/extensions/transformer_engine_spec_provider.py +++ b/megatron/core/extensions/transformer_engine_spec_provider.py @@ -28,6 +28,13 @@ from megatron.core.utils import get_te_version, is_te_min_version +class _TENormWithResidual: + """Class adapter for TENorm with residual fusion enabled.""" + + def __new__(cls, *args, **kwargs): + return TENorm(*args, has_residual=True, **kwargs) + + class TESpecProvider(BackendSpecProvider): """A protocol for providing the submodules used in Spec building.""" @@ -51,14 +58,17 @@ def column_parallel_layer_norm_linear(self) -> Optional[type]: """Which module for sequential layernorm and linear""" return TELayerNormColumnParallelLinear - def layer_norm(self, rms_norm: bool = False, for_qk: bool = False) -> LayerNormBuilder: + def layer_norm( + self, rms_norm: bool = False, for_qk: bool = False, has_residual: bool = False + ) -> LayerNormBuilder: """Which module to use for layer norm""" if for_qk and not is_te_min_version("1.9.0"): # TENorm significantly harms convergence when used # for QKLayerNorm if TE Version < 1.9; # we instead use the Apex implementation. return FusedLayerNorm - return TENorm + # Keep returning a class so this path stays aligned with build_module's class handling. + return _TENormWithResidual if has_residual else TENorm def core_attention(self) -> type: """Which module to use for attention""" diff --git a/megatron/core/models/backends.py b/megatron/core/models/backends.py index f3995519595..e867e91c003 100644 --- a/megatron/core/models/backends.py +++ b/megatron/core/models/backends.py @@ -72,7 +72,9 @@ def column_parallel_layer_norm_linear(self) -> Optional[type]: ... @abstractmethod - def layer_norm(self, rms_norm: bool = False, for_qk: bool = False) -> LayerNormBuilder: + def layer_norm( + self, rms_norm: bool = False, for_qk: bool = False, has_residual: bool = False + ) -> LayerNormBuilder: """Which module for layernorm""" ... @@ -113,7 +115,9 @@ def column_parallel_layer_norm_linear(self) -> Optional[type]: """Which module for sequential layernorm and linear""" return None - def layer_norm(self, rms_norm: bool = False, for_qk: bool = False) -> LayerNormBuilder: + def layer_norm( + self, rms_norm: bool = False, for_qk: bool = False, has_residual: bool = False + ) -> LayerNormBuilder: """Which module to use for layer norm""" if rms_norm: # Matching get_gpt_layer_local_spec. @@ -162,7 +166,9 @@ def column_parallel_layer_norm_linear(self) -> type[InferenceLayerNormColumnPara """Which module for sequential layernorm and linear""" return InferenceLayerNormColumnParallelLinear - def layer_norm(self, rms_norm: bool = False, for_qk: bool = False) -> LayerNormBuilder: + def layer_norm( + self, rms_norm: bool = False, for_qk: bool = False, has_residual: bool = False + ) -> LayerNormBuilder: """Which module to use for layer norm""" if for_qk and not is_te_min_version("1.9.0"): # TENorm significantly harms convergence when used diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py index 103601a3be0..328d1c2f07f 100755 --- a/megatron/core/models/gpt/gpt_layer_specs.py +++ b/megatron/core/models/gpt/gpt_layer_specs.py @@ -110,7 +110,7 @@ def get_gpt_layer_with_inference_submodules( else backend.column_parallel_linear() ) return TransformerLayerSubmodules( - input_layernorm=backend.layer_norm(), + input_layernorm=backend.layer_norm(has_residual=True), self_attention=ModuleSpec( module=MLASelfAttention, params={"attn_mask_type": AttnMaskType.causal}, @@ -244,7 +244,7 @@ def get_gpt_layer_with_transformer_engine_submodules( else backend.column_parallel_linear() ) return TransformerLayerSubmodules( - input_layernorm=backend.layer_norm(), + input_layernorm=backend.layer_norm(has_residual=True), self_attention=ModuleSpec( module=MLASelfAttention, params={"attn_mask_type": AttnMaskType.causal}, @@ -261,7 +261,7 @@ def get_gpt_layer_with_transformer_engine_submodules( ), ), self_attn_bda=get_bias_dropout_add, - pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp, + pre_mlp_layernorm=backend.layer_norm(has_residual=True) if num_experts else IdentityOp, mlp=mlp, mlp_bda=get_bias_dropout_add, ) @@ -284,7 +284,7 @@ def get_gpt_layer_with_transformer_engine_submodules( ), ), self_attn_bda=get_bias_dropout_add, - pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp, + pre_mlp_layernorm=backend.layer_norm(has_residual=True) if num_experts else IdentityOp, mlp=mlp, mlp_bda=get_bias_dropout_add, sharded_state_dict_keys_map={ @@ -345,10 +345,10 @@ def get_gpt_layer_local_submodules( backend = LocalSpecProvider() # Adjust for RMS norm. if normalization == "RMSNorm": - layer_norm = backend.layer_norm(rms_norm=True, for_qk=False) + layer_norm = backend.layer_norm(rms_norm=True, for_qk=False, has_residual=True) qk_norm = backend.layer_norm(rms_norm=True, for_qk=True) else: - layer_norm = backend.layer_norm(rms_norm=False, for_qk=False) + layer_norm = backend.layer_norm(rms_norm=False, for_qk=False, has_residual=True) qk_norm = backend.layer_norm(rms_norm=False, for_qk=True) if fp8 is not None: diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 642af8415d3..fd4025de9f7 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -446,6 +446,9 @@ class TransformerConfig(ModelParallelConfig): fused_single_qkv_rope: bool = False """If set, avoid splitting QKV before ROPE forward and avoid concatenating ROPE dgrads.""" + fused_residual_rmsnorm: bool = False + """If True, fuses residual connection and RMSNorm backward pass when TE is used.""" + #################### # activation recomputation #################### @@ -1635,6 +1638,12 @@ def __post_init__(self): "to True and use_te_activation_func to False." ) + if self.fused_residual_rmsnorm: + if self.normalization != "RMSNorm": + raise ValueError( + "fused_residual_rmsnorm is only supported when normalization is RMSNorm." + ) + if self.use_te_activation_func: if self.activation_func not in (F.gelu, F.silu, F.relu): raise ValueError( diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 90c5f6e5084..7cf565e1eec 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -576,11 +576,6 @@ def _forward_attention( inference_context = deprecate_inference_params(inference_context, inference_params) - # Residual connection. - residual = hidden_states - if self.config.fp32_residual_connection: - residual = residual.float() - # Optional Input Layer norm if self.recompute_input_layernorm: self.input_layernorm_checkpoint = tensor_parallel.CheckpointWithoutOutput() @@ -592,6 +587,20 @@ def _forward_attention( with off_interface(self.offload_attn_norm, hidden_states, "attn_norm") as hidden_states: input_layernorm_output = apply_module(self.input_layernorm)(hidden_states) + if isinstance(input_layernorm_output, tuple): + if len(input_layernorm_output) != 2: + raise ValueError( + f"When the output of input_layernorm is a tuple, it is " + f"expected to have 2 elements (output, residual), but " + f"got {len(input_layernorm_output)}" + ) + input_layernorm_output, residual = input_layernorm_output + else: + residual = hidden_states + + if self.config.fp32_residual_connection: + residual = residual.float() + using_fused_tp_inference_kernel = (not self.training) and ( self.config.inference_fuse_tp_communication ) @@ -646,14 +655,23 @@ def _forward_attention( hidden_states, name="attn_norm", forced_released_tensors=[residual] ) - # Residual connection. - residual = hidden_states - if self.config.fp32_residual_connection: - residual = residual.float() - # Optional Layer norm after self-attention pre_cross_attn_layernorm_output = apply_module(self.pre_cross_attn_layernorm)(hidden_states) + if isinstance(pre_cross_attn_layernorm_output, tuple): + if len(pre_cross_attn_layernorm_output) != 2: + raise ValueError( + f"When the output of pre_cross_attn_layernorm_output " + f"is a tuple, it is expected to have 2 elements " + f"(output, residual), but " + f"got {len(pre_cross_attn_layernorm_output)}" + ) + pre_cross_attn_layernorm_output, residual = pre_cross_attn_layernorm_output + else: + residual = hidden_states + + if self.config.fp32_residual_connection: + residual = residual.float() # Cross attention. attention_output_with_bias = self.cross_attention( pre_cross_attn_layernorm_output, @@ -728,14 +746,24 @@ def _forward_mlp( output (Tensor): Transformed hidden states of shape [s, b, h]. """ - # Residual connection. - residual = hidden_states - if self.config.fp32_residual_connection: - residual = residual.float() - # Optional Layer norm post the cross-attention. pre_mlp_layernorm_output = self._forward_pre_mlp_layernorm(hidden_states) + if isinstance(pre_mlp_layernorm_output, tuple): + if len(pre_mlp_layernorm_output) != 2: + raise ValueError( + f"When the output of pre_mlp_layernorm is a tuple, it is " + f"expected to have 2 elements (output, residual), but " + f"got {len(pre_mlp_layernorm_output)}" + ) + pre_mlp_layernorm_output, residual = pre_mlp_layernorm_output + else: + # Residual connection. + residual = hidden_states + + if self.config.fp32_residual_connection: + residual = residual.float() + nvtx_range_push(suffix="mlp") # Potentially chunk the MLP computation during prefill to minimize the peak activation size should_chunk_mlp_for_prefill = ( @@ -1147,6 +1175,15 @@ def _te_cuda_graph_replay(self, *args, **kwargs): if not self.is_moe_layer: return residual, None, None, None hidden_states = apply_module(self.pre_mlp_layernorm)(residual) + if isinstance(hidden_states, tuple): + if len(hidden_states) != 2: + raise ValueError( + f"When the output of pre_mlp_layernorm is a tuple, it is " + f"expected to have 2 elements (output, residual), but " + f"got {len(hidden_states)}" + ) + hidden_states, residual = hidden_states + shared_expert_output = self.mlp.shared_experts_compute(hidden_states) probs, routing_map = self.mlp.route(hidden_states) hidden_states, probs = self.mlp.preprocess(hidden_states, probs, routing_map) @@ -1350,11 +1387,22 @@ def _forward_mlp_router(self, hidden_states, padding_mask=None): This method is isolated so it can be captured by `cudagraph_manager_router`. """ - residual = hidden_states - if self.config.fp32_residual_connection: - residual = residual.float() self.mlp.fwd_execution_map = "route" pre_mlp_layernorm_output = self._forward_pre_mlp_layernorm(hidden_states) + if isinstance(pre_mlp_layernorm_output, tuple): + if len(pre_mlp_layernorm_output) != 2: + raise ValueError( + f"When the output of pre_mlp_layernorm is a tuple, it is " + f"expected to have 2 elements (output, residual), but " + f"got {len(pre_mlp_layernorm_output)}" + ) + pre_mlp_layernorm_output, residual = pre_mlp_layernorm_output + else: + residual = hidden_states + + if self.config.fp32_residual_connection: + residual = residual.float() + router_outputs = self.mlp( pre_mlp_layernorm_output, intermediate_tensors=(), padding_mask=padding_mask ) diff --git a/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py b/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py new file mode 100644 index 00000000000..6c03e0fa801 --- /dev/null +++ b/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py @@ -0,0 +1,46 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import pytest +import torch +from transformer_engine.pytorch import RMSNorm + +from megatron.core.extensions.transformer_engine import TEFusedResidualRMSNorm + + +def baseline_rmsnorm_residual(x, rmsnorm: RMSNorm): + return rmsnorm(x), x + + +@pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("normalized_shape", [256, 256 * 2, 256 * 4]) +def test_rmsnorm_residual_fusion(input_dtype, normalized_shape): + x_baseline = torch.randn(16, 32, normalized_shape, dtype=input_dtype, device="cuda") + x_baseline.requires_grad = True + x_fused = x_baseline.detach() + x_fused.requires_grad = True + baseline_rmsnorm = RMSNorm(normalized_shape=normalized_shape, dtype=input_dtype).cuda() + fused_rmsnorm = TEFusedResidualRMSNorm( + normalized_shape=normalized_shape, dtype=input_dtype + ).cuda() + + # baseline + baseline_y, baseline_residual = baseline_rmsnorm_residual(x_baseline, baseline_rmsnorm) + baseline_loss = baseline_y.sum() + baseline_residual.sum() + baseline_loss.backward() + + # fused + fused_y, fused_residual = fused_rmsnorm(x_fused) + fused_loss = fused_y.sum() + fused_residual.sum() + fused_loss.backward() + + # Use tolerances appropriate for dtype (pattern from other tests) + tols = ( + dict(rtol=1e-6, atol=1e-6) if input_dtype is torch.float32 else dict(rtol=2e-2, atol=1e-2) + ) + + assert fused_y.dtype == baseline_y.dtype + assert torch.allclose(fused_y, baseline_y, **tols) + assert fused_residual.dtype == baseline_residual.dtype + assert torch.allclose(fused_residual, baseline_residual, **tols) + assert x_fused.grad.dtype == x_baseline.grad.dtype + assert torch.allclose(x_baseline.grad, x_fused.grad, **tols) diff --git a/tests/unit_tests/models/test_mamba_moe_model.py b/tests/unit_tests/models/test_mamba_moe_model.py index f8d1cde7028..c3894a8cb67 100644 --- a/tests/unit_tests/models/test_mamba_moe_model.py +++ b/tests/unit_tests/models/test_mamba_moe_model.py @@ -113,6 +113,7 @@ "fp8_quantizer_factory": None, "fp8_recipe": "delayed", "fp8_wgrad": True, + "fused_residual_rmsnorm": False, "fused_single_qkv_rope": False, "gated_linear_unit": False, "glu_linear_offset": 0.0,