Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
2593e53
add fused te fused layernorm
CarlosGomes98 Jan 26, 2026
d74e006
revert changes to normalization parameter, add fusion flag instead
CarlosGomes98 Feb 12, 2026
6c1364f
Refactor TEFusedResidualRMSNorm properly wrapping it for compatibilit…
CarlosGomes98 Feb 12, 2026
57596cf
add more spots where tuple outputs break mcore
CarlosGomes98 Feb 12, 2026
ba5c4a5
remove excessive comments
CarlosGomes98 Feb 12, 2026
20393df
add quantization
CarlosGomes98 Feb 12, 2026
5d1c460
add rmsnorm residual fusion test
CarlosGomes98 Feb 12, 2026
c58a797
fix tests
CarlosGomes98 Feb 12, 2026
5d2bddc
dont use residual_add when not necessary
CarlosGomes98 Feb 12, 2026
6779413
Remove quantization for now
CarlosGomes98 Feb 12, 2026
d490afe
formatting changes
CarlosGomes98 Feb 12, 2026
7dac784
add check tuple has len 2 to pre_mlp_layernorm
CarlosGomes98 Feb 12, 2026
678b8e9
fix formatting
CarlosGomes98 Feb 12, 2026
f366acb
Add checks for tuple length in MultiTokenPredictionLayer and Transfor…
CarlosGomes98 Feb 13, 2026
bc1cf5b
Revert changes to attention.py
CarlosGomes98 Feb 17, 2026
934d02d
remove unnecessary unpacking
CarlosGomes98 Feb 17, 2026
2d9ba48
guard has_residual behind TENorm check
CarlosGomes98 Feb 23, 2026
818e0e7
autoformat
CarlosGomes98 Feb 23, 2026
d1f59cc
add missing copyright header
CarlosGomes98 Feb 23, 2026
a1ad51e
remove quantize arg from test
CarlosGomes98 Feb 23, 2026
c5c9b25
add arg to golden_dict
CarlosGomes98 Feb 23, 2026
e89bb23
compact TENorm
CarlosGomes98 Feb 24, 2026
650355f
format
CarlosGomes98 Feb 24, 2026
1866618
leverage spec to simplify build of has_residual layernorm
CarlosGomes98 Mar 2, 2026
bc0cfe7
format
CarlosGomes98 Mar 2, 2026
fac5e6a
fix rebase
CarlosGomes98 Mar 2, 2026
4766110
fix issue with build_module used with layernorm
CarlosGomes98 Mar 3, 2026
3fcb878
fix docs error
CarlosGomes98 Mar 3, 2026
71bd099
Update megatron/core/transformer/transformer_config.py
ericharper Mar 5, 2026
c5325b9
Merge branch 'main' into cgomes/ds_fuse_dLN_add
ericharper Mar 6, 2026
bd7395a
Update megatron/core/transformer/transformer_config.py
ericharper Mar 6, 2026
3d5501a
Merge branch 'main' into cgomes/ds_fuse_dLN_add
Phlip79 Mar 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 204 additions & 18 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,41 +433,227 @@ def __new__(cls, config: TransformerConfig):
TEActivationOp = None


if HAVE_TE and is_te_min_version("1.13.0"):

class TEFusedResidualRMSNorm(te.pytorch.RMSNorm):
"""
RMSNorm with fused residual output for Megatron Core.

Inherits from te.pytorch.RMSNorm to maintain all parameter management,
checkpoint compatibility, and Megatron-specific features. Creates a fused
implementation using TE's ops API that shares the base class parameters.

The fused implementation uses:
- MakeExtraOutput: Forks the residual connection
- RMSNorm: Normalizes the main path

Forward pass returns: (normalized_output, residual)
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Fused implementation (stored in tuple to avoid submodule registration)
self._fused_impl: Optional[Tuple[te.pytorch.ops.Sequential]] = None

def _make_fused_impl(self) -> te.pytorch.ops.Sequential:
"""
Construct fused ops pipeline that shares parameters with base RMSNorm.

Creates MakeExtraOutput + RMSNorm ops, where the RMSNorm op shares
the weight parameter with self.weight from the base class.
"""

fused_impl = te.pytorch.ops.Sequential()

# Op 1: MakeExtraOutput - forks the residual
fused_impl.append(te.pytorch.ops.MakeExtraOutput())

# Op 2: RMSNorm - shares weight parameter with self
kwargs = {
"eps": self.eps,
"device": "meta", # Already initialized
"dtype": self.weight.dtype,
"zero_centered_gamma": self.zero_centered_gamma,
}

# Add sm_margin if available (TE 2.5+)
if hasattr(self, '_sm_margins'):
kwargs["sm_margin"] = self._sm_margins

rmsnorm_op = te.pytorch.ops.RMSNorm(self.weight.shape, **kwargs)

rmsnorm_op.weight = self.weight

fused_impl.append(rmsnorm_op)

self._register_hooks_on_fused_impl(fused_impl)

Comment thread
ericharper marked this conversation as resolved.
return fused_impl

def _register_hooks_on_fused_impl(self, fused_impl: torch.nn.Module) -> None:

forward_pre_hooks = []
forward_post_hooks = []
backward_pre_hooks = []
backward_post_hooks = []

for submodule in self.modules():
for hook in submodule._forward_pre_hooks.values():
forward_pre_hooks.append((submodule, hook))
for hook in submodule._forward_hooks.values():
forward_post_hooks.append((submodule, hook))
for hook in submodule._backward_pre_hooks.values():
backward_pre_hooks.append((submodule, hook))
for hook in submodule._backward_hooks.values():
backward_post_hooks.append((submodule, hook))

# Pre-forward hooks
# Note: DDP pre-forward hooks are safe since they do not
# interact with input tensor.
if forward_pre_hooks:
from megatron.core.distributed import distributed_data_parallel

if any(
inspect.getmodule(hook) != distributed_data_parallel
for _, hook in forward_pre_hooks
):
warnings.warn(
"TEFusedResidualRMSNorm module has a submodule with a pre-forward hook. "
"TEFusedResidualRMSNorm module does not expose intermediate tensors, "
"so the hook may have incorrect behavior if it attempts to "
"access the input tensor."
)

def forward_pre_hook(module, *_) -> None:
for submodule, hook in forward_pre_hooks:
# Assume that hook does not interact with input
ret = hook(submodule, None)
if ret is not None:
raise RuntimeError(
"TEFusedResidualRMSNorm module does not expose "
"intermediate tensors, but submodule has "
"pre-forward hook that modifies input tensor."
)

fused_impl.register_forward_pre_hook(forward_pre_hook)

# Post-forward hooks
if forward_post_hooks:
warnings.warn(
"TEFusedResidualRMSNorm module has a submodule with a post-forward hook. "
"TEFusedResidualRMSNorm module does not expose intermediate tensors, "
"so the hook may have incorrect behavior if it attempts to "
"access the input or output tensors."
)

def forward_post_hook(module, *_) -> None:
for submodule, hook in forward_post_hooks:
# Assume that hook does not interact with input or output
ret = hook(submodule, None, None)
if ret is not None:
raise RuntimeError(
"TEFusedResidualRMSNorm module does not expose "
"intermediate tensors, but submodule has "
"post-forward hook that modifies output tensor."
)

fused_impl.register_forward_hook(forward_post_hook)

# Backward hooks
if backward_pre_hooks:
raise RuntimeError(
"TEFusedResidualRMSNorm module does not support "
"submodules with pre-backward hooks"
)
if backward_post_hooks:
raise RuntimeError(
"TEFusedResidualRMSNorm module does not support "
"submodules with post-backward hooks"
)

def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass with fused residual output.

Args:
hidden_states: Input tensor [s, b, h]

Returns:
Tuple of (normalized_output, residual), both [s, b, h]

Note:
Sequential.forward() automatically returns (output, extra_outputs...)
when MakeExtraOutput is present, so we don't need manual unpacking.
"""

# Construct fused impl lazily on first forward
# (in case parameters are modified after __init__)
if self._fused_impl is None:
self._fused_impl = (self._make_fused_impl(),)

# Apply fused implementation
# Sequential returns (normalized_output, residual) automatically
return self._fused_impl[0](hidden_states)

else:
TEFusedResidualRMSNorm = None # type: ignore[assignment, misc]


class TENorm:
"""A conditional wrapper to initialize an instance of
Transformer-Engine's `LayerNorm` or `RMSNorm` based on input."""
Transformer-Engine's `LayerNorm` or `RMSNorm` based on input.

Residual fusion is a two-level opt-in mechanism:

1. Global capability: config.fused_residual_rmsnorm must be True (enables the feature)
2. Local intent: has_residual=True must be passed at build site (declares this specific
norm is followed by a residual connection)

Fusion only happens when BOTH conditions are met.

"""

# TODO should we ditch normalization config and just use spec to choose LayerNorm vs RMSNorm?
def __new__(
cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5
) -> LayerNormInterface:
cls,
config: TransformerConfig,
hidden_size: int,
eps: float = 1e-5,
has_residual: bool = False,
):
if not HAVE_TE:
raise ImportError(
"Transformer Engine is not installed. "
"Please install it with `pip install transformer-engine`."
)

use_fused_residual = config.fused_residual_rmsnorm and has_residual
if use_fused_residual and config.normalization != "RMSNorm":
raise ValueError("Fused residual is only supported " "for RMSNorm normalization")

if config.normalization == "LayerNorm":
instance = te.pytorch.LayerNorm(
hidden_size=hidden_size,
eps=eps,
sequence_parallel=config.sequence_parallel,
zero_centered_gamma=config.layernorm_zero_centered_gamma,
**_get_extra_te_kwargs(config),
)
norm_module = te.pytorch.LayerNorm
elif config.normalization == "RMSNorm":
assert hasattr(
te.pytorch, "RMSNorm"
), "Transformer-Engine >= v0.11 required to use this feature"
instance = te.pytorch.RMSNorm(
hidden_size=hidden_size,
eps=eps,
sequence_parallel=config.sequence_parallel,
zero_centered_gamma=config.layernorm_zero_centered_gamma,
**_get_extra_te_kwargs(config),
)
if use_fused_residual:
assert (
TEFusedResidualRMSNorm is not None
), "TEFusedResidualRMSNorm requires Transformer-Engine >= v1.13.0"
norm_module = TEFusedResidualRMSNorm
else:
norm_module = te.pytorch.RMSNorm
else:
raise Exception("Only LayerNorm and RMSNorm are curently supported")
raise Exception("Only LayerNorm and RMSNorm are currently supported")

instance = norm_module(
normalized_shape=hidden_size,
eps=eps,
sequence_parallel=config.sequence_parallel,
zero_centered_gamma=config.layernorm_zero_centered_gamma,
**_get_extra_te_kwargs(config),
)

return cast(LayerNormInterface, instance)

Expand Down
14 changes: 12 additions & 2 deletions megatron/core/extensions/transformer_engine_spec_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@
from megatron.core.utils import get_te_version, is_te_min_version


class _TENormWithResidual:
"""Class adapter for TENorm with residual fusion enabled."""

def __new__(cls, *args, **kwargs):
return TENorm(*args, has_residual=True, **kwargs)


class TESpecProvider(BackendSpecProvider):
"""A protocol for providing the submodules used in Spec building."""

Expand All @@ -51,14 +58,17 @@ def column_parallel_layer_norm_linear(self) -> Optional[type]:
"""Which module for sequential layernorm and linear"""
return TELayerNormColumnParallelLinear

def layer_norm(self, rms_norm: bool = False, for_qk: bool = False) -> LayerNormBuilder:
def layer_norm(
self, rms_norm: bool = False, for_qk: bool = False, has_residual: bool = False
) -> LayerNormBuilder:
"""Which module to use for layer norm"""
if for_qk and not is_te_min_version("1.9.0"):
# TENorm significantly harms convergence when used
# for QKLayerNorm if TE Version < 1.9;
# we instead use the Apex implementation.
return FusedLayerNorm
return TENorm
# Keep returning a class so this path stays aligned with build_module's class handling.
return _TENormWithResidual if has_residual else TENorm

def core_attention(self) -> type:
"""Which module to use for attention"""
Expand Down
12 changes: 9 additions & 3 deletions megatron/core/models/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def column_parallel_layer_norm_linear(self) -> Optional[type]:
...

@abstractmethod
def layer_norm(self, rms_norm: bool = False, for_qk: bool = False) -> LayerNormBuilder:
def layer_norm(
self, rms_norm: bool = False, for_qk: bool = False, has_residual: bool = False
) -> LayerNormBuilder:
"""Which module for layernorm"""
...

Expand Down Expand Up @@ -113,7 +115,9 @@ def column_parallel_layer_norm_linear(self) -> Optional[type]:
"""Which module for sequential layernorm and linear"""
return None

def layer_norm(self, rms_norm: bool = False, for_qk: bool = False) -> LayerNormBuilder:
def layer_norm(
self, rms_norm: bool = False, for_qk: bool = False, has_residual: bool = False
) -> LayerNormBuilder:
"""Which module to use for layer norm"""
if rms_norm:
# Matching get_gpt_layer_local_spec.
Expand Down Expand Up @@ -162,7 +166,9 @@ def column_parallel_layer_norm_linear(self) -> type[InferenceLayerNormColumnPara
"""Which module for sequential layernorm and linear"""
return InferenceLayerNormColumnParallelLinear

def layer_norm(self, rms_norm: bool = False, for_qk: bool = False) -> LayerNormBuilder:
def layer_norm(
self, rms_norm: bool = False, for_qk: bool = False, has_residual: bool = False
) -> LayerNormBuilder:
"""Which module to use for layer norm"""
if for_qk and not is_te_min_version("1.9.0"):
# TENorm significantly harms convergence when used
Expand Down
12 changes: 6 additions & 6 deletions megatron/core/models/gpt/gpt_layer_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def get_gpt_layer_with_inference_submodules(
else backend.column_parallel_linear()
)
return TransformerLayerSubmodules(
input_layernorm=backend.layer_norm(),
input_layernorm=backend.layer_norm(has_residual=True),
self_attention=ModuleSpec(
module=MLASelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
Expand Down Expand Up @@ -244,7 +244,7 @@ def get_gpt_layer_with_transformer_engine_submodules(
else backend.column_parallel_linear()
)
return TransformerLayerSubmodules(
input_layernorm=backend.layer_norm(),
input_layernorm=backend.layer_norm(has_residual=True),
self_attention=ModuleSpec(
module=MLASelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
Expand All @@ -261,7 +261,7 @@ def get_gpt_layer_with_transformer_engine_submodules(
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp,
pre_mlp_layernorm=backend.layer_norm(has_residual=True) if num_experts else IdentityOp,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
)
Expand All @@ -284,7 +284,7 @@ def get_gpt_layer_with_transformer_engine_submodules(
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp,
pre_mlp_layernorm=backend.layer_norm(has_residual=True) if num_experts else IdentityOp,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
sharded_state_dict_keys_map={
Expand Down Expand Up @@ -345,10 +345,10 @@ def get_gpt_layer_local_submodules(
backend = LocalSpecProvider()
# Adjust for RMS norm.
if normalization == "RMSNorm":
layer_norm = backend.layer_norm(rms_norm=True, for_qk=False)
layer_norm = backend.layer_norm(rms_norm=True, for_qk=False, has_residual=True)
qk_norm = backend.layer_norm(rms_norm=True, for_qk=True)
else:
layer_norm = backend.layer_norm(rms_norm=False, for_qk=False)
layer_norm = backend.layer_norm(rms_norm=False, for_qk=False, has_residual=True)
qk_norm = backend.layer_norm(rms_norm=False, for_qk=True)

if fp8 is not None:
Expand Down
9 changes: 9 additions & 0 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,9 @@ class TransformerConfig(ModelParallelConfig):
fused_single_qkv_rope: bool = False
"""If set, avoid splitting QKV before ROPE forward and avoid concatenating ROPE dgrads."""

fused_residual_rmsnorm: bool = False
"""If True, fuses residual connection and RMSNorm backward pass when TE is used."""

####################
# activation recomputation
####################
Expand Down Expand Up @@ -1635,6 +1638,12 @@ def __post_init__(self):
"to True and use_te_activation_func to False."
)

if self.fused_residual_rmsnorm:
if self.normalization != "RMSNorm":
raise ValueError(
"fused_residual_rmsnorm is only supported when normalization is RMSNorm."
)

if self.use_te_activation_func:
if self.activation_func not in (F.gelu, F.silu, F.relu):
raise ValueError(
Expand Down
Loading
Loading