Skip to content

feat: refine rms norm ops#343

Open
RuibinCheung wants to merge 2 commits into
mainfrom
dev/zhangrb/refine_rmsnorm
Open

feat: refine rms norm ops#343
RuibinCheung wants to merge 2 commits into
mainfrom
dev/zhangrb/refine_rmsnorm

Conversation

@RuibinCheung
Copy link
Copy Markdown
Collaborator

Description

Migrate MLPerf optimized rmsnorm triton kernel into Primus-Turbo.

This PR refactors the RMSNorm operator: the legacy C++/CUDA implementation under csrc/pytorch/normalization/ is removed and replaced with a Triton-backed implementation living entirely in Python. The new implementation is stride-aware, picks its launch configuration based on (H, B), and additionally introduces a fused rmsnorm_residual op that combines x + residual with the RMSNorm in a single kernel launch—removing the standalone elementwise add that typically precedes the norm in residual paths (e.g. transformer blocks / MoE attention).

Motivation:

  • Remove the C++ build dependency for RMSNorm and unify on Triton, which is easier to tune per shape and per backend.
  • Eliminate forced .contiguous() copies on the autograd hot path by making the kernels stride-aware on both batch and hidden dimensions.
  • Provide a fused residual variant so the next residual-add can directly reuse x + residual without re-reading the tensor.
  • Restore full dtype coverage (fp32 / fp16 / bf16) in tests, which had been temporarily narrowed to fp32 only.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Removed the legacy C++/CUDA RMSNorm extension:
    • Deleted csrc/pytorch/normalization/normalization.cpp
    • Deleted csrc/pytorch/normalization/normalization_meta.cpp
  • Added Triton kernels under primus_turbo/triton/normalization/rmsnorm_kernel.py:
    • rmsnorm_fwd_kernel / rmsnorm_fwd_kernel_multi_row
    • rmsnorm_bwd_kernel / rmsnorm_bwd_kernel_multi_row
    • rmsnorm_fwd_residual_kernel / rmsnorm_fwd_residual_kernel_multi_row
    • rmsnorm_bwd_residual_kernel / rmsnorm_bwd_residual_kernel_multi_row
    • All kernels are stride-aware on both batch and hidden dims, so callers can pass non-contiguous views (reshape(-1, H) of strided tensors) without an implicit _to_copy.
    • Multi-row variants reduce dgamma inside each program over ROWS_PER_BLOCK rows, shrinking the partial buffer from (B, H) to (num_programs, H)—important for small-H / huge-B shapes such as q_norm in MoE attention.
  • Added Python launchers in primus_turbo/pytorch/kernels/normalization/rmsnorm_impl.py:
    • rmsnorm_fwd_impl / rmsnorm_bwd_impl
    • rmsnorm_fwd_residual_impl / rmsnorm_bwd_residual_impl
    • _pick_config(H, B) chooses (BLOCK_H, ROWS_PER_BLOCK, num_warps, num_stages) adaptively. Multi-row mode is enabled when BLOCK_H ≤ 256 and B ≥ 4096.
  • Refactored primus_turbo/pytorch/ops/normalization.py:
    • rmsnorm(x, gamma, eps=1e-6) -> y now dispatches to the Triton launchers via _RMSNormFunction (custom torch.autograd.Function).
    • New public API rmsnorm_residual(x, residual, gamma, eps=1e-6) -> (y, x_plus_r) implemented as _RMSNormResidualFunction. Backward returns the same gradient for both x and residual because add has Jacobian [I, I].
  • primus_turbo/pytorch/modules/normalization.py (RMSNorm module): switched to importing the Triton-backed rmsnorm op; module API is unchanged.
  • Tests (tests/pytorch/ops/test_normalization.py):
    • Re-enabled torch.float16 and torch.bfloat16 for test_rmsnorm_ops (previously fp32-only) and expanded shape coverage (inner_shape=128 added; oversized 16384 outer removed to keep the grid bounded).
    • Added test_rmsnorm_residual_ops covering forward, backward, dtype coverage, and the fact that x.grad == residual.grad. dgamma uses a slightly looser tolerance for fp16/bf16 to absorb the extra reduction noise from the doubled input magnitude.
    • Removed stale debug prints.

Checklist:

  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copilot AI review requested due to automatic review settings May 18, 2026 08:00
Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

安全审查完成:发现 2 个高置信问题,均与新 Triton RMSNorm kernel 的张量元数据约束不足有关。未发现注入、认证/授权、SSRF、XSS、路径遍历、反序列化或依赖供应链风险。

Open in Web View Automation 

Sent by Cursor Automation: Find vulnerabilities

Comment thread primus_turbo/triton/normalization/rmsnorm_kernel.py
Comment thread primus_turbo/pytorch/ops/normalization.py
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR migrates Primus-Turbo’s RMSNorm from a legacy C++/CUDA extension to Triton kernels implemented in Python, adds a fused rmsnorm_residual variant to combine residual-add + RMSNorm in one launch, and updates tests to cover more dtypes/shapes.

Changes:

  • Add Triton RMSNorm forward/backward kernels (single-row + multi-row) and a fused residual-add RMSNorm variant.
  • Introduce Python launchers and new autograd-backed public ops: rmsnorm and rmsnorm_residual.
  • Update the RMSNorm module and expand/restore test coverage for fp16/bf16 and the new residual API.

Reviewed changes

Copilot reviewed 7 out of 9 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
primus_turbo/triton/normalization/rmsnorm_kernel.py New Triton kernels for RMSNorm (standard + residual), including multi-row variants.
primus_turbo/pytorch/kernels/normalization/rmsnorm_impl.py New Python launchers/config selection for Triton RMSNorm kernels.
primus_turbo/pytorch/ops/normalization.py New public ops (rmsnorm, rmsnorm_residual) via custom autograd Functions.
primus_turbo/pytorch/modules/normalization.py Switch module implementation to call the Triton-backed RMSNorm op.
tests/pytorch/ops/test_normalization.py Re-enable fp16/bf16 testing, adjust shapes, add residual op tests.
csrc/pytorch/normalization/normalization.cpp Removed legacy C++/CUDA RMSNorm implementation.
csrc/pytorch/normalization/normalization_meta.cpp Removed legacy meta implementation for RMSNorm.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 30 to +36
def forward(ctx, x: torch.Tensor, gamma: torch.Tensor, eps: float = 1e-6):
y = torch.ops.primus_turbo_cpp_extension.rmsnorm_fwd(x, gamma, eps)
assert x.is_cuda and gamma.is_cuda, "rmsnorm: x and gamma must be CUDA tensors"
orig_shape = x.shape
H = gamma.shape[0]
assert (
orig_shape[-1] == H
), f"rmsnorm: last dim of x ({orig_shape[-1]}) must equal gamma.shape[0] ({H})"
), f"rmsnorm_residual: shape mismatch {tuple(x.shape)} vs {tuple(residual.shape)}"
orig_shape = x.shape
H = gamma.shape[0]
assert orig_shape[-1] == H
Comment on lines +18 to +23
from primus_turbo.pytorch.kernels.normalization.rmsnorm_impl import (
rmsnorm_bwd_impl,
rmsnorm_bwd_residual_impl,
rmsnorm_fwd_impl,
rmsnorm_fwd_residual_impl,
)
Comment on lines +106 to +112
@staticmethod
def backward(ctx, grad_y: torch.Tensor, grad_xpr: torch.Tensor):
x_plus_r, gamma, rstd = ctx.saved_tensors
dx, dg = rmsnorm_bwd_residual_impl(
grad_y,
grad_xpr,
x_plus_r,
if self.elementwise_affine:
torch.nn.init.ones_(self.weight)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Comment on lines +70 to +73
# Backward — only flow gradient through y so dxpr = 0.
grad_out = torch.randn_like(y)
y.backward(grad_out)
y_ref.backward(grad_out)
Copilot AI review requested due to automatic review settings May 18, 2026 08:53
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 9 out of 11 changed files in this pull request and generated 8 comments.

Comments suppressed due to low confidence (3)

primus_turbo/pytorch/ops/normalization.py:92

  • rmsnorm_residual also relies on Python assert for required checks (CUDA tensors, shape match, last-dim match). For a public op, these validations should be done with explicit exceptions so they are not skipped under python -O, and so error types/messages are consistent.
    def forward(ctx, x: torch.Tensor, residual: torch.Tensor, gamma: torch.Tensor, eps: float = 1e-6):
        assert (
            x.is_cuda and residual.is_cuda and gamma.is_cuda
        ), "rmsnorm_residual: x, residual and gamma must be CUDA tensors"
        assert (
            x.shape == residual.shape
        ), f"rmsnorm_residual: shape mismatch {tuple(x.shape)} vs {tuple(residual.shape)}"
        orig_shape = x.shape
        H = gamma.shape[0]
        assert orig_shape[-1] == H

primus_turbo/pytorch/kernels/normalization/rmsnorm_impl.py:275

  • Same as the standard backward: when ROWS == 1, the residual backward allocates a (B, H) fp32 dg_partial buffer, which can be very large and risk OOM for big B/H. Consider using an in-kernel reduction strategy to avoid materializing the full buffer.
    dx = torch.empty_like(x_plus_r)
    if ROWS == 1:
        dg_partial = torch.empty(B, H, device=x_plus_r.device, dtype=torch.float32)
        rmsnorm_bwd_residual_kernel[(B,)](

csrc/pytorch/bindings_pytorch.cpp:98

  • dequantize_fp8_rowwise is registered in the CUDA impl table, but no definition is present in the codebase (only a prototype). This will break the extension build at link time. Add the missing implementation or remove the registration.
    // ********* Quantization *********
    m.impl("quantize_fp8_tensorwise", quantize_fp8_tensorwise);
    m.impl("dequantize_fp8_tensorwise", dequantize_fp8_tensorwise);
    m.impl("quantize_fp8_rowwise", quantize_fp8_rowwise);
    m.impl("dequantize_fp8_rowwise", dequantize_fp8_rowwise);

Comment on lines 51 to +52
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.rms_norm(x, self.normalized_shape, self.weight, self.eps)
return _rmsnorm(x, self.weight, self.eps)
Comment thread primus_turbo/pytorch/ops/normalization.py
Comment thread primus_turbo/pytorch/kernels/normalization/rmsnorm_impl.py
Comment thread primus_turbo/pytorch/kernels/normalization/rmsnorm_impl.py
Comment thread tests/pytorch/ops/test_normalization.py
Comment thread tests/pytorch/ops/test_normalization.py
Comment thread csrc/pytorch/bindings_pytorch.cpp Outdated
Comment thread csrc/pytorch/bindings_pytorch.cpp Outdated
@RuibinCheung RuibinCheung force-pushed the dev/zhangrb/refine_rmsnorm branch from 32a73ba to e8eee83 Compare May 20, 2026 01:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants