feat: refine rms norm ops#343
Open
RuibinCheung wants to merge 2 commits into
Open
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
This PR migrates Primus-Turbo’s RMSNorm from a legacy C++/CUDA extension to Triton kernels implemented in Python, adds a fused rmsnorm_residual variant to combine residual-add + RMSNorm in one launch, and updates tests to cover more dtypes/shapes.
Changes:
- Add Triton RMSNorm forward/backward kernels (single-row + multi-row) and a fused residual-add RMSNorm variant.
- Introduce Python launchers and new autograd-backed public ops:
rmsnormandrmsnorm_residual. - Update the
RMSNormmodule and expand/restore test coverage for fp16/bf16 and the new residual API.
Reviewed changes
Copilot reviewed 7 out of 9 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
primus_turbo/triton/normalization/rmsnorm_kernel.py |
New Triton kernels for RMSNorm (standard + residual), including multi-row variants. |
primus_turbo/pytorch/kernels/normalization/rmsnorm_impl.py |
New Python launchers/config selection for Triton RMSNorm kernels. |
primus_turbo/pytorch/ops/normalization.py |
New public ops (rmsnorm, rmsnorm_residual) via custom autograd Functions. |
primus_turbo/pytorch/modules/normalization.py |
Switch module implementation to call the Triton-backed RMSNorm op. |
tests/pytorch/ops/test_normalization.py |
Re-enable fp16/bf16 testing, adjust shapes, add residual op tests. |
csrc/pytorch/normalization/normalization.cpp |
Removed legacy C++/CUDA RMSNorm implementation. |
csrc/pytorch/normalization/normalization_meta.cpp |
Removed legacy meta implementation for RMSNorm. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
30
to
+36
| def forward(ctx, x: torch.Tensor, gamma: torch.Tensor, eps: float = 1e-6): | ||
| y = torch.ops.primus_turbo_cpp_extension.rmsnorm_fwd(x, gamma, eps) | ||
| assert x.is_cuda and gamma.is_cuda, "rmsnorm: x and gamma must be CUDA tensors" | ||
| orig_shape = x.shape | ||
| H = gamma.shape[0] | ||
| assert ( | ||
| orig_shape[-1] == H | ||
| ), f"rmsnorm: last dim of x ({orig_shape[-1]}) must equal gamma.shape[0] ({H})" |
| ), f"rmsnorm_residual: shape mismatch {tuple(x.shape)} vs {tuple(residual.shape)}" | ||
| orig_shape = x.shape | ||
| H = gamma.shape[0] | ||
| assert orig_shape[-1] == H |
Comment on lines
+18
to
+23
| from primus_turbo.pytorch.kernels.normalization.rmsnorm_impl import ( | ||
| rmsnorm_bwd_impl, | ||
| rmsnorm_bwd_residual_impl, | ||
| rmsnorm_fwd_impl, | ||
| rmsnorm_fwd_residual_impl, | ||
| ) |
Comment on lines
+106
to
+112
| @staticmethod | ||
| def backward(ctx, grad_y: torch.Tensor, grad_xpr: torch.Tensor): | ||
| x_plus_r, gamma, rstd = ctx.saved_tensors | ||
| dx, dg = rmsnorm_bwd_residual_impl( | ||
| grad_y, | ||
| grad_xpr, | ||
| x_plus_r, |
| if self.elementwise_affine: | ||
| torch.nn.init.ones_(self.weight) | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
Comment on lines
+70
to
+73
| # Backward — only flow gradient through y so dxpr = 0. | ||
| grad_out = torch.randn_like(y) | ||
| y.backward(grad_out) | ||
| y_ref.backward(grad_out) |
Contributor
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 9 out of 11 changed files in this pull request and generated 8 comments.
Comments suppressed due to low confidence (3)
primus_turbo/pytorch/ops/normalization.py:92
rmsnorm_residualalso relies on Pythonassertfor required checks (CUDA tensors, shape match, last-dim match). For a public op, these validations should be done with explicit exceptions so they are not skipped underpython -O, and so error types/messages are consistent.
def forward(ctx, x: torch.Tensor, residual: torch.Tensor, gamma: torch.Tensor, eps: float = 1e-6):
assert (
x.is_cuda and residual.is_cuda and gamma.is_cuda
), "rmsnorm_residual: x, residual and gamma must be CUDA tensors"
assert (
x.shape == residual.shape
), f"rmsnorm_residual: shape mismatch {tuple(x.shape)} vs {tuple(residual.shape)}"
orig_shape = x.shape
H = gamma.shape[0]
assert orig_shape[-1] == H
primus_turbo/pytorch/kernels/normalization/rmsnorm_impl.py:275
- Same as the standard backward: when
ROWS == 1, the residual backward allocates a(B, H)fp32dg_partialbuffer, which can be very large and risk OOM for bigB/H. Consider using an in-kernel reduction strategy to avoid materializing the full buffer.
dx = torch.empty_like(x_plus_r)
if ROWS == 1:
dg_partial = torch.empty(B, H, device=x_plus_r.device, dtype=torch.float32)
rmsnorm_bwd_residual_kernel[(B,)](
csrc/pytorch/bindings_pytorch.cpp:98
dequantize_fp8_rowwiseis registered in the CUDA impl table, but no definition is present in the codebase (only a prototype). This will break the extension build at link time. Add the missing implementation or remove the registration.
// ********* Quantization *********
m.impl("quantize_fp8_tensorwise", quantize_fp8_tensorwise);
m.impl("dequantize_fp8_tensorwise", dequantize_fp8_tensorwise);
m.impl("quantize_fp8_rowwise", quantize_fp8_rowwise);
m.impl("dequantize_fp8_rowwise", dequantize_fp8_rowwise);
Comment on lines
51
to
+52
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| return F.rms_norm(x, self.normalized_shape, self.weight, self.eps) | ||
| return _rmsnorm(x, self.weight, self.eps) |
32a73ba to
e8eee83
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.


Description
Migrate MLPerf optimized rmsnorm triton kernel into Primus-Turbo.
This PR refactors the RMSNorm operator: the legacy C++/CUDA implementation under
csrc/pytorch/normalization/is removed and replaced with a Triton-backed implementation living entirely in Python. The new implementation is stride-aware, picks its launch configuration based on(H, B), and additionally introduces a fusedrmsnorm_residualop that combinesx + residualwith the RMSNorm in a single kernel launch—removing the standalone elementwise add that typically precedes the norm in residual paths (e.g. transformer blocks / MoE attention).Motivation:
.contiguous()copies on the autograd hot path by making the kernels stride-aware on both batch and hidden dimensions.x + residualwithout re-reading the tensor.Fixes # (issue)
Type of change
Changes
csrc/pytorch/normalization/normalization.cppcsrc/pytorch/normalization/normalization_meta.cppprimus_turbo/triton/normalization/rmsnorm_kernel.py:rmsnorm_fwd_kernel/rmsnorm_fwd_kernel_multi_rowrmsnorm_bwd_kernel/rmsnorm_bwd_kernel_multi_rowrmsnorm_fwd_residual_kernel/rmsnorm_fwd_residual_kernel_multi_rowrmsnorm_bwd_residual_kernel/rmsnorm_bwd_residual_kernel_multi_rowreshape(-1, H)of strided tensors) without an implicit_to_copy.dgammainside each program overROWS_PER_BLOCKrows, shrinking the partial buffer from(B, H)to(num_programs, H)—important for small-H / huge-B shapes such as q_norm in MoE attention.primus_turbo/pytorch/kernels/normalization/rmsnorm_impl.py:rmsnorm_fwd_impl/rmsnorm_bwd_implrmsnorm_fwd_residual_impl/rmsnorm_bwd_residual_impl_pick_config(H, B)chooses(BLOCK_H, ROWS_PER_BLOCK, num_warps, num_stages)adaptively. Multi-row mode is enabled whenBLOCK_H ≤ 256andB ≥ 4096.primus_turbo/pytorch/ops/normalization.py:rmsnorm(x, gamma, eps=1e-6) -> ynow dispatches to the Triton launchers via_RMSNormFunction(customtorch.autograd.Function).rmsnorm_residual(x, residual, gamma, eps=1e-6) -> (y, x_plus_r)implemented as_RMSNormResidualFunction. Backward returns the same gradient for bothxandresidualbecauseaddhas Jacobian[I, I].primus_turbo/pytorch/modules/normalization.py(RMSNormmodule): switched to importing the Triton-backedrmsnormop; module API is unchanged.tests/pytorch/ops/test_normalization.py):torch.float16andtorch.bfloat16fortest_rmsnorm_ops(previously fp32-only) and expanded shape coverage (inner_shape=128added; oversized 16384 outer removed to keep the grid bounded).test_rmsnorm_residual_opscovering forward, backward, dtype coverage, and the fact thatx.grad == residual.grad.dgammauses a slightly looser tolerance for fp16/bf16 to absorb the extra reduction noise from the doubled input magnitude.Checklist: