From 2593e53cb99c3e92dec226448c10a9f9a2b7d570 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Mon, 26 Jan 2026 16:29:16 +0100 Subject: [PATCH 01/30] add fused te fused layernorm --- .../core/extensions/transformer_engine.py | 11 +++++-- .../core/transformer/transformer_layer.py | 29 ++++++++++++------- 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index bb913d97446..ae5f860bbab 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -466,8 +466,16 @@ def __new__( zero_centered_gamma=config.layernorm_zero_centered_gamma, **_get_extra_te_kwargs(config), ) + elif config.normalization == "ResidualRMSNorm": + extra_te_kwargs = _get_extra_te_kwargs(config) + extra_te_kwargs["dtype"] = extra_te_kwargs["params_dtype"] + del extra_te_kwargs["params_dtype"] + instance = te.pytorch.ops.Sequential( + te.pytorch.ops.MakeExtraOutput(), + te.pytorch.ops.RMSNorm(normalized_shape=hidden_size, eps=eps, zero_centered_gamma=config.layernorm_zero_centered_gamma, **extra_te_kwargs), + ) else: - raise Exception("Only LayerNorm and RMSNorm are curently supported") + raise Exception("Only LayerNorm, RMSNorm and ResidualRMSNorm are curently supported") return cast(LayerNormInterface, instance) @@ -2207,7 +2215,6 @@ def forward(self, hidden_states: torch.Tensor, **kwargs) -> Tuple[Tensor, Option else: TEFusedMLP = None # type: ignore[assignment, misc] - class TEDelayedScaling(te.common.recipe.DelayedScaling): """ Wrapper for the Transformer-Engine's `DelayedScaling` layer. diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 73be3496876..586614ffebf 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -573,11 +573,6 @@ def _forward_attention( inference_context = deprecate_inference_params(inference_context, inference_params) - # Residual connection. - residual = hidden_states - if self.config.fp32_residual_connection: - residual = residual.float() - # Optional Input Layer norm if self.recompute_input_layernorm: self.input_layernorm_checkpoint = tensor_parallel.CheckpointWithoutOutput() @@ -589,6 +584,16 @@ def _forward_attention( with off_interface(self.offload_attn_norm, hidden_states, "attn_norm") as hidden_states: input_layernorm_output = apply_module(self.input_layernorm)(hidden_states) + if isinstance(input_layernorm_output, tuple): + if len(input_layernorm_output) != 2: + raise ValueError(f"When the output of input_layernorm is a tuple, it is expected to have 2 elements (output, residual), but got {len(input_layernorm_output)}") + input_layernorm_output, residual = input_layernorm_output + else: + residual = hidden_states + + if self.config.fp32_residual_connection: + residual = residual.float() + using_fused_tp_inference_kernel = (not self.training) and ( self.config.inference_fuse_tp_communication ) @@ -643,14 +648,18 @@ def _forward_attention( hidden_states, name="attn_norm", forced_released_tensors=[residual] ) - # Residual connection. - residual = hidden_states - if self.config.fp32_residual_connection: - residual = residual.float() - # Optional Layer norm after self-attention pre_cross_attn_layernorm_output = apply_module(self.pre_cross_attn_layernorm)(hidden_states) + if isinstance(pre_cross_attn_layernorm_output, tuple): + if len(pre_cross_attn_layernorm_output) != 2: + raise ValueError(f"When the output of pre_cross_attn_layernorm_output is a tuple, it is expected to have 2 elements (output, residual), but got {len(pre_cross_attn_layernorm_output)}") + pre_cross_attn_layernorm_output, residual = pre_cross_attn_layernorm_output + else: + residual = hidden_states + + if self.config.fp32_residual_connection: + residual = residual.float() # Cross attention. attention_output_with_bias = self.cross_attention( pre_cross_attn_layernorm_output, From d74e0062eda4182db282549f9983199ba0688279 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Thu, 12 Feb 2026 11:10:09 +0100 Subject: [PATCH 02/30] revert changes to normalization parameter, add fusion flag instead --- .../core/extensions/transformer_engine.py | 30 ++++++++++--------- .../core/transformer/transformer_config.py | 9 +++++- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index ae5f860bbab..053dbbfbacb 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -459,21 +459,23 @@ def __new__( assert hasattr( te.pytorch, "RMSNorm" ), "Transformer-Engine >= v0.11 required to use this feature" - instance = te.pytorch.RMSNorm( - hidden_size=hidden_size, - eps=eps, - sequence_parallel=config.sequence_parallel, - zero_centered_gamma=config.layernorm_zero_centered_gamma, - **_get_extra_te_kwargs(config), - ) - elif config.normalization == "ResidualRMSNorm": + extra_te_kwargs = _get_extra_te_kwargs(config) - extra_te_kwargs["dtype"] = extra_te_kwargs["params_dtype"] - del extra_te_kwargs["params_dtype"] - instance = te.pytorch.ops.Sequential( - te.pytorch.ops.MakeExtraOutput(), - te.pytorch.ops.RMSNorm(normalized_shape=hidden_size, eps=eps, zero_centered_gamma=config.layernorm_zero_centered_gamma, **extra_te_kwargs), - ) + if config.fused_residual_rmsnorm: + extra_te_kwargs["dtype"] = extra_te_kwargs["params_dtype"] + del extra_te_kwargs["params_dtype"] + instance = te.pytorch.ops.Sequential( + te.pytorch.ops.MakeExtraOutput(), + te.pytorch.ops.RMSNorm(normalized_shape=hidden_size, eps=eps, zero_centered_gamma=config.layernorm_zero_centered_gamma, **extra_te_kwargs), + ) + else: + instance = te.pytorch.RMSNorm( + hidden_size=hidden_size, + eps=eps, + sequence_parallel=config.sequence_parallel, + zero_centered_gamma=config.layernorm_zero_centered_gamma, + **extra_te_kwargs, + ) else: raise Exception("Only LayerNorm, RMSNorm and ResidualRMSNorm are curently supported") diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 559f4226af2..7fe340c9e25 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -446,6 +446,9 @@ class TransformerConfig(ModelParallelConfig): fused_single_qkv_rope: bool = False """If set, avoid splitting QKV before ROPE forward and avoid concatenating ROPE dgrads.""" + fused_residual_rmsnorm: bool = False + """If True, uses fuses residual connection and RMSNorm when TE is used.""" + #################### # activation recomputation #################### @@ -1638,7 +1641,11 @@ def __post_init__(self): "If you use bias in MLP FC1, we recommend setting bias_activation_fusion " "to True and use_te_activation_func to False." ) - + + if self.fused_residual_rmsnorm: + if self.normalization != "RMSNorm": + raise ValueError("fused_residual_rmsnorm is only supported when normalization is RMSNorm.") + if self.use_te_activation_func: if self.activation_func not in (F.gelu, F.silu, F.relu): raise ValueError( From 6c1364f498a3dcd30062780fd7cd99f2c38c63f5 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Thu, 12 Feb 2026 11:10:09 +0100 Subject: [PATCH 03/30] Refactor TEFusedResidualRMSNorm properly wrapping it for compatibility with mcore --- .../core/extensions/transformer_engine.py | 198 +++++++++++++++++- 1 file changed, 191 insertions(+), 7 deletions(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 053dbbfbacb..19d281f4f93 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -433,6 +433,183 @@ def __new__(cls, config: TransformerConfig): TEActivationOp = None +if HAVE_TE and is_te_min_version("1.13.0"): + + class TEFusedResidualRMSNorm(te.pytorch.RMSNorm): + """ + RMSNorm with fused residual output for Megatron Core. + + Inherits from te.pytorch.RMSNorm to maintain all parameter management, + checkpoint compatibility, and Megatron-specific features. Creates a fused + implementation using TE's ops API that shares the base class parameters. + + The fused implementation uses: + - MakeExtraOutput: Forks the residual connection + - RMSNorm: Normalizes the main path + + Forward pass returns: (normalized_output, residual) + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Fused implementation (stored in tuple to avoid submodule registration) + self._fused_impl: Optional[Tuple[te.pytorch.ops.Sequential]] = None + + def _make_fused_impl(self) -> te.pytorch.ops.Sequential: + """ + Construct fused ops pipeline that shares parameters with base RMSNorm. + + Creates MakeExtraOutput + RMSNorm ops, where the RMSNorm op shares + the weight parameter with self.weight from the base class. + """ + + fused_impl = te.pytorch.ops.Sequential() + + # Op 1: MakeExtraOutput - forks the residual + fused_impl.append(te.pytorch.ops.MakeExtraOutput()) + + # Op 2: RMSNorm - shares weight parameter with self + kwargs = { + "eps": self.eps, + "device": "meta", # Already initialized + "dtype": self.weight.dtype, + "zero_centered_gamma": self.zero_centered_gamma, + } + + # Add sm_margin if available (TE 2.5+) + if hasattr(self, '_sm_margins'): + kwargs["sm_margin"] = self._sm_margins + + rmsnorm_op = te.pytorch.ops.RMSNorm(self.weight.shape, **kwargs) + + # CRITICAL: Share the weight parameter with base class + # This ensures checkpointing works through the base class + rmsnorm_op.weight = self.weight + + fused_impl.append(rmsnorm_op) + + # Transfer hooks from base module to fused implementation + # This is CRITICAL for DDP to work correctly + self._register_hooks_on_fused_impl(fused_impl) + + return fused_impl + + def _register_hooks_on_fused_impl(self, fused_impl: torch.nn.Module) -> None: + """ + Transfer hooks from base RMSNorm to fused implementation. + + This is critical for distributed training - DDP registers hooks on the + base module that must be executed. Follows TEFusedMLP pattern. + + Note: Transformer Engine's op fuser does not expose intermediate tensors, + so hooks that modify tensors will not work correctly. + """ + + # Collect hooks from all submodules (including self) + forward_pre_hooks = [] + forward_post_hooks = [] + backward_pre_hooks = [] + backward_post_hooks = [] + + for submodule in self.modules(): + for hook in submodule._forward_pre_hooks.values(): + forward_pre_hooks.append((submodule, hook)) + for hook in submodule._forward_hooks.values(): + forward_post_hooks.append((submodule, hook)) + for hook in submodule._backward_pre_hooks.values(): + backward_pre_hooks.append((submodule, hook)) + for hook in submodule._backward_hooks.values(): + backward_post_hooks.append((submodule, hook)) + + # Pre-forward hooks + # Note: DDP pre-forward hooks are safe since they do not + # interact with input tensor. + if forward_pre_hooks: + from megatron.core.distributed import distributed_data_parallel + + if any( + inspect.getmodule(hook) != distributed_data_parallel + for _, hook in forward_pre_hooks + ): + warnings.warn( + "TEFusedResidualRMSNorm module has a submodule with a pre-forward hook. " + "TEFusedResidualRMSNorm module does not expose intermediate tensors, " + "so the hook may have incorrect behavior if it attempts to " + "access the input tensor." + ) + + def forward_pre_hook(module, *_) -> None: + for submodule, hook in forward_pre_hooks: + # Assume that hook does not interact with input + ret = hook(submodule, None) + if ret is not None: + raise RuntimeError( + "TEFusedResidualRMSNorm module does not expose intermediate tensors, but " + "submodule has pre-forward hook that modifies input tensor." + ) + + fused_impl.register_forward_pre_hook(forward_pre_hook) + + # Post-forward hooks + if forward_post_hooks: + warnings.warn( + "TEFusedResidualRMSNorm module has a submodule with a post-forward hook. " + "TEFusedResidualRMSNorm module does not expose intermediate tensors, " + "so the hook may have incorrect behavior if it attempts to " + "access the input or output tensors." + ) + + def forward_post_hook(module, *_) -> None: + for submodule, hook in forward_post_hooks: + # Assume that hook does not interact with input or output + ret = hook(submodule, None, None) + if ret is not None: + raise RuntimeError( + "TEFusedResidualRMSNorm module does not expose intermediate tensors, but " + "submodule has post-forward hook that modifies output tensor." + ) + + fused_impl.register_forward_hook(forward_post_hook) + + # Backward hooks + if backward_pre_hooks: + raise RuntimeError( + "TEFusedResidualRMSNorm module does not support submodules with pre-backward hooks" + ) + if backward_post_hooks: + raise RuntimeError( + "TEFusedResidualRMSNorm module does not support submodules with post-backward hooks" + ) + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass with fused residual output. + + Args: + hidden_states: Input tensor [s, b, h] + + Returns: + Tuple of (normalized_output, residual), both [s, b, h] + + Note: + Sequential.forward() automatically returns (output, extra_outputs...) + when MakeExtraOutput is present, so we don't need manual unpacking. + """ + + # Construct fused impl lazily on first forward + # (in case parameters are modified after __init__) + if self._fused_impl is None: + self._fused_impl = (self._make_fused_impl(),) + + # Apply fused implementation + # Sequential returns (normalized_output, residual) automatically + return self._fused_impl[0](hidden_states) + +else: + TEFusedResidualRMSNorm = None # type: ignore[assignment, misc] + + class TENorm: """A conditional wrapper to initialize an instance of Transformer-Engine's `LayerNorm` or `RMSNorm` based on input.""" @@ -459,18 +636,25 @@ def __new__( assert hasattr( te.pytorch, "RMSNorm" ), "Transformer-Engine >= v0.11 required to use this feature" - + extra_te_kwargs = _get_extra_te_kwargs(config) + if config.fused_residual_rmsnorm: - extra_te_kwargs["dtype"] = extra_te_kwargs["params_dtype"] - del extra_te_kwargs["params_dtype"] - instance = te.pytorch.ops.Sequential( - te.pytorch.ops.MakeExtraOutput(), - te.pytorch.ops.RMSNorm(normalized_shape=hidden_size, eps=eps, zero_centered_gamma=config.layernorm_zero_centered_gamma, **extra_te_kwargs), + # Use fused residual variant + assert TEFusedResidualRMSNorm is not None, ( + "TEFusedResidualRMSNorm requires Transformer-Engine >= v1.13.0" + ) + instance = TEFusedResidualRMSNorm( + normalized_shape=hidden_size, + eps=eps, + sequence_parallel=config.sequence_parallel, + zero_centered_gamma=config.layernorm_zero_centered_gamma, + **extra_te_kwargs, ) else: + # Standard RMSNorm without fusion instance = te.pytorch.RMSNorm( - hidden_size=hidden_size, + normalized_shape=hidden_size, eps=eps, sequence_parallel=config.sequence_parallel, zero_centered_gamma=config.layernorm_zero_centered_gamma, From 57596cf79803bf5d887c7159275616aa31771047 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Thu, 12 Feb 2026 11:10:09 +0100 Subject: [PATCH 04/30] add more spots where tuple outputs break mcore --- megatron/core/extensions/transformer_engine.py | 2 +- megatron/core/transformer/multi_token_prediction.py | 6 ++++++ megatron/core/transformer/transformer_block.py | 6 +++++- megatron/core/transformer/transformer_layer.py | 11 +++++++++++ 4 files changed, 23 insertions(+), 2 deletions(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 19d281f4f93..bef218a0c52 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -661,7 +661,7 @@ def __new__( **extra_te_kwargs, ) else: - raise Exception("Only LayerNorm, RMSNorm and ResidualRMSNorm are curently supported") + raise Exception("Only LayerNorm and RMSNorm are curently supported") return cast(LayerNormInterface, instance) diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py index 4ad2e517cfc..f7928f43140 100755 --- a/megatron/core/transformer/multi_token_prediction.py +++ b/megatron/core/transformer/multi_token_prediction.py @@ -897,8 +897,12 @@ def _concat_embeddings(self, hidden_states: torch.Tensor, decoder_input: torch.T Concatenate the tokens before sending to transformer layer. """ decoder_input = apply_module(self.enorm)(decoder_input) + if isinstance(decoder_input, tuple): + decoder_input = decoder_input[0] decoder_input = make_viewless_tensor(inp=decoder_input, requires_grad=True, keep_graph=True) hidden_states = apply_module(self.hnorm)(hidden_states) + if isinstance(hidden_states, tuple): + hidden_states = hidden_states[0] hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) # At the (k - 1)-th MTP module, concatenates the i-th token's hidden_states # and the (i + K)-th token's embedding, and combine them with linear projection. @@ -991,6 +995,8 @@ def _postprocess(self, hidden_states: torch.Tensor): # Layer norm before shared head layer. hidden_states = apply_module(self.final_layernorm)(hidden_states) + if isinstance(hidden_states, tuple): + hidden_states = hidden_states[0] # TENorm produces a "viewed" tensor. This will result in schedule.py's # deallocate_output_tensor() throwing an error, so a viewless tensor is # created to prevent this. diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index 703f6c266f1..6c3a97a19ce 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -855,7 +855,11 @@ def forward( # Final layer norm. if self.final_layernorm is not None: - hidden_states = apply_module(self.final_layernorm)(cast(Tensor, hidden_states)) + hidden_states = self.final_layernorm(hidden_states) + # Handle fused residual normalization (returns tuple of (output, residual)) + # For final layernorm, we only need the normalized output, not the residual + if isinstance(hidden_states, tuple): + hidden_states = hidden_states[0] # TENorm produces a "viewed" tensor. This will result in schedule.py's # deallocate_output_tensor() throwing an error, so a viewless tensor is # created to prevent this. diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 586614ffebf..a4a72f969d9 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -742,6 +742,15 @@ def _forward_mlp( # Optional Layer norm post the cross-attention. pre_mlp_layernorm_output = self._forward_pre_mlp_layernorm(hidden_states) + # Handle fused residual normalization (returns tuple of (output, residual)) + if isinstance(pre_mlp_layernorm_output, tuple): + if len(pre_mlp_layernorm_output) != 2: + raise ValueError( + f"When the output of pre_mlp_layernorm is a tuple, it is expected to have " + f"2 elements (output, residual), but got {len(pre_mlp_layernorm_output)}" + ) + pre_mlp_layernorm_output, residual = pre_mlp_layernorm_output + nvtx_range_push(suffix="mlp") # Potentially chunk the MLP computation during prefill to minimize the peak activation size should_chunk_mlp_for_prefill = ( @@ -1153,6 +1162,8 @@ def _te_cuda_graph_replay(self, *args, **kwargs): if not self.is_moe_layer: return residual, None, None, None hidden_states = apply_module(self.pre_mlp_layernorm)(residual) + if isinstance(hidden_states, tuple): + hidden_states, residual = hidden_states shared_expert_output = self.mlp.shared_experts_compute(hidden_states) probs, routing_map = self.mlp.route(hidden_states) hidden_states, probs = self.mlp.preprocess(hidden_states, probs, routing_map) From ba5c4a5a8bd314ff1c1804ca6f09bab273e5ecf1 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Thu, 12 Feb 2026 11:10:09 +0100 Subject: [PATCH 05/30] remove excessive comments --- megatron/core/extensions/transformer_engine.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index bef218a0c52..a7101a04fda 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -483,30 +483,16 @@ def _make_fused_impl(self) -> te.pytorch.ops.Sequential: rmsnorm_op = te.pytorch.ops.RMSNorm(self.weight.shape, **kwargs) - # CRITICAL: Share the weight parameter with base class - # This ensures checkpointing works through the base class rmsnorm_op.weight = self.weight fused_impl.append(rmsnorm_op) - # Transfer hooks from base module to fused implementation - # This is CRITICAL for DDP to work correctly self._register_hooks_on_fused_impl(fused_impl) return fused_impl def _register_hooks_on_fused_impl(self, fused_impl: torch.nn.Module) -> None: - """ - Transfer hooks from base RMSNorm to fused implementation. - - This is critical for distributed training - DDP registers hooks on the - base module that must be executed. Follows TEFusedMLP pattern. - - Note: Transformer Engine's op fuser does not expose intermediate tensors, - so hooks that modify tensors will not work correctly. - """ - # Collect hooks from all submodules (including self) forward_pre_hooks = [] forward_post_hooks = [] backward_pre_hooks = [] From 20393df71a5ed70928b617211bfef67afa124942 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Thu, 12 Feb 2026 11:10:09 +0100 Subject: [PATCH 06/30] add quantization --- megatron/core/extensions/transformer_engine.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index a7101a04fda..ca286a1f44e 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -450,8 +450,9 @@ class TEFusedResidualRMSNorm(te.pytorch.RMSNorm): Forward pass returns: (normalized_output, residual) """ - def __init__(self, *args, **kwargs): + def __init__(self, quantize: bool, *args, **kwargs): super().__init__(*args, **kwargs) + self.quantize = quantize # Fused implementation (stored in tuple to avoid submodule registration) self._fused_impl: Optional[Tuple[te.pytorch.ops.Sequential]] = None @@ -480,6 +481,9 @@ def _make_fused_impl(self) -> te.pytorch.ops.Sequential: # Add sm_margin if available (TE 2.5+) if hasattr(self, '_sm_margins'): kwargs["sm_margin"] = self._sm_margins + + if self.quantize: + fused_impl.append(te.ops.Quantize(forward=False, backward=True)) rmsnorm_op = te.pytorch.ops.RMSNorm(self.weight.shape, **kwargs) @@ -487,6 +491,9 @@ def _make_fused_impl(self) -> te.pytorch.ops.Sequential: fused_impl.append(rmsnorm_op) + if self.quantize: + fused_impl.append(te.ops.Quantize(forward=True, backward=False)) + self._register_hooks_on_fused_impl(fused_impl) return fused_impl @@ -635,6 +642,7 @@ def __new__( eps=eps, sequence_parallel=config.sequence_parallel, zero_centered_gamma=config.layernorm_zero_centered_gamma, + quantize=config.fp8 or config.fp4, **extra_te_kwargs, ) else: From 5d1c460be59349ce8fd505f2cbd7e43376b2d170 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Thu, 12 Feb 2026 11:10:09 +0100 Subject: [PATCH 07/30] add rmsnorm residual fusion test --- .../fusions/test_rmsnorm_residual_fusion.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py diff --git a/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py b/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py new file mode 100644 index 00000000000..cd6a2dba4f4 --- /dev/null +++ b/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py @@ -0,0 +1,38 @@ +import pytest +import torch + +from megatron.core.extensions.transformer_engine import TEFusedResidualRMSNorm +from transformer_engine.pytorch import RMSNorm + +def baseline_rmsnorm_residual(x, rmsnorm: RMSNorm): + return x, rmsnorm(x) + +@pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("normalized_shape", [256, 256*2, 256*4]) +def test_rmsnorm_residual_fusion(input_dtype, normalized_shape): + x_baseline = torch.randn(16, 32, 1024, dtype=input_dtype, device="cuda") + x_baseline.requires_grad = True + x_fused = x_baseline.detach() + x_fused.requires_grad = True + baseline_rmsnorm = RMSNorm(normalized_shape=normalized_shape).cuda() + fused_rmsnorm = TEFusedResidualRMSNorm(normalized_shape=normalized_shape, quantize=False).cuda() + + # baseline + baseline_y, baseline_residual = baseline_rmsnorm_residual(x_baseline, baseline_rmsnorm) + baseline_loss = baseline_y.sum() + baseline_residual.sum() + baseline_loss.backward() + + # fused + fused_y, fused_residual = fused_rmsnorm(x_fused) + fused_loss = fused_y.sum() + fused_residual.sum() + fused_loss.backward() + + # Use tolerances appropriate for dtype (pattern from other tests) + tols = dict(rtol=1e-6, atol=1e-6) if input_dtype is torch.float32 else dict(rtol=2e-2, atol=1e-2) + + assert fused_y.dtype == baseline_y.dtype + assert torch.allclose(fused_y, baseline_y, **tols) + assert fused_residual.dtype == baseline_residual.dtype + assert torch.allclose(fused_residual, baseline_residual, **tols) + assert x_fused.grad.dtype == x_baseline.grad.dtype + assert torch.allclose(x_baseline.grad, x_fused.grad, **tols) From c58a797e843261cde447f1847331e6aeb977d3ea Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Thu, 12 Feb 2026 11:10:09 +0100 Subject: [PATCH 08/30] fix tests --- tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py b/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py index cd6a2dba4f4..59ad8212eb8 100644 --- a/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py +++ b/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py @@ -5,17 +5,17 @@ from transformer_engine.pytorch import RMSNorm def baseline_rmsnorm_residual(x, rmsnorm: RMSNorm): - return x, rmsnorm(x) + return rmsnorm(x), x @pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float32]) @pytest.mark.parametrize("normalized_shape", [256, 256*2, 256*4]) def test_rmsnorm_residual_fusion(input_dtype, normalized_shape): - x_baseline = torch.randn(16, 32, 1024, dtype=input_dtype, device="cuda") + x_baseline = torch.randn(16, 32, normalized_shape, dtype=input_dtype, device="cuda") x_baseline.requires_grad = True x_fused = x_baseline.detach() x_fused.requires_grad = True - baseline_rmsnorm = RMSNorm(normalized_shape=normalized_shape).cuda() - fused_rmsnorm = TEFusedResidualRMSNorm(normalized_shape=normalized_shape, quantize=False).cuda() + baseline_rmsnorm = RMSNorm(normalized_shape=normalized_shape, dtype=input_dtype).cuda() + fused_rmsnorm = TEFusedResidualRMSNorm(normalized_shape=normalized_shape, dtype=input_dtype, quantize=False).cuda() # baseline baseline_y, baseline_residual = baseline_rmsnorm_residual(x_baseline, baseline_rmsnorm) From 5d2bddc7995a539fed1f475df8fbe60ea3508ec1 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Thu, 12 Feb 2026 11:10:09 +0100 Subject: [PATCH 09/30] dont use residual_add when not necessary --- .../core/extensions/transformer_engine.py | 23 +++++++++++++++---- .../core/transformer/transformer_layer.py | 3 +++ 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index ca286a1f44e..2520ea9e85f 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -605,19 +605,32 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens class TENorm: """A conditional wrapper to initialize an instance of - Transformer-Engine's `LayerNorm` or `RMSNorm` based on input.""" + Transformer-Engine's `LayerNorm` or `RMSNorm` based on input. + + Residual Fusion Design: + ---------------------- + Residual fusion is a two-level opt-in mechanism: + + 1. Global capability: config.fused_residual_rmsnorm must be True (enables the feature) + 2. Local intent: has_residual=True must be passed at build site (declares this specific + norm is followed by a residual connection) + + Fusion only happens when BOTH conditions are met. + + """ # TODO should we ditch normalization config and just use spec to choose LayerNorm vs RMSNorm? def __new__( - cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5 - ) -> LayerNormInterface: + cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5, has_residual: bool = False + ): if not HAVE_TE: raise ImportError( "Transformer Engine is not installed. " "Please install it with `pip install transformer-engine`." ) - if config.normalization == "LayerNorm": + if config.fused_residual_rmsnorm and has_residual: + raise ValueError("Fused residual RMSNorm is not supported for LayerNorm") instance = te.pytorch.LayerNorm( hidden_size=hidden_size, eps=eps, @@ -632,7 +645,7 @@ def __new__( extra_te_kwargs = _get_extra_te_kwargs(config) - if config.fused_residual_rmsnorm: + if config.fused_residual_rmsnorm and has_residual: # Use fused residual variant assert TEFusedResidualRMSNorm is not None, ( "TEFusedResidualRMSNorm requires Transformer-Engine >= v1.13.0" diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index a4a72f969d9..b92218d1883 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -308,6 +308,7 @@ def __init__( config=self.config, hidden_size=self.config.hidden_size, eps=self.config.layernorm_epsilon, + has_residual=True, # Followed by self-attention + residual add ) attention_optional_kwargs = {} @@ -337,6 +338,7 @@ def __init__( config=self.config, hidden_size=self.config.hidden_size, eps=self.config.layernorm_epsilon, + has_residual=True, # Followed by cross-attention + residual add ) # [Module 5: CrossAttention] @@ -355,6 +357,7 @@ def __init__( config=self.config, hidden_size=self.config.hidden_size, eps=self.config.layernorm_epsilon, + has_residual=True, # Followed by MLP + residual add ) # [Module 8: MLP block] additional_mlp_kwargs = {} From 6779413bc696a371d3651a01d2767a52eff94c5a Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Thu, 12 Feb 2026 11:40:53 +0100 Subject: [PATCH 10/30] Remove quantization for now --- megatron/core/extensions/transformer_engine.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 2520ea9e85f..fcaed3d33b4 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -450,10 +450,8 @@ class TEFusedResidualRMSNorm(te.pytorch.RMSNorm): Forward pass returns: (normalized_output, residual) """ - def __init__(self, quantize: bool, *args, **kwargs): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.quantize = quantize - # Fused implementation (stored in tuple to avoid submodule registration) self._fused_impl: Optional[Tuple[te.pytorch.ops.Sequential]] = None @@ -482,18 +480,12 @@ def _make_fused_impl(self) -> te.pytorch.ops.Sequential: if hasattr(self, '_sm_margins'): kwargs["sm_margin"] = self._sm_margins - if self.quantize: - fused_impl.append(te.ops.Quantize(forward=False, backward=True)) - rmsnorm_op = te.pytorch.ops.RMSNorm(self.weight.shape, **kwargs) rmsnorm_op.weight = self.weight fused_impl.append(rmsnorm_op) - if self.quantize: - fused_impl.append(te.ops.Quantize(forward=True, backward=False)) - self._register_hooks_on_fused_impl(fused_impl) return fused_impl @@ -655,7 +647,6 @@ def __new__( eps=eps, sequence_parallel=config.sequence_parallel, zero_centered_gamma=config.layernorm_zero_centered_gamma, - quantize=config.fp8 or config.fp4, **extra_te_kwargs, ) else: From d490afe7de2c481ee10ed68b604f285406d3a602 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Thu, 12 Feb 2026 11:56:24 +0100 Subject: [PATCH 11/30] formatting changes --- .../core/extensions/transformer_engine.py | 31 ++++++++++++------- .../core/transformer/transformer_block.py | 4 +-- .../core/transformer/transformer_config.py | 8 +++-- .../core/transformer/transformer_layer.py | 13 ++++++-- .../fusions/test_rmsnorm_residual_fusion.py | 16 +++++++--- 5 files changed, 49 insertions(+), 23 deletions(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index fcaed3d33b4..bfea8f88d5e 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -479,7 +479,7 @@ def _make_fused_impl(self) -> te.pytorch.ops.Sequential: # Add sm_margin if available (TE 2.5+) if hasattr(self, '_sm_margins'): kwargs["sm_margin"] = self._sm_margins - + rmsnorm_op = te.pytorch.ops.RMSNorm(self.weight.shape, **kwargs) rmsnorm_op.weight = self.weight @@ -530,8 +530,9 @@ def forward_pre_hook(module, *_) -> None: ret = hook(submodule, None) if ret is not None: raise RuntimeError( - "TEFusedResidualRMSNorm module does not expose intermediate tensors, but " - "submodule has pre-forward hook that modifies input tensor." + "TEFusedResidualRMSNorm module does not expose " + "intermediate tensors, but submodule has " + "pre-forward hook that modifies input tensor." ) fused_impl.register_forward_pre_hook(forward_pre_hook) @@ -551,8 +552,9 @@ def forward_post_hook(module, *_) -> None: ret = hook(submodule, None, None) if ret is not None: raise RuntimeError( - "TEFusedResidualRMSNorm module does not expose intermediate tensors, but " - "submodule has post-forward hook that modifies output tensor." + "TEFusedResidualRMSNorm module does not expose " + "intermediate tensors, but submodule has " + "post-forward hook that modifies output tensor." ) fused_impl.register_forward_hook(forward_post_hook) @@ -560,11 +562,13 @@ def forward_post_hook(module, *_) -> None: # Backward hooks if backward_pre_hooks: raise RuntimeError( - "TEFusedResidualRMSNorm module does not support submodules with pre-backward hooks" + "TEFusedResidualRMSNorm module does not support " + "submodules with pre-backward hooks" ) if backward_post_hooks: raise RuntimeError( - "TEFusedResidualRMSNorm module does not support submodules with post-backward hooks" + "TEFusedResidualRMSNorm module does not support " + "submodules with post-backward hooks" ) def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -613,7 +617,11 @@ class TENorm: # TODO should we ditch normalization config and just use spec to choose LayerNorm vs RMSNorm? def __new__( - cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5, has_residual: bool = False + cls, + config: TransformerConfig, + hidden_size: int, + eps: float = 1e-5, + has_residual: bool = False, ): if not HAVE_TE: raise ImportError( @@ -639,9 +647,9 @@ def __new__( if config.fused_residual_rmsnorm and has_residual: # Use fused residual variant - assert TEFusedResidualRMSNorm is not None, ( - "TEFusedResidualRMSNorm requires Transformer-Engine >= v1.13.0" - ) + assert ( + TEFusedResidualRMSNorm is not None + ), "TEFusedResidualRMSNorm requires Transformer-Engine >= v1.13.0" instance = TEFusedResidualRMSNorm( normalized_shape=hidden_size, eps=eps, @@ -2399,6 +2407,7 @@ def forward(self, hidden_states: torch.Tensor, **kwargs) -> Tuple[Tensor, Option else: TEFusedMLP = None # type: ignore[assignment, misc] + class TEDelayedScaling(te.common.recipe.DelayedScaling): """ Wrapper for the Transformer-Engine's `DelayedScaling` layer. diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index 6c3a97a19ce..a5a80cf8943 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -2,7 +2,7 @@ import logging from contextlib import nullcontext from dataclasses import dataclass -from typing import List, Optional, Set, Union, cast +from typing import List, Optional, Set, Union import torch from torch import Tensor @@ -28,7 +28,7 @@ get_transformer_layer_offset, ) from megatron.core.transformer.utils import sharded_state_dict_default -from megatron.core.typed_torch import apply_module, not_none +from megatron.core.typed_torch import not_none from megatron.core.utils import ( WrappedTensor, deprecate_inference_params, diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 7fe340c9e25..8ef3cff98c7 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1641,11 +1641,13 @@ def __post_init__(self): "If you use bias in MLP FC1, we recommend setting bias_activation_fusion " "to True and use_te_activation_func to False." ) - + if self.fused_residual_rmsnorm: if self.normalization != "RMSNorm": - raise ValueError("fused_residual_rmsnorm is only supported when normalization is RMSNorm.") - + raise ValueError( + "fused_residual_rmsnorm is only supported when normalization is RMSNorm." + ) + if self.use_te_activation_func: if self.activation_func not in (F.gelu, F.silu, F.relu): raise ValueError( diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index b92218d1883..381db1bbe0b 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -589,7 +589,11 @@ def _forward_attention( if isinstance(input_layernorm_output, tuple): if len(input_layernorm_output) != 2: - raise ValueError(f"When the output of input_layernorm is a tuple, it is expected to have 2 elements (output, residual), but got {len(input_layernorm_output)}") + raise ValueError( + f"When the output of input_layernorm is a tuple, it is " + f"expected to have 2 elements (output, residual), but " + f"got {len(input_layernorm_output)}" + ) input_layernorm_output, residual = input_layernorm_output else: residual = hidden_states @@ -656,7 +660,12 @@ def _forward_attention( if isinstance(pre_cross_attn_layernorm_output, tuple): if len(pre_cross_attn_layernorm_output) != 2: - raise ValueError(f"When the output of pre_cross_attn_layernorm_output is a tuple, it is expected to have 2 elements (output, residual), but got {len(pre_cross_attn_layernorm_output)}") + raise ValueError( + f"When the output of pre_cross_attn_layernorm_output " + f"is a tuple, it is expected to have 2 elements " + f"(output, residual), but " + f"got {len(pre_cross_attn_layernorm_output)}" + ) pre_cross_attn_layernorm_output, residual = pre_cross_attn_layernorm_output else: residual = hidden_states diff --git a/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py b/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py index 59ad8212eb8..a7504a6e8d0 100644 --- a/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py +++ b/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py @@ -1,22 +1,26 @@ import pytest import torch +from transformer_engine.pytorch import RMSNorm from megatron.core.extensions.transformer_engine import TEFusedResidualRMSNorm -from transformer_engine.pytorch import RMSNorm + def baseline_rmsnorm_residual(x, rmsnorm: RMSNorm): return rmsnorm(x), x + @pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float32]) -@pytest.mark.parametrize("normalized_shape", [256, 256*2, 256*4]) +@pytest.mark.parametrize("normalized_shape", [256, 256 * 2, 256 * 4]) def test_rmsnorm_residual_fusion(input_dtype, normalized_shape): x_baseline = torch.randn(16, 32, normalized_shape, dtype=input_dtype, device="cuda") x_baseline.requires_grad = True x_fused = x_baseline.detach() x_fused.requires_grad = True baseline_rmsnorm = RMSNorm(normalized_shape=normalized_shape, dtype=input_dtype).cuda() - fused_rmsnorm = TEFusedResidualRMSNorm(normalized_shape=normalized_shape, dtype=input_dtype, quantize=False).cuda() - + fused_rmsnorm = TEFusedResidualRMSNorm( + normalized_shape=normalized_shape, dtype=input_dtype, quantize=False + ).cuda() + # baseline baseline_y, baseline_residual = baseline_rmsnorm_residual(x_baseline, baseline_rmsnorm) baseline_loss = baseline_y.sum() + baseline_residual.sum() @@ -28,7 +32,9 @@ def test_rmsnorm_residual_fusion(input_dtype, normalized_shape): fused_loss.backward() # Use tolerances appropriate for dtype (pattern from other tests) - tols = dict(rtol=1e-6, atol=1e-6) if input_dtype is torch.float32 else dict(rtol=2e-2, atol=1e-2) + tols = ( + dict(rtol=1e-6, atol=1e-6) if input_dtype is torch.float32 else dict(rtol=2e-2, atol=1e-2) + ) assert fused_y.dtype == baseline_y.dtype assert torch.allclose(fused_y, baseline_y, **tols) From 7dac7849a89fdfc46676af8927857434cb07405b Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Thu, 12 Feb 2026 12:15:47 +0100 Subject: [PATCH 12/30] add check tuple has len 2 to pre_mlp_layernorm --- megatron/core/transformer/transformer_block.py | 1 - megatron/core/transformer/transformer_layer.py | 6 ++++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index a5a80cf8943..81ab34552b0 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -28,7 +28,6 @@ get_transformer_layer_offset, ) from megatron.core.transformer.utils import sharded_state_dict_default -from megatron.core.typed_torch import not_none from megatron.core.utils import ( WrappedTensor, deprecate_inference_params, diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 381db1bbe0b..7da4b15cba3 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -1175,7 +1175,13 @@ def _te_cuda_graph_replay(self, *args, **kwargs): return residual, None, None, None hidden_states = apply_module(self.pre_mlp_layernorm)(residual) if isinstance(hidden_states, tuple): + if len(hidden_states) != 2: + raise ValueError( + f"When the output of pre_mlp_layernorm is a tuple, it is expected to have " + f"2 elements (output, residual), but got {len(hidden_states)}" + ) hidden_states, residual = hidden_states + shared_expert_output = self.mlp.shared_experts_compute(hidden_states) probs, routing_map = self.mlp.route(hidden_states) hidden_states, probs = self.mlp.preprocess(hidden_states, probs, routing_map) From 678b8e973c3ea6747a5fe3b89261134405d070f9 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Thu, 12 Feb 2026 12:18:34 +0100 Subject: [PATCH 13/30] fix formatting --- megatron/core/transformer/attention.py | 8 ++++++-- megatron/core/transformer/transformer_layer.py | 7 ++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 28e3dde01c4..653f98a0aa2 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -60,7 +60,9 @@ rearrange = None try: - from flash_attn_3.flash_attn_interface import _flash_attn_forward + from flash_attn_3.flash_attn_interface import ( + _flash_attn_forward, + ) from flash_attn_3.flash_attn_interface import ( flash_attn_with_kvcache as flash_attn3_with_kvcache, ) @@ -71,7 +73,9 @@ if not HAVE_FA3: try: - from flashattn_hopper.flash_attn_interface import _flash_attn_forward + from flashattn_hopper.flash_attn_interface import ( + _flash_attn_forward, + ) from flashattn_hopper.flash_attn_interface import ( flash_attn_with_kvcache as flash_attn3_with_kvcache, ) diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 7da4b15cba3..6d4b16dde4a 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -1177,11 +1177,12 @@ def _te_cuda_graph_replay(self, *args, **kwargs): if isinstance(hidden_states, tuple): if len(hidden_states) != 2: raise ValueError( - f"When the output of pre_mlp_layernorm is a tuple, it is expected to have " - f"2 elements (output, residual), but got {len(hidden_states)}" + f"When the output of pre_mlp_layernorm is a tuple,\ + it is expected to have 2 elements (output, residual),\ + but got {len(hidden_states)}" ) hidden_states, residual = hidden_states - + shared_expert_output = self.mlp.shared_experts_compute(hidden_states) probs, routing_map = self.mlp.route(hidden_states) hidden_states, probs = self.mlp.preprocess(hidden_states, probs, routing_map) From f366acb1e23bb33cd8fc6affb9c6e9cb147c1ed9 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Fri, 13 Feb 2026 11:27:08 +0100 Subject: [PATCH 14/30] Add checks for tuple length in MultiTokenPredictionLayer and Transformer classes --- .../transformer/multi_token_prediction.py | 18 +++++++++ .../core/transformer/transformer_block.py | 9 ++++- .../core/transformer/transformer_layer.py | 40 ++++++++++++------- 3 files changed, 51 insertions(+), 16 deletions(-) diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py index f7928f43140..afbeb9816a1 100755 --- a/megatron/core/transformer/multi_token_prediction.py +++ b/megatron/core/transformer/multi_token_prediction.py @@ -898,10 +898,22 @@ def _concat_embeddings(self, hidden_states: torch.Tensor, decoder_input: torch.T """ decoder_input = apply_module(self.enorm)(decoder_input) if isinstance(decoder_input, tuple): + if len(decoder_input) != 2: + raise ValueError( + f"When the output of enorm is a tuple, it is " + f"expected to have 2 elements (output, residual), but " + f"got {len(decoder_input)}" + ) decoder_input = decoder_input[0] decoder_input = make_viewless_tensor(inp=decoder_input, requires_grad=True, keep_graph=True) hidden_states = apply_module(self.hnorm)(hidden_states) if isinstance(hidden_states, tuple): + if len(hidden_states) != 2: + raise ValueError( + f"When the output of hnorm is a tuple, it is " + f"expected to have 2 elements (output, residual), but " + f"got {len(hidden_states)}" + ) hidden_states = hidden_states[0] hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) # At the (k - 1)-th MTP module, concatenates the i-th token's hidden_states @@ -996,6 +1008,12 @@ def _postprocess(self, hidden_states: torch.Tensor): # Layer norm before shared head layer. hidden_states = apply_module(self.final_layernorm)(hidden_states) if isinstance(hidden_states, tuple): + if len(hidden_states) != 2: + raise ValueError( + f"When the output of final_layernorm is a tuple, it is " + f"expected to have 2 elements (output, residual), but " + f"got {len(hidden_states)}" + ) hidden_states = hidden_states[0] # TENorm produces a "viewed" tensor. This will result in schedule.py's # deallocate_output_tensor() throwing an error, so a viewless tensor is diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index 81ab34552b0..d9d52013320 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -855,9 +855,14 @@ def forward( # Final layer norm. if self.final_layernorm is not None: hidden_states = self.final_layernorm(hidden_states) - # Handle fused residual normalization (returns tuple of (output, residual)) - # For final layernorm, we only need the normalized output, not the residual if isinstance(hidden_states, tuple): + if len(hidden_states) != 2: + raise ValueError( + f"When the output of final_layernorm is a tuple, it is " + f"expected to have 2 elements (output, residual), but " + f"got {len(hidden_states)}" + ) + # For final layernorm, we only need the normalized output, not the residual hidden_states = hidden_states[0] # TENorm produces a "viewed" tensor. This will result in schedule.py's # deallocate_output_tensor() throwing an error, so a viewless tensor is diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 6d4b16dde4a..20281513ae8 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -746,22 +746,23 @@ def _forward_mlp( output (Tensor): Transformed hidden states of shape [s, b, h]. """ - # Residual connection. - residual = hidden_states - if self.config.fp32_residual_connection: - residual = residual.float() - # Optional Layer norm post the cross-attention. pre_mlp_layernorm_output = self._forward_pre_mlp_layernorm(hidden_states) - # Handle fused residual normalization (returns tuple of (output, residual)) if isinstance(pre_mlp_layernorm_output, tuple): if len(pre_mlp_layernorm_output) != 2: raise ValueError( - f"When the output of pre_mlp_layernorm is a tuple, it is expected to have " - f"2 elements (output, residual), but got {len(pre_mlp_layernorm_output)}" + f"When the output of pre_mlp_layernorm is a tuple, it is " + f"expected to have 2 elements (output, residual), but " + f"got {len(pre_mlp_layernorm_output)}" ) pre_mlp_layernorm_output, residual = pre_mlp_layernorm_output + else: + # Residual connection. + residual = hidden_states + + if self.config.fp32_residual_connection: + residual = residual.float() nvtx_range_push(suffix="mlp") # Potentially chunk the MLP computation during prefill to minimize the peak activation size @@ -1177,9 +1178,9 @@ def _te_cuda_graph_replay(self, *args, **kwargs): if isinstance(hidden_states, tuple): if len(hidden_states) != 2: raise ValueError( - f"When the output of pre_mlp_layernorm is a tuple,\ - it is expected to have 2 elements (output, residual),\ - but got {len(hidden_states)}" + f"When the output of pre_mlp_layernorm is a tuple, it is " + f"expected to have 2 elements (output, residual), but " + f"got {len(hidden_states)}" ) hidden_states, residual = hidden_states @@ -1344,11 +1345,22 @@ def _forward_mlp_router(self, hidden_states, padding_mask=None): This method is isolated so it can be captured by `cudagraph_manager_router`. """ - residual = hidden_states - if self.config.fp32_residual_connection: - residual = residual.float() self.mlp.fwd_execution_map = "route" pre_mlp_layernorm_output = self._forward_pre_mlp_layernorm(hidden_states) + if isinstance(pre_mlp_layernorm_output, tuple): + if len(pre_mlp_layernorm_output) != 2: + raise ValueError( + f"When the output of pre_mlp_layernorm is a tuple, it is " + f"expected to have 2 elements (output, residual), but " + f"got {len(pre_mlp_layernorm_output)}" + ) + pre_mlp_layernorm_output, residual = pre_mlp_layernorm_output + else: + residual = hidden_states + + if self.config.fp32_residual_connection: + residual = residual.float() + router_outputs = self.mlp( pre_mlp_layernorm_output, intermediate_tensors=(), padding_mask=padding_mask ) From bc1cf5b843267acb0d744c9dce54eaddf59508ad Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Tue, 17 Feb 2026 13:51:39 +0100 Subject: [PATCH 15/30] Revert changes to attention.py --- megatron/core/transformer/attention.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 653f98a0aa2..28e3dde01c4 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -60,9 +60,7 @@ rearrange = None try: - from flash_attn_3.flash_attn_interface import ( - _flash_attn_forward, - ) + from flash_attn_3.flash_attn_interface import _flash_attn_forward from flash_attn_3.flash_attn_interface import ( flash_attn_with_kvcache as flash_attn3_with_kvcache, ) @@ -73,9 +71,7 @@ if not HAVE_FA3: try: - from flashattn_hopper.flash_attn_interface import ( - _flash_attn_forward, - ) + from flashattn_hopper.flash_attn_interface import _flash_attn_forward from flashattn_hopper.flash_attn_interface import ( flash_attn_with_kvcache as flash_attn3_with_kvcache, ) From 934d02ddb6ef2636deb6431f434a43d95aea99c7 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Tue, 17 Feb 2026 18:05:22 +0100 Subject: [PATCH 16/30] remove unnecessary unpacking --- .../transformer/multi_token_prediction.py | 24 ------------------- .../core/transformer/transformer_block.py | 12 ++-------- 2 files changed, 2 insertions(+), 34 deletions(-) diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py index afbeb9816a1..4ad2e517cfc 100755 --- a/megatron/core/transformer/multi_token_prediction.py +++ b/megatron/core/transformer/multi_token_prediction.py @@ -897,24 +897,8 @@ def _concat_embeddings(self, hidden_states: torch.Tensor, decoder_input: torch.T Concatenate the tokens before sending to transformer layer. """ decoder_input = apply_module(self.enorm)(decoder_input) - if isinstance(decoder_input, tuple): - if len(decoder_input) != 2: - raise ValueError( - f"When the output of enorm is a tuple, it is " - f"expected to have 2 elements (output, residual), but " - f"got {len(decoder_input)}" - ) - decoder_input = decoder_input[0] decoder_input = make_viewless_tensor(inp=decoder_input, requires_grad=True, keep_graph=True) hidden_states = apply_module(self.hnorm)(hidden_states) - if isinstance(hidden_states, tuple): - if len(hidden_states) != 2: - raise ValueError( - f"When the output of hnorm is a tuple, it is " - f"expected to have 2 elements (output, residual), but " - f"got {len(hidden_states)}" - ) - hidden_states = hidden_states[0] hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) # At the (k - 1)-th MTP module, concatenates the i-th token's hidden_states # and the (i + K)-th token's embedding, and combine them with linear projection. @@ -1007,14 +991,6 @@ def _postprocess(self, hidden_states: torch.Tensor): # Layer norm before shared head layer. hidden_states = apply_module(self.final_layernorm)(hidden_states) - if isinstance(hidden_states, tuple): - if len(hidden_states) != 2: - raise ValueError( - f"When the output of final_layernorm is a tuple, it is " - f"expected to have 2 elements (output, residual), but " - f"got {len(hidden_states)}" - ) - hidden_states = hidden_states[0] # TENorm produces a "viewed" tensor. This will result in schedule.py's # deallocate_output_tensor() throwing an error, so a viewless tensor is # created to prevent this. diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index d9d52013320..ea3f33d3ee5 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -28,6 +28,7 @@ get_transformer_layer_offset, ) from megatron.core.transformer.utils import sharded_state_dict_default +from megatron.core.typed_torch import apply_module, not_none from megatron.core.utils import ( WrappedTensor, deprecate_inference_params, @@ -854,16 +855,7 @@ def forward( # Final layer norm. if self.final_layernorm is not None: - hidden_states = self.final_layernorm(hidden_states) - if isinstance(hidden_states, tuple): - if len(hidden_states) != 2: - raise ValueError( - f"When the output of final_layernorm is a tuple, it is " - f"expected to have 2 elements (output, residual), but " - f"got {len(hidden_states)}" - ) - # For final layernorm, we only need the normalized output, not the residual - hidden_states = hidden_states[0] + hidden_states = apply_module(self.final_layernorm)(cast(Tensor, hidden_states)) # TENorm produces a "viewed" tensor. This will result in schedule.py's # deallocate_output_tensor() throwing an error, so a viewless tensor is # created to prevent this. From 2d9ba48b4f057bf329aa4b4961b6f40a1dfb5bf6 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Mon, 23 Feb 2026 14:09:46 +0100 Subject: [PATCH 17/30] guard has_residual behind TENorm check --- .../core/transformer/transformer_layer.py | 34 +++++++++++-------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 20281513ae8..bd0f83000c3 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -302,13 +302,23 @@ def __init__( self.hidden_dropout = config.hidden_dropout if hidden_dropout is None else hidden_dropout self.is_mtp_layer = is_mtp_layer + + # import here to avoid circular import + from megatron.core.extensions.transformer_engine import TENorm + def _build_layernorm(builder: LayerNormBuilder, has_residual_connection: bool): + norm_kwargs: Dict[str, Any] = { + "config": self.config, + "hidden_size": self.config.hidden_size, + "eps": self.config.layernorm_epsilon, + } + if has_residual_connection and builder is TENorm: + norm_kwargs["has_residual"] = True + return builder(**norm_kwargs) + # [Module 1: Input Layernorm] Optional Layernorm on the input data # TODO: add pytorch only layernorm - self.input_layernorm = submodules.input_layernorm( - config=self.config, - hidden_size=self.config.hidden_size, - eps=self.config.layernorm_epsilon, - has_residual=True, # Followed by self-attention + residual add + self.input_layernorm = _build_layernorm( + submodules.input_layernorm, has_residual_connection=True ) attention_optional_kwargs = {} @@ -334,11 +344,8 @@ def __init__( self.self_attn_bda = build_module(submodules.self_attn_bda) # [Module 4: Post SelfAttention] Optional Layernorm after self-attn - self.pre_cross_attn_layernorm = submodules.pre_cross_attn_layernorm( - config=self.config, - hidden_size=self.config.hidden_size, - eps=self.config.layernorm_epsilon, - has_residual=True, # Followed by cross-attention + residual add + self.pre_cross_attn_layernorm = _build_layernorm( + submodules.pre_cross_attn_layernorm, has_residual_connection=True ) # [Module 5: CrossAttention] @@ -353,11 +360,8 @@ def __init__( self.cross_attn_bda = build_module(submodules.cross_attn_bda, config=self.config) # [Module 7: Pre MLP] Optional Layernorm before MLP - self.pre_mlp_layernorm = submodules.pre_mlp_layernorm( - config=self.config, - hidden_size=self.config.hidden_size, - eps=self.config.layernorm_epsilon, - has_residual=True, # Followed by MLP + residual add + self.pre_mlp_layernorm = _build_layernorm( + submodules.pre_mlp_layernorm, has_residual_connection=True ) # [Module 8: MLP block] additional_mlp_kwargs = {} From 818e0e732389da99667521b483106dbfa3b7d570 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Mon, 23 Feb 2026 14:44:05 +0100 Subject: [PATCH 18/30] autoformat --- megatron/core/transformer/transformer_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index bd0f83000c3..fbb1f8e3604 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -302,9 +302,9 @@ def __init__( self.hidden_dropout = config.hidden_dropout if hidden_dropout is None else hidden_dropout self.is_mtp_layer = is_mtp_layer - # import here to avoid circular import from megatron.core.extensions.transformer_engine import TENorm + def _build_layernorm(builder: LayerNormBuilder, has_residual_connection: bool): norm_kwargs: Dict[str, Any] = { "config": self.config, From d1f59cc74b5b82878ebfda1b3a5f959f01953546 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Mon, 23 Feb 2026 15:01:28 +0100 Subject: [PATCH 19/30] add missing copyright header --- tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py b/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py index a7504a6e8d0..324c162186d 100644 --- a/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py +++ b/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py @@ -1,3 +1,5 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + import pytest import torch from transformer_engine.pytorch import RMSNorm From a1ad51e2815e21cddae66708f49036e3039adef0 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Mon, 23 Feb 2026 19:19:42 +0100 Subject: [PATCH 20/30] remove quantize arg from test --- tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py b/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py index 324c162186d..6c03e0fa801 100644 --- a/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py +++ b/tests/unit_tests/fusions/test_rmsnorm_residual_fusion.py @@ -20,7 +20,7 @@ def test_rmsnorm_residual_fusion(input_dtype, normalized_shape): x_fused.requires_grad = True baseline_rmsnorm = RMSNorm(normalized_shape=normalized_shape, dtype=input_dtype).cuda() fused_rmsnorm = TEFusedResidualRMSNorm( - normalized_shape=normalized_shape, dtype=input_dtype, quantize=False + normalized_shape=normalized_shape, dtype=input_dtype ).cuda() # baseline From c5c9b25f6109960fb87de1e5409ed96add490f21 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Mon, 23 Feb 2026 19:30:19 +0100 Subject: [PATCH 21/30] add arg to golden_dict --- tests/unit_tests/models/test_mamba_moe_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit_tests/models/test_mamba_moe_model.py b/tests/unit_tests/models/test_mamba_moe_model.py index 98c6ac63e0e..7df49e99394 100644 --- a/tests/unit_tests/models/test_mamba_moe_model.py +++ b/tests/unit_tests/models/test_mamba_moe_model.py @@ -113,6 +113,7 @@ "fp8_quantizer_factory": None, "fp8_recipe": "delayed", "fp8_wgrad": True, + "fused_residual_rmsnorm": False, "fused_single_qkv_rope": False, "gated_linear_unit": False, "glu_linear_offset": 0.0, From e89bb23f969951da184d0a27abc843ca5dc3764d Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Tue, 24 Feb 2026 11:54:08 +0100 Subject: [PATCH 22/30] compact TENorm --- .../core/extensions/transformer_engine.py | 53 ++++++++----------- 1 file changed, 23 insertions(+), 30 deletions(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index bfea8f88d5e..f7fd4c481ea 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -628,46 +628,39 @@ def __new__( "Transformer Engine is not installed. " "Please install it with `pip install transformer-engine`." ) - if config.normalization == "LayerNorm": - if config.fused_residual_rmsnorm and has_residual: - raise ValueError("Fused residual RMSNorm is not supported for LayerNorm") - instance = te.pytorch.LayerNorm( - hidden_size=hidden_size, - eps=eps, - sequence_parallel=config.sequence_parallel, - zero_centered_gamma=config.layernorm_zero_centered_gamma, - **_get_extra_te_kwargs(config), + + use_fused_residual = config.fused_residual_rmsnorm and has_residual + if use_fused_residual and config.normalization != "RMSNorm": + raise ValueError( + "Fused residual is only supported " + "for RMSNorm normalization" ) + + if config.normalization == "LayerNorm": + norm_module = te.pytorch.LayerNorm elif config.normalization == "RMSNorm": assert hasattr( te.pytorch, "RMSNorm" ), "Transformer-Engine >= v0.11 required to use this feature" - - extra_te_kwargs = _get_extra_te_kwargs(config) - - if config.fused_residual_rmsnorm and has_residual: - # Use fused residual variant + if use_fused_residual: assert ( TEFusedResidualRMSNorm is not None ), "TEFusedResidualRMSNorm requires Transformer-Engine >= v1.13.0" - instance = TEFusedResidualRMSNorm( - normalized_shape=hidden_size, - eps=eps, - sequence_parallel=config.sequence_parallel, - zero_centered_gamma=config.layernorm_zero_centered_gamma, - **extra_te_kwargs, - ) + norm_module = TEFusedResidualRMSNorm else: - # Standard RMSNorm without fusion - instance = te.pytorch.RMSNorm( - normalized_shape=hidden_size, - eps=eps, - sequence_parallel=config.sequence_parallel, - zero_centered_gamma=config.layernorm_zero_centered_gamma, - **extra_te_kwargs, - ) + norm_module = te.pytorch.RMSNorm else: - raise Exception("Only LayerNorm and RMSNorm are curently supported") + raise Exception( + "Only LayerNorm and RMSNorm are currently supported" + ) + + instance = norm_module( + normalized_shape=hidden_size, + eps=eps, + sequence_parallel=config.sequence_parallel, + zero_centered_gamma=config.layernorm_zero_centered_gamma, + **_get_extra_te_kwargs(config), + ) return cast(LayerNormInterface, instance) From 650355f8d912cc70c4b2c21cf257ffdbd3e916aa Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Tue, 24 Feb 2026 20:38:35 +0100 Subject: [PATCH 23/30] format --- megatron/core/extensions/transformer_engine.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index f7fd4c481ea..aae3c80630e 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -631,11 +631,8 @@ def __new__( use_fused_residual = config.fused_residual_rmsnorm and has_residual if use_fused_residual and config.normalization != "RMSNorm": - raise ValueError( - "Fused residual is only supported " - "for RMSNorm normalization" - ) - + raise ValueError("Fused residual is only supported " "for RMSNorm normalization") + if config.normalization == "LayerNorm": norm_module = te.pytorch.LayerNorm elif config.normalization == "RMSNorm": @@ -650,10 +647,8 @@ def __new__( else: norm_module = te.pytorch.RMSNorm else: - raise Exception( - "Only LayerNorm and RMSNorm are currently supported" - ) - + raise Exception("Only LayerNorm and RMSNorm are currently supported") + instance = norm_module( normalized_shape=hidden_size, eps=eps, From 18666186bb7e50173e08371fedf8787ad61e5dfc Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Mon, 2 Mar 2026 21:36:35 +0100 Subject: [PATCH 24/30] leverage spec to simplify build of has_residual layernorm --- .../transformer_engine_spec_provider.py | 4 +-- megatron/core/models/backends.py | 6 ++-- megatron/core/models/gpt/gpt_layer_specs.py | 12 +++---- .../core/transformer/transformer_layer.py | 31 +++++++------------ 4 files changed, 23 insertions(+), 30 deletions(-) diff --git a/megatron/core/extensions/transformer_engine_spec_provider.py b/megatron/core/extensions/transformer_engine_spec_provider.py index a445eacdfe0..872acea9601 100644 --- a/megatron/core/extensions/transformer_engine_spec_provider.py +++ b/megatron/core/extensions/transformer_engine_spec_provider.py @@ -52,14 +52,14 @@ def column_parallel_layer_norm_linear(self) -> Optional[type]: """Which module for sequential layernorm and linear""" return TELayerNormColumnParallelLinear - def layer_norm(self, rms_norm: bool = False, for_qk: bool = False) -> LayerNormBuilder: + def layer_norm(self, rms_norm: bool = False, for_qk: bool = False, has_residual: bool = False) -> LayerNormBuilder: """Which module to use for layer norm""" if for_qk and not is_te_min_version("1.9.0"): # TENorm significantly harms convergence when used # for QKLayerNorm if TE Version < 1.9; # we instead use the Apex implementation. return FusedLayerNorm - return TENorm + return lambda *args, **kwargs: TENorm(*args, has_residual=has_residual, **kwargs) def core_attention(self) -> type: """Which module to use for attention""" diff --git a/megatron/core/models/backends.py b/megatron/core/models/backends.py index ebb979772f0..ad3fde44ceb 100644 --- a/megatron/core/models/backends.py +++ b/megatron/core/models/backends.py @@ -73,7 +73,7 @@ def column_parallel_layer_norm_linear(self) -> Optional[type]: ... @abstractmethod - def layer_norm(self, rms_norm: bool = False, for_qk: bool = False) -> LayerNormBuilder: + def layer_norm(self, rms_norm: bool = False, for_qk: bool = False, has_residual: bool = False) -> LayerNormBuilder: """Which module for layernorm""" ... @@ -114,7 +114,7 @@ def column_parallel_layer_norm_linear(self) -> Optional[type]: """Which module for sequential layernorm and linear""" return None - def layer_norm(self, rms_norm: bool = False, for_qk: bool = False) -> LayerNormBuilder: + def layer_norm(self, rms_norm: bool = False, for_qk: bool = False, has_residual: bool = False) -> LayerNormBuilder: """Which module to use for layer norm""" if rms_norm: # Matching get_gpt_layer_local_spec. @@ -170,7 +170,7 @@ def column_parallel_layer_norm_linear(self) -> type[InferenceLayerNormColumnPara """Which module for sequential layernorm and linear""" return InferenceLayerNormColumnParallelLinear - def layer_norm(self, rms_norm: bool = False, for_qk: bool = False) -> LayerNormBuilder: + def layer_norm(self, rms_norm: bool = False, for_qk: bool = False, has_residual: bool = False) -> LayerNormBuilder: """Which module to use for layer norm""" if for_qk and not is_te_min_version("1.9.0"): # TENorm significantly harms convergence when used diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py index aae2d5f3e81..67b8aef2602 100755 --- a/megatron/core/models/gpt/gpt_layer_specs.py +++ b/megatron/core/models/gpt/gpt_layer_specs.py @@ -111,7 +111,7 @@ def get_gpt_layer_with_inference_submodules( else backend.column_parallel_linear() ) return TransformerLayerSubmodules( - input_layernorm=backend.layer_norm(), + input_layernorm=backend.layer_norm(has_residual=True), self_attention=ModuleSpec( module=MLASelfAttention, params={"attn_mask_type": AttnMaskType.causal}, @@ -249,7 +249,7 @@ def get_gpt_layer_with_transformer_engine_submodules( else backend.column_parallel_linear() ) return TransformerLayerSubmodules( - input_layernorm=backend.layer_norm(), + input_layernorm=backend.layer_norm(has_residual=True), self_attention=ModuleSpec( module=MLASelfAttention, params={"attn_mask_type": AttnMaskType.causal}, @@ -266,7 +266,7 @@ def get_gpt_layer_with_transformer_engine_submodules( ), ), self_attn_bda=get_bias_dropout_add, - pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp, + pre_mlp_layernorm=backend.layer_norm(has_residual=True) if num_experts else IdentityOp, mlp=mlp, mlp_bda=get_bias_dropout_add, ) @@ -289,7 +289,7 @@ def get_gpt_layer_with_transformer_engine_submodules( ), ), self_attn_bda=get_bias_dropout_add, - pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp, + pre_mlp_layernorm=backend.layer_norm(has_residual=True) if num_experts else IdentityOp, mlp=mlp, mlp_bda=get_bias_dropout_add, sharded_state_dict_keys_map={ @@ -353,10 +353,10 @@ def get_gpt_layer_local_submodules( backend = LocalSpecProvider() # Adjust for RMS norm. if normalization == "RMSNorm": - layer_norm = backend.layer_norm(rms_norm=True, for_qk=False) + layer_norm = backend.layer_norm(rms_norm=True, for_qk=False, has_residual=True) qk_norm = backend.layer_norm(rms_norm=True, for_qk=True) else: - layer_norm = backend.layer_norm(rms_norm=False, for_qk=False) + layer_norm = backend.layer_norm(rms_norm=False, for_qk=False, has_residual=True) qk_norm = backend.layer_norm(rms_norm=False, for_qk=True) if fp8 is not None: diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index fbb1f8e3604..9d14f1fecef 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -302,23 +302,12 @@ def __init__( self.hidden_dropout = config.hidden_dropout if hidden_dropout is None else hidden_dropout self.is_mtp_layer = is_mtp_layer - # import here to avoid circular import - from megatron.core.extensions.transformer_engine import TENorm - - def _build_layernorm(builder: LayerNormBuilder, has_residual_connection: bool): - norm_kwargs: Dict[str, Any] = { - "config": self.config, - "hidden_size": self.config.hidden_size, - "eps": self.config.layernorm_epsilon, - } - if has_residual_connection and builder is TENorm: - norm_kwargs["has_residual"] = True - return builder(**norm_kwargs) - # [Module 1: Input Layernorm] Optional Layernorm on the input data # TODO: add pytorch only layernorm - self.input_layernorm = _build_layernorm( - submodules.input_layernorm, has_residual_connection=True + self.input_layernorm = submodules.input_layernorm( + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, ) attention_optional_kwargs = {} @@ -344,8 +333,10 @@ def _build_layernorm(builder: LayerNormBuilder, has_residual_connection: bool): self.self_attn_bda = build_module(submodules.self_attn_bda) # [Module 4: Post SelfAttention] Optional Layernorm after self-attn - self.pre_cross_attn_layernorm = _build_layernorm( - submodules.pre_cross_attn_layernorm, has_residual_connection=True + self.pre_cross_attn_layernorm = submodules.pre_cross_attn_layernorm( + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, ) # [Module 5: CrossAttention] @@ -360,8 +351,10 @@ def _build_layernorm(builder: LayerNormBuilder, has_residual_connection: bool): self.cross_attn_bda = build_module(submodules.cross_attn_bda, config=self.config) # [Module 7: Pre MLP] Optional Layernorm before MLP - self.pre_mlp_layernorm = _build_layernorm( - submodules.pre_mlp_layernorm, has_residual_connection=True + self.pre_mlp_layernorm = submodules.pre_mlp_layernorm( + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, ) # [Module 8: MLP block] additional_mlp_kwargs = {} From bc0cfe79b17b9e44543afefa0c3244cb57e11f30 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Mon, 2 Mar 2026 21:52:00 +0100 Subject: [PATCH 25/30] format --- .../extensions/transformer_engine_spec_provider.py | 4 +++- megatron/core/models/backends.py | 12 +++++++++--- megatron/core/transformer/transformer_layer.py | 8 ++++---- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/megatron/core/extensions/transformer_engine_spec_provider.py b/megatron/core/extensions/transformer_engine_spec_provider.py index 872acea9601..740116a7115 100644 --- a/megatron/core/extensions/transformer_engine_spec_provider.py +++ b/megatron/core/extensions/transformer_engine_spec_provider.py @@ -52,7 +52,9 @@ def column_parallel_layer_norm_linear(self) -> Optional[type]: """Which module for sequential layernorm and linear""" return TELayerNormColumnParallelLinear - def layer_norm(self, rms_norm: bool = False, for_qk: bool = False, has_residual: bool = False) -> LayerNormBuilder: + def layer_norm( + self, rms_norm: bool = False, for_qk: bool = False, has_residual: bool = False + ) -> LayerNormBuilder: """Which module to use for layer norm""" if for_qk and not is_te_min_version("1.9.0"): # TENorm significantly harms convergence when used diff --git a/megatron/core/models/backends.py b/megatron/core/models/backends.py index ad3fde44ceb..3a57742e459 100644 --- a/megatron/core/models/backends.py +++ b/megatron/core/models/backends.py @@ -73,7 +73,9 @@ def column_parallel_layer_norm_linear(self) -> Optional[type]: ... @abstractmethod - def layer_norm(self, rms_norm: bool = False, for_qk: bool = False, has_residual: bool = False) -> LayerNormBuilder: + def layer_norm( + self, rms_norm: bool = False, for_qk: bool = False, has_residual: bool = False + ) -> LayerNormBuilder: """Which module for layernorm""" ... @@ -114,7 +116,9 @@ def column_parallel_layer_norm_linear(self) -> Optional[type]: """Which module for sequential layernorm and linear""" return None - def layer_norm(self, rms_norm: bool = False, for_qk: bool = False, has_residual: bool = False) -> LayerNormBuilder: + def layer_norm( + self, rms_norm: bool = False, for_qk: bool = False, has_residual: bool = False + ) -> LayerNormBuilder: """Which module to use for layer norm""" if rms_norm: # Matching get_gpt_layer_local_spec. @@ -170,7 +174,9 @@ def column_parallel_layer_norm_linear(self) -> type[InferenceLayerNormColumnPara """Which module for sequential layernorm and linear""" return InferenceLayerNormColumnParallelLinear - def layer_norm(self, rms_norm: bool = False, for_qk: bool = False, has_residual: bool = False) -> LayerNormBuilder: + def layer_norm( + self, rms_norm: bool = False, for_qk: bool = False, has_residual: bool = False + ) -> LayerNormBuilder: """Which module to use for layer norm""" if for_qk and not is_te_min_version("1.9.0"): # TENorm significantly harms convergence when used diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 9d14f1fecef..e435409c16d 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -594,10 +594,10 @@ def _forward_attention( input_layernorm_output, residual = input_layernorm_output else: residual = hidden_states - + if self.config.fp32_residual_connection: residual = residual.float() - + using_fused_tp_inference_kernel = (not self.training) and ( self.config.inference_fuse_tp_communication ) @@ -1354,10 +1354,10 @@ def _forward_mlp_router(self, hidden_states, padding_mask=None): pre_mlp_layernorm_output, residual = pre_mlp_layernorm_output else: residual = hidden_states - + if self.config.fp32_residual_connection: residual = residual.float() - + router_outputs = self.mlp( pre_mlp_layernorm_output, intermediate_tensors=(), padding_mask=padding_mask ) From fac5e6adb08fdc8e6483f49ba3c637768bc09be9 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Mon, 2 Mar 2026 22:25:03 +0100 Subject: [PATCH 26/30] fix rebase --- megatron/core/transformer/transformer_block.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index ea3f33d3ee5..703f6c266f1 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -2,7 +2,7 @@ import logging from contextlib import nullcontext from dataclasses import dataclass -from typing import List, Optional, Set, Union +from typing import List, Optional, Set, Union, cast import torch from torch import Tensor From 4766110168a40155097ab50c71bfa2d9427786fb Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Tue, 3 Mar 2026 10:19:18 +0100 Subject: [PATCH 27/30] fix issue with build_module used with layernorm --- .../extensions/transformer_engine_spec_provider.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/megatron/core/extensions/transformer_engine_spec_provider.py b/megatron/core/extensions/transformer_engine_spec_provider.py index 740116a7115..051517d2ed0 100644 --- a/megatron/core/extensions/transformer_engine_spec_provider.py +++ b/megatron/core/extensions/transformer_engine_spec_provider.py @@ -29,6 +29,13 @@ from megatron.core.utils import get_te_version, is_te_min_version +class _TENormWithResidual: + """Class adapter for TENorm with residual fusion enabled.""" + + def __new__(cls, *args, **kwargs): + return TENorm(*args, has_residual=True, **kwargs) + + class TESpecProvider(BackendSpecProvider): """A protocol for providing the submodules used in Spec building.""" @@ -61,7 +68,8 @@ def layer_norm( # for QKLayerNorm if TE Version < 1.9; # we instead use the Apex implementation. return FusedLayerNorm - return lambda *args, **kwargs: TENorm(*args, has_residual=has_residual, **kwargs) + # Keep returning a class so this path stays aligned with build_module's class handling. + return _TENormWithResidual if has_residual else TENorm def core_attention(self) -> type: """Which module to use for attention""" From 3fcb87800275e519220050b6f38035f4c9290401 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Tue, 3 Mar 2026 10:59:09 +0100 Subject: [PATCH 28/30] fix docs error --- megatron/core/extensions/transformer_engine.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index aae3c80630e..61848d0f379 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -603,8 +603,6 @@ class TENorm: """A conditional wrapper to initialize an instance of Transformer-Engine's `LayerNorm` or `RMSNorm` based on input. - Residual Fusion Design: - ---------------------- Residual fusion is a two-level opt-in mechanism: 1. Global capability: config.fused_residual_rmsnorm must be True (enables the feature) From 71bd0994998864e14871e2bfe72b8dff8202cfb4 Mon Sep 17 00:00:00 2001 From: Eric Harper Date: Thu, 5 Mar 2026 15:32:43 -0700 Subject: [PATCH 29/30] Update megatron/core/transformer/transformer_config.py Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com> --- megatron/core/transformer/transformer_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 8ef3cff98c7..69d95461656 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -446,7 +446,7 @@ class TransformerConfig(ModelParallelConfig): fused_single_qkv_rope: bool = False """If set, avoid splitting QKV before ROPE forward and avoid concatenating ROPE dgrads.""" - fused_residual_rmsnorm: bool = False + """If True, fuses residual connection and RMSNorm when TE is used.""" """If True, uses fuses residual connection and RMSNorm when TE is used.""" #################### From bd7395a7c7b791cab0a0e1b006c8ceac8bb5ce11 Mon Sep 17 00:00:00 2001 From: Eric Harper Date: Thu, 5 Mar 2026 22:45:07 -0700 Subject: [PATCH 30/30] Update megatron/core/transformer/transformer_config.py Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com> --- megatron/core/transformer/transformer_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 02f0e12bdd3..fd4025de9f7 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -446,8 +446,8 @@ class TransformerConfig(ModelParallelConfig): fused_single_qkv_rope: bool = False """If set, avoid splitting QKV before ROPE forward and avoid concatenating ROPE dgrads.""" - """If True, fuses residual connection and RMSNorm when TE is used.""" - """If True, uses fuses residual connection and RMSNorm when TE is used.""" + fused_residual_rmsnorm: bool = False + """If True, fuses residual connection and RMSNorm backward pass when TE is used.""" #################### # activation recomputation