From 314380b5137e41d297185314cb74bfcc3850c9d5 Mon Sep 17 00:00:00 2001 From: RuibinCheung Date: Mon, 18 May 2026 07:49:55 +0000 Subject: [PATCH 1/2] feat: refine rms norm ops --- csrc/pytorch/normalization/normalization.cpp | 85 ---- .../normalization/normalization_meta.cpp | 22 - .../pytorch/kernels/normalization/__init__.py | 0 .../kernels/normalization/rmsnorm_impl.py | 326 ++++++++++++++ primus_turbo/pytorch/modules/normalization.py | 22 +- primus_turbo/pytorch/ops/normalization.py | 141 ++++-- primus_turbo/triton/normalization/__init__.py | 0 .../triton/normalization/rmsnorm_kernel.py | 416 ++++++++++++++++++ tests/pytorch/ops/test_normalization.py | 57 ++- 9 files changed, 913 insertions(+), 156 deletions(-) delete mode 100644 csrc/pytorch/normalization/normalization.cpp delete mode 100644 csrc/pytorch/normalization/normalization_meta.cpp create mode 100644 primus_turbo/pytorch/kernels/normalization/__init__.py create mode 100644 primus_turbo/pytorch/kernels/normalization/rmsnorm_impl.py create mode 100644 primus_turbo/triton/normalization/__init__.py create mode 100644 primus_turbo/triton/normalization/rmsnorm_kernel.py diff --git a/csrc/pytorch/normalization/normalization.cpp b/csrc/pytorch/normalization/normalization.cpp deleted file mode 100644 index b2746b2de..000000000 --- a/csrc/pytorch/normalization/normalization.cpp +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. -// -// See LICENSE for license information. - -#include "primus_turbo/normalization.h" - -#include "../extensions.h" - -namespace primus_turbo::pytorch { - -using namespace primus_turbo::dtype; - -at::Tensor rmsnorm_fwd(const at::Tensor &input, const at::Tensor &gamma, const double eps) { - TORCH_CHECK(input.is_contiguous(), "rmsnorm_fwd: input must be contiguous."); - TORCH_CHECK(gamma.is_contiguous(), "rmsnorm_fwd: gamma must be contiguous."); - - const int64_t inner_len = gamma.numel(); - const int64_t outer_len = input.numel() / inner_len; - auto output = at::empty_like(input); - - TORCH_CHECK(input.numel() % inner_len == 0, "input.numel() must be divisible by gamma.numel()"); - - auto stream = at::cuda::getCurrentCUDAStream(); - if (input.scalar_type() == at::kFloat) { - rmsnorm_fwd_impl(input.data_ptr(), gamma.data_ptr(), - output.data_ptr(), inner_len, outer_len, - static_cast(eps), stream); - } else if (input.scalar_type() == at::kHalf) { - rmsnorm_fwd_impl(reinterpret_cast(input.data_ptr()), - reinterpret_cast(gamma.data_ptr()), - reinterpret_cast(output.data_ptr()), inner_len, - outer_len, static_cast(eps), stream); - } else if (input.scalar_type() == at::kBFloat16) { - rmsnorm_fwd_impl(reinterpret_cast(input.data_ptr()), - reinterpret_cast(gamma.data_ptr()), - reinterpret_cast(output.data_ptr()), inner_len, - outer_len, static_cast(eps), stream); - } else { - PRIMUS_TURBO_ERROR("RMSNorm only support : [float32, float16, bfloat16]"); - } - return output; -} - -std::vector rmsnorm_bwd(const at::Tensor &input, const at::Tensor &gamma, - const at::Tensor &grad_output, const double eps) { - TORCH_CHECK(input.is_contiguous(), "rmsnorm_bwd: input must be contiguous."); - TORCH_CHECK(gamma.is_contiguous(), "rmsnorm_bwd: gamma must be contiguous."); - TORCH_CHECK(grad_output.is_contiguous(), "rmsnorm_bwd: grad_output must be contiguous."); - - const int64_t inner_len = gamma.numel(); - const int64_t outer_len = input.numel() / inner_len; - - TORCH_CHECK(input.numel() % inner_len == 0, "input.numel() must be divisible by gamma.numel()"); - - auto input_grad = at::empty_like(input); - auto gamma_grad = at::empty_like(input); - - auto stream = at::cuda::getCurrentCUDAStream(); - if (input.scalar_type() == at::kFloat) { - rmsnorm_bwd_impl(input.data_ptr(), gamma.data_ptr(), - grad_output.data_ptr(), input_grad.data_ptr(), - gamma_grad.data_ptr(), inner_len, outer_len, - static_cast(eps), stream); - } else if (input.scalar_type() == at::kHalf) { - rmsnorm_bwd_impl(reinterpret_cast(input.data_ptr()), - reinterpret_cast(gamma.data_ptr()), - reinterpret_cast(grad_output.data_ptr()), - reinterpret_cast(input_grad.data_ptr()), - reinterpret_cast(gamma_grad.data_ptr()), inner_len, - outer_len, static_cast(eps), stream); - } else if (input.scalar_type() == at::kBFloat16) { - rmsnorm_bwd_impl(reinterpret_cast(input.data_ptr()), - reinterpret_cast(gamma.data_ptr()), - reinterpret_cast(grad_output.data_ptr()), - reinterpret_cast(input_grad.data_ptr()), - reinterpret_cast(gamma_grad.data_ptr()), inner_len, - outer_len, static_cast(eps), stream); - } else { - PRIMUS_TURBO_ERROR("RMSNorm only support : [float32, float16, bfloat16]"); - } - - return {input_grad, gamma_grad}; -} - -} // namespace primus_turbo::pytorch diff --git a/csrc/pytorch/normalization/normalization_meta.cpp b/csrc/pytorch/normalization/normalization_meta.cpp deleted file mode 100644 index ff12d0c4c..000000000 --- a/csrc/pytorch/normalization/normalization_meta.cpp +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. -// -// See LICENSE for license information. - -#include - -namespace primus_turbo::pytorch { - -at::Tensor rmsnorm_fwd_meta(const at::Tensor &input, const at::Tensor &gamma, const double eps) { - return at::empty_like(input, at::device(at::kMeta)); -} - -std::vector rmsnorm_bwd_meta(const at::Tensor &input, const at::Tensor &gamma, - const at::Tensor &grad_output, const double eps) { - - at::Tensor grad_input = at::empty_like(input, at::device(at::kMeta)); - at::Tensor grad_gamma = at::empty_like(input, at::device(at::kMeta)); - - return {grad_input, grad_gamma}; -} - -} // namespace primus_turbo::pytorch diff --git a/primus_turbo/pytorch/kernels/normalization/__init__.py b/primus_turbo/pytorch/kernels/normalization/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/primus_turbo/pytorch/kernels/normalization/rmsnorm_impl.py b/primus_turbo/pytorch/kernels/normalization/rmsnorm_impl.py new file mode 100644 index 000000000..f5dbd620b --- /dev/null +++ b/primus_turbo/pytorch/kernels/normalization/rmsnorm_impl.py @@ -0,0 +1,326 @@ +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +"""Python-side wrappers that launch the Triton RMSNorm kernels.""" +from __future__ import annotations + +from typing import Optional, Tuple + +import torch + +from primus_turbo.triton.normalization.rmsnorm_kernel import ( + rmsnorm_bwd_kernel, + rmsnorm_bwd_kernel_multi_row, + rmsnorm_bwd_residual_kernel, + rmsnorm_bwd_residual_kernel_multi_row, + rmsnorm_fwd_kernel, + rmsnorm_fwd_kernel_multi_row, + rmsnorm_fwd_residual_kernel, + rmsnorm_fwd_residual_kernel_multi_row, +) + + +def _next_pow2(x: int) -> int: + p = 1 + while p < x: + p <<= 1 + return p + + +def _reshape_batch_hidden(x: torch.Tensor, H: int) -> torch.Tensor: + """Flatten to [B, H] without forcing a contiguous copy. + + Original strides are kept so the Triton kernels can read/write strided rows + directly, avoiding implicit ``_to_copy`` kernels on the autograd hot path. + """ + if x.shape[-1] != H: + raise ValueError(f"last dim mismatch: expected H={H}, got shape={tuple(x.shape)}") + return x.reshape(-1, H) + + +def _pick_config(H: int, B: int) -> Tuple[int, int, int, int]: + """Return (BLOCK_H, ROWS_PER_BLOCK, num_warps, num_stages). + + Multi-row mode (ROWS_PER_BLOCK > 1) wins when H is small AND B is huge, + because the grid size of one program per row becomes the bottleneck. + """ + BLOCK_H = _next_pow2(H) + if BLOCK_H <= 256 and B >= 4096: + ROWS = 16 if BLOCK_H <= 128 else 8 + return BLOCK_H, ROWS, 4, 2 + if BLOCK_H <= 256: + return BLOCK_H, 1, 1, 1 + if BLOCK_H <= 1024: + return BLOCK_H, 1, 4, 2 + if BLOCK_H <= 4096: + return BLOCK_H, 1, 8, 2 + return BLOCK_H, 1, 16, 2 + + +def rmsnorm_fwd_impl( + x: torch.Tensor, gamma: torch.Tensor, eps: float +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, int]: + """Forward launcher. + + Returns ``(y, x2, rstd, BLOCK_H, ROWS, num_warps, num_stages)``. ``x2`` is a + [B, H] view of ``x`` (no copy) saved for backward. ``rstd`` is the per-row + reciprocal std needed by backward. + """ + H = gamma.shape[0] + x2 = _reshape_batch_hidden(x, H) + B = x2.shape[0] + y = torch.empty_like(x2) + rstd = torch.empty(B, device=x.device, dtype=torch.float32) + BLOCK_H, ROWS, num_warps, num_stages = _pick_config(H, B) + if ROWS == 1: + rmsnorm_fwd_kernel[(B,)]( + x2, + gamma, + y, + rstd, + x2.stride(0), + x2.stride(1), + y.stride(0), + y.stride(1), + H=H, + eps=eps, + BLOCK_H=BLOCK_H, + num_warps=num_warps, + num_stages=num_stages, + ) + else: + grid = ((B + ROWS - 1) // ROWS,) + rmsnorm_fwd_kernel_multi_row[grid]( + x2, + gamma, + y, + rstd, + x2.stride(0), + x2.stride(1), + y.stride(0), + y.stride(1), + B=B, + H=H, + eps=eps, + BLOCK_H=BLOCK_H, + ROWS_PER_BLOCK=ROWS, + num_warps=num_warps, + num_stages=num_stages, + ) + return y, x2, rstd, BLOCK_H, ROWS, num_warps, num_stages + + +def rmsnorm_bwd_impl( + dy: torch.Tensor, + x2: torch.Tensor, + gamma: torch.Tensor, + rstd: torch.Tensor, + BLOCK_H: int, + ROWS: int, + num_warps: int, + num_stages: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Backward launcher. Returns ``(dx [B, H], dgamma [H])``.""" + H = gamma.shape[0] + B = x2.shape[0] + dy2 = _reshape_batch_hidden(dy, H) + dx = torch.empty_like(x2) + if ROWS == 1: + dg_partial = torch.empty(B, H, device=x2.device, dtype=torch.float32) + rmsnorm_bwd_kernel[(B,)]( + dy2, + x2, + gamma, + rstd, + dx, + dg_partial, + x2.stride(0), + x2.stride(1), + dy2.stride(0), + dy2.stride(1), + dx.stride(0), + dx.stride(1), + dg_partial.stride(0), + H=H, + BLOCK_H=BLOCK_H, + num_warps=num_warps, + num_stages=num_stages, + ) + else: + num_programs = (B + ROWS - 1) // ROWS + dg_partial = torch.empty(num_programs, H, device=x2.device, dtype=torch.float32) + grid = (num_programs,) + rmsnorm_bwd_kernel_multi_row[grid]( + dy2, + x2, + gamma, + rstd, + dx, + dg_partial, + x2.stride(0), + x2.stride(1), + dy2.stride(0), + dy2.stride(1), + dx.stride(0), + dx.stride(1), + dg_partial.stride(0), + B=B, + H=H, + BLOCK_H=BLOCK_H, + ROWS_PER_BLOCK=ROWS, + num_warps=num_warps, + num_stages=num_stages, + ) + dg = dg_partial.sum(dim=0).to(gamma.dtype) + return dx, dg + + +def rmsnorm_fwd_residual_impl( + x: torch.Tensor, residual: torch.Tensor, gamma: torch.Tensor, eps: float +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, int]: + """Fused (x + residual) -> rmsnorm forward. + + Returns ``(y, x_plus_r, rstd, BLOCK_H, ROWS, num_warps, num_stages)``. Both + ``y`` and ``x_plus_r`` are returned in [B, H] layout (caller is expected to + reshape back to the original logical shape if needed). + """ + H = gamma.shape[0] + x2 = _reshape_batch_hidden(x, H) + r2 = _reshape_batch_hidden(residual, H) + B = x2.shape[0] + y = torch.empty_like(x2) + x_plus_r = torch.empty_like(x2) + rstd = torch.empty(B, device=x.device, dtype=torch.float32) + BLOCK_H, ROWS, num_warps, num_stages = _pick_config(H, B) + if ROWS == 1: + rmsnorm_fwd_residual_kernel[(B,)]( + x2, + r2, + gamma, + y, + x_plus_r, + rstd, + x2.stride(0), + x2.stride(1), + r2.stride(0), + r2.stride(1), + y.stride(0), + y.stride(1), + x_plus_r.stride(0), + x_plus_r.stride(1), + H=H, + eps=eps, + BLOCK_H=BLOCK_H, + num_warps=num_warps, + num_stages=num_stages, + ) + else: + grid = ((B + ROWS - 1) // ROWS,) + rmsnorm_fwd_residual_kernel_multi_row[grid]( + x2, + r2, + gamma, + y, + x_plus_r, + rstd, + x2.stride(0), + x2.stride(1), + r2.stride(0), + r2.stride(1), + y.stride(0), + y.stride(1), + x_plus_r.stride(0), + x_plus_r.stride(1), + B=B, + H=H, + eps=eps, + BLOCK_H=BLOCK_H, + ROWS_PER_BLOCK=ROWS, + num_warps=num_warps, + num_stages=num_stages, + ) + return y, x_plus_r, rstd, BLOCK_H, ROWS, num_warps, num_stages + + +def rmsnorm_bwd_residual_impl( + dy: torch.Tensor, + dxpr: Optional[torch.Tensor], + x_plus_r: torch.Tensor, + gamma: torch.Tensor, + rstd: torch.Tensor, + BLOCK_H: int, + ROWS: int, + num_warps: int, + num_stages: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Backward launcher for the residual variant. Returns ``(dx [B, H], dgamma [H])``. + + Caller is expected to return the same ``dx`` for both ``x`` and ``residual`` + inputs because the upstream ``+`` has Jacobian ``[I, I]``. + """ + H = gamma.shape[0] + B = x_plus_r.shape[0] + dy2 = _reshape_batch_hidden(dy, H) + if dxpr is None: + # When ``x_plus_r`` is unused downstream, autograd may hand us None. + # Substitute zeros so the kernel sees a valid pointer. + dxpr2 = torch.zeros_like(x_plus_r) + else: + dxpr2 = _reshape_batch_hidden(dxpr, H) + dx = torch.empty_like(x_plus_r) + if ROWS == 1: + dg_partial = torch.empty(B, H, device=x_plus_r.device, dtype=torch.float32) + rmsnorm_bwd_residual_kernel[(B,)]( + dy2, + dxpr2, + x_plus_r, + gamma, + rstd, + dx, + dg_partial, + x_plus_r.stride(0), + x_plus_r.stride(1), + dy2.stride(0), + dy2.stride(1), + dxpr2.stride(0), + dxpr2.stride(1), + dx.stride(0), + dx.stride(1), + dg_partial.stride(0), + H=H, + BLOCK_H=BLOCK_H, + num_warps=num_warps, + num_stages=num_stages, + ) + else: + num_programs = (B + ROWS - 1) // ROWS + dg_partial = torch.empty(num_programs, H, device=x_plus_r.device, dtype=torch.float32) + grid = (num_programs,) + rmsnorm_bwd_residual_kernel_multi_row[grid]( + dy2, + dxpr2, + x_plus_r, + gamma, + rstd, + dx, + dg_partial, + x_plus_r.stride(0), + x_plus_r.stride(1), + dy2.stride(0), + dy2.stride(1), + dxpr2.stride(0), + dxpr2.stride(1), + dx.stride(0), + dx.stride(1), + dg_partial.stride(0), + B=B, + H=H, + BLOCK_H=BLOCK_H, + ROWS_PER_BLOCK=ROWS, + num_warps=num_warps, + num_stages=num_stages, + ) + dg = dg_partial.sum(dim=0).to(gamma.dtype) + return dx, dg diff --git a/primus_turbo/pytorch/modules/normalization.py b/primus_turbo/pytorch/modules/normalization.py index fe4c53796..bd4f02068 100644 --- a/primus_turbo/pytorch/modules/normalization.py +++ b/primus_turbo/pytorch/modules/normalization.py @@ -9,13 +9,15 @@ import torch from torch import Size -from torch.nn import functional as F from torch.nn.parameter import Parameter +from primus_turbo.pytorch.ops.normalization import rmsnorm as _rmsnorm + __all__ = ["RMSNorm"] class RMSNorm(torch.nn.Module): + def __init__( self, normalized_shape: Union[int, list[int], Size], @@ -27,10 +29,14 @@ def __init__( factory_kwargs = {"device": device, "dtype": dtype} super().__init__() if isinstance(normalized_shape, numbers.Integral): - # mypy error: incompatible types in assignment normalized_shape = (normalized_shape,) # type: ignore[assignment] self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] - self.eps = eps + if len(self.normalized_shape) != 1: + raise ValueError( + "primus_turbo RMSNorm currently only supports a 1-D normalized_shape " + f"(got {self.normalized_shape}). Use the last hidden dimension." + ) + self.eps = eps if eps is not None else 1e-6 self.elementwise_affine = elementwise_affine if self.elementwise_affine: self.weight = Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) @@ -39,19 +45,13 @@ def __init__( self.reset_parameters() def reset_parameters(self) -> None: - """ - Resets parameters based on their initialization used in __init__. - """ if self.elementwise_affine: torch.nn.init.ones_(self.weight) def forward(self, x: torch.Tensor) -> torch.Tensor: - return F.rms_norm(x, self.normalized_shape, self.weight, self.eps) + return _rmsnorm(x, self.weight, self.eps) def extra_repr(self) -> str: - """ - Extra information about the module. - """ - return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format( + return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format( **self.__dict__ ) diff --git a/primus_turbo/pytorch/ops/normalization.py b/primus_turbo/pytorch/ops/normalization.py index a4332d9c7..2b083e60a 100644 --- a/primus_turbo/pytorch/ops/normalization.py +++ b/primus_turbo/pytorch/ops/normalization.py @@ -3,51 +3,140 @@ # # See LICENSE for license information. ############################################################################### +"""Triton-backed RMSNorm ops (standard + fused residual variant). + +Public API: + - ``rmsnorm(x, gamma, eps=1e-6) -> y`` + - ``rmsnorm_residual(x, residual, gamma, eps=1e-6) -> (y, x_plus_r)`` +""" +from __future__ import annotations + +from typing import Tuple import torch -__all__ = ["rmsnorm"] +from primus_turbo.pytorch.kernels.normalization.rmsnorm_impl import ( + rmsnorm_bwd_impl, + rmsnorm_bwd_residual_impl, + rmsnorm_fwd_impl, + rmsnorm_fwd_residual_impl, +) +__all__ = ["rmsnorm", "rmsnorm_residual"] -class RMSNormFunction(torch.autograd.Function): + +class _RMSNormFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: torch.Tensor, gamma: torch.Tensor, eps: float = 1e-6): - y = torch.ops.primus_turbo_cpp_extension.rmsnorm_fwd(x, gamma, eps) + assert x.is_cuda and gamma.is_cuda, "rmsnorm: x and gamma must be CUDA tensors" + orig_shape = x.shape + H = gamma.shape[0] + assert ( + orig_shape[-1] == H + ), f"rmsnorm: last dim of x ({orig_shape[-1]}) must equal gamma.shape[0] ({H})" + + y, x2, rstd, BLOCK_H, ROWS, num_warps, num_stages = rmsnorm_fwd_impl(x, gamma, eps) - ctx.save_for_backward(x, gamma) + ctx.save_for_backward(x2, gamma, rstd) ctx.eps = eps - return y + ctx.orig_shape = orig_shape + ctx.BLOCK_H = BLOCK_H + ctx.ROWS = ROWS + ctx.num_warps = num_warps + ctx.num_stages = num_stages + return y.reshape(orig_shape) @staticmethod - def backward_torch(ctx, grad_out: torch.Tensor): - x, gamma = ctx.saved_tensors - eps = ctx.eps + def backward(ctx, grad_out: torch.Tensor): + x2, gamma, rstd = ctx.saved_tensors + dx, dg = rmsnorm_bwd_impl( + grad_out, + x2, + gamma, + rstd, + ctx.BLOCK_H, + ctx.ROWS, + ctx.num_warps, + ctx.num_stages, + ) + return dx.reshape(ctx.orig_shape), dg, None - N = x.size(-1) - x_squared = x * x - x_squared_sum = x_squared.sum(dim=-1, keepdim=True) - x_norm = torch.rsqrt(x_squared_sum / N + eps) - grad_x_norm = grad_out * gamma # scale by g - grad_x_part1 = grad_x_norm * x_norm # apply normalized scaling +class _RMSNormResidualFunction(torch.autograd.Function): + """Fused (x + residual) -> rmsnorm -> (y, x_plus_r). - grad_x_squared_sum = (-0.5 * (x_squared_sum / N + eps) ** (-1.5)) * (2 * x / N) - grad_x_part2 = grad_x_squared_sum * (x * grad_x_norm).sum(dim=-1, keepdim=True) + Replaces the pattern:: - grad_x = grad_x_part1 + grad_x_part2 + h = x + residual + y = rmsnorm(h) - # Gradient w.r.t. g - grad_g = (grad_out * x * x_norm).sum(dim=0) + with a single Triton kernel that computes both ``h`` and ``y`` from one load + of ``x`` and one load of ``residual``. ``h`` is exposed as a second return + so the caller can feed it to the next residual add. - return grad_x, grad_g, None + Backward returns ``(dx, dresidual, dgamma, None)`` where ``dx == dresidual`` + because the upstream ``+`` has Jacobian ``[I, I]``. + """ @staticmethod - def backward(ctx, grad_out: torch.Tensor): - x, gamma = ctx.saved_tensors - eps = ctx.eps - grad_x, grad_g = torch.ops.primus_turbo_cpp_extension.rmsnorm_bwd(x, gamma, grad_out, eps) - return grad_x, grad_g.sum(dim=0), None + def forward(ctx, x: torch.Tensor, residual: torch.Tensor, gamma: torch.Tensor, eps: float = 1e-6): + assert ( + x.is_cuda and residual.is_cuda and gamma.is_cuda + ), "rmsnorm_residual: x, residual and gamma must be CUDA tensors" + assert ( + x.shape == residual.shape + ), f"rmsnorm_residual: shape mismatch {tuple(x.shape)} vs {tuple(residual.shape)}" + orig_shape = x.shape + H = gamma.shape[0] + assert orig_shape[-1] == H + + y, x_plus_r, rstd, BLOCK_H, ROWS, num_warps, num_stages = rmsnorm_fwd_residual_impl( + x, residual, gamma, eps + ) + + ctx.save_for_backward(x_plus_r, gamma, rstd) + ctx.eps = eps + ctx.orig_shape = orig_shape + ctx.BLOCK_H = BLOCK_H + ctx.ROWS = ROWS + ctx.num_warps = num_warps + ctx.num_stages = num_stages + return y.reshape(orig_shape), x_plus_r.reshape(orig_shape) + + @staticmethod + def backward(ctx, grad_y: torch.Tensor, grad_xpr: torch.Tensor): + x_plus_r, gamma, rstd = ctx.saved_tensors + dx, dg = rmsnorm_bwd_residual_impl( + grad_y, + grad_xpr, + x_plus_r, + gamma, + rstd, + ctx.BLOCK_H, + ctx.ROWS, + ctx.num_warps, + ctx.num_stages, + ) + dx_out = dx.reshape(ctx.orig_shape) + # Jacobian of add() is [I, I] -> both x and residual get the same grad. + return dx_out, dx_out, dg, None def rmsnorm(x: torch.Tensor, gamma: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: - return RMSNormFunction.apply(x, gamma, eps) + """Triton-backed RMSNorm: ``y = (x / rms(x)) * gamma``.""" + return _RMSNormFunction.apply(x, gamma, eps) + + +def rmsnorm_residual( + x: torch.Tensor, + residual: torch.Tensor, + gamma: torch.Tensor, + eps: float = 1e-6, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Fused residual-add + RMSNorm. + + Returns ``(rmsnorm(x + residual) * gamma, x + residual)`` in a single kernel + launch. The second element is exposed so the caller can feed it directly to + the next residual add, removing the standalone elementwise add. + """ + return _RMSNormResidualFunction.apply(x, residual, gamma, eps) diff --git a/primus_turbo/triton/normalization/__init__.py b/primus_turbo/triton/normalization/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/primus_turbo/triton/normalization/rmsnorm_kernel.py b/primus_turbo/triton/normalization/rmsnorm_kernel.py new file mode 100644 index 000000000..5f71dc0cf --- /dev/null +++ b/primus_turbo/triton/normalization/rmsnorm_kernel.py @@ -0,0 +1,416 @@ +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +"""Triton RMSNorm kernels (single-row + multi-row fwd/bwd, plus residual variant). + +The kernels are stride-aware on both batch and hidden dims so callers can pass +non-contiguous views (e.g. ``hidden_states.reshape(-1, H)`` on a strided +fp16/bf16 tensor) without forcing a ``.contiguous()`` copy. + +Backward formulation (standard): + grad_x = (grad_out * gamma * rstd) - x * rstd^3 * mean(grad_out * gamma * x) / H + grad_g = sum_over_batch(grad_out * x * rstd) + +For the residual variant the bwd additionally folds the gradient flowing through +``x_plus_r`` (consumed by the next residual-add) into ``dx``. The autograd +function returns the same gradient for both ``x`` and ``residual`` since their +sum has Jacobian ``[I, I]``. + +A 2-stage bwd is used. The multi-row variants reduce ``dgamma`` *inside* each +program over its ``ROWS_PER_BLOCK`` rows, so the partial buffer is +``(num_programs, H)`` instead of ``(B, H)``. This is essential at small-H, +huge-B shapes (e.g. q_norm in MoE attention) where ``(B, H)`` would otherwise +cost an unreasonable amount of workspace memory. +""" +from __future__ import annotations + +import triton +import triton.language as tl + + +# --------------------------------------------------------------------------- +# Forward kernel — one row per program. Used when H is large. +# --------------------------------------------------------------------------- +@triton.jit +def rmsnorm_fwd_kernel( + X_ptr, + G_ptr, + Y_ptr, + RSTD_ptr, + stride_xb, + stride_xh, + stride_yb, + stride_yh, + H: tl.constexpr, + eps, + BLOCK_H: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_H) + x_ptrs = X_ptr + row * stride_xb + offs * stride_xh + y_ptrs = Y_ptr + row * stride_yb + offs * stride_yh + g_ptrs = G_ptr + offs + mask = offs < H + + x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + var = tl.sum(x * x, axis=0) / H + rstd = tl.rsqrt(var + eps) + g = tl.load(g_ptrs, mask=mask, other=0.0).to(tl.float32) + y = (x * rstd * g).to(Y_ptr.dtype.element_ty) + tl.store(y_ptrs, y, mask=mask) + tl.store(RSTD_ptr + row, rstd) + + +# --------------------------------------------------------------------------- +# Forward kernel — N rows per program (small H, huge B). Reduces grid size, +# which is critical when the launch / scheduling cost dominates. +# --------------------------------------------------------------------------- +@triton.jit +def rmsnorm_fwd_kernel_multi_row( + X_ptr, + G_ptr, + Y_ptr, + RSTD_ptr, + stride_xb, + stride_xh, + stride_yb, + stride_yh, + B, + H: tl.constexpr, + eps, + BLOCK_H: tl.constexpr, + ROWS_PER_BLOCK: tl.constexpr, +): + pid = tl.program_id(0) + row_start = pid * ROWS_PER_BLOCK + row_offs = row_start + tl.arange(0, ROWS_PER_BLOCK) + row_mask = row_offs < B + + h_offs = tl.arange(0, BLOCK_H) + h_mask = h_offs < H + + x_ptrs = X_ptr + row_offs[:, None] * stride_xb + h_offs[None, :] * stride_xh + y_ptrs = Y_ptr + row_offs[:, None] * stride_yb + h_offs[None, :] * stride_yh + g_ptrs = G_ptr + h_offs + + full_mask = row_mask[:, None] & h_mask[None, :] + x = tl.load(x_ptrs, mask=full_mask, other=0.0).to(tl.float32) + g = tl.load(g_ptrs, mask=h_mask, other=0.0).to(tl.float32) + + var = tl.sum(x * x, axis=1) / H + rstd = tl.rsqrt(var + eps) + y = (x * rstd[:, None] * g[None, :]).to(Y_ptr.dtype.element_ty) + tl.store(y_ptrs, y, mask=full_mask) + tl.store(RSTD_ptr + row_offs, rstd, mask=row_mask) + + +# --------------------------------------------------------------------------- +# Forward kernel — fused residual add. Computes +# x_plus_r = x + residual +# y = rmsnorm(x_plus_r) * gamma +# in one pass and exposes ``x_plus_r`` for the next residual add. Saves +# ``x_plus_r`` (input dtype) and ``rstd`` (fp32) for backward. +# --------------------------------------------------------------------------- +@triton.jit +def rmsnorm_fwd_residual_kernel( + X_ptr, + R_ptr, + G_ptr, + Y_ptr, + XPR_ptr, + RSTD_ptr, + stride_xb, + stride_xh, + stride_rb, + stride_rh, + stride_yb, + stride_yh, + stride_xprb, + stride_xprh, + H: tl.constexpr, + eps, + BLOCK_H: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_H) + x_ptrs = X_ptr + row * stride_xb + offs * stride_xh + r_ptrs = R_ptr + row * stride_rb + offs * stride_rh + y_ptrs = Y_ptr + row * stride_yb + offs * stride_yh + xpr_ptrs = XPR_ptr + row * stride_xprb + offs * stride_xprh + g_ptrs = G_ptr + offs + mask = offs < H + + x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + r = tl.load(r_ptrs, mask=mask, other=0.0).to(tl.float32) + xpr = x + r + tl.store(xpr_ptrs, xpr.to(XPR_ptr.dtype.element_ty), mask=mask) + + var = tl.sum(xpr * xpr, axis=0) / H + rstd = tl.rsqrt(var + eps) + g = tl.load(g_ptrs, mask=mask, other=0.0).to(tl.float32) + y = (xpr * rstd * g).to(Y_ptr.dtype.element_ty) + tl.store(y_ptrs, y, mask=mask) + tl.store(RSTD_ptr + row, rstd) + + +@triton.jit +def rmsnorm_fwd_residual_kernel_multi_row( + X_ptr, + R_ptr, + G_ptr, + Y_ptr, + XPR_ptr, + RSTD_ptr, + stride_xb, + stride_xh, + stride_rb, + stride_rh, + stride_yb, + stride_yh, + stride_xprb, + stride_xprh, + B, + H: tl.constexpr, + eps, + BLOCK_H: tl.constexpr, + ROWS_PER_BLOCK: tl.constexpr, +): + pid = tl.program_id(0) + row_start = pid * ROWS_PER_BLOCK + row_offs = row_start + tl.arange(0, ROWS_PER_BLOCK) + row_mask = row_offs < B + + h_offs = tl.arange(0, BLOCK_H) + h_mask = h_offs < H + + x_ptrs = X_ptr + row_offs[:, None] * stride_xb + h_offs[None, :] * stride_xh + r_ptrs = R_ptr + row_offs[:, None] * stride_rb + h_offs[None, :] * stride_rh + y_ptrs = Y_ptr + row_offs[:, None] * stride_yb + h_offs[None, :] * stride_yh + xpr_ptrs = XPR_ptr + row_offs[:, None] * stride_xprb + h_offs[None, :] * stride_xprh + g_ptrs = G_ptr + h_offs + + full_mask = row_mask[:, None] & h_mask[None, :] + x = tl.load(x_ptrs, mask=full_mask, other=0.0).to(tl.float32) + r = tl.load(r_ptrs, mask=full_mask, other=0.0).to(tl.float32) + g = tl.load(g_ptrs, mask=h_mask, other=0.0).to(tl.float32) + + xpr = x + r + tl.store(xpr_ptrs, xpr.to(XPR_ptr.dtype.element_ty), mask=full_mask) + + var = tl.sum(xpr * xpr, axis=1) / H + rstd = tl.rsqrt(var + eps) + y = (xpr * rstd[:, None] * g[None, :]).to(Y_ptr.dtype.element_ty) + tl.store(y_ptrs, y, mask=full_mask) + tl.store(RSTD_ptr + row_offs, rstd, mask=row_mask) + + +# --------------------------------------------------------------------------- +# Backward — stage 0: per-row dx + per-row (or per-program) partial dgamma. +# --------------------------------------------------------------------------- +@triton.jit +def rmsnorm_bwd_kernel( + DY_ptr, + X_ptr, + G_ptr, + RSTD_ptr, + DX_ptr, + DG_PART_ptr, + stride_xb, + stride_xh, + stride_dyb, + stride_dyh, + stride_dxb, + stride_dxh, + stride_dgb, + H: tl.constexpr, + BLOCK_H: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_H) + mask = offs < H + + x_ptrs = X_ptr + row * stride_xb + offs * stride_xh + dy_ptrs = DY_ptr + row * stride_dyb + offs * stride_dyh + dx_ptrs = DX_ptr + row * stride_dxb + offs * stride_dxh + dgp_ptrs = DG_PART_ptr + row * stride_dgb + offs + g_ptrs = G_ptr + offs + + x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(dy_ptrs, mask=mask, other=0.0).to(tl.float32) + g = tl.load(g_ptrs, mask=mask, other=0.0).to(tl.float32) + rstd = tl.load(RSTD_ptr + row).to(tl.float32) + + x_hat = x * rstd + dxhat = dy * g + m = tl.sum(dxhat * x_hat, axis=0) / H + dx = (dxhat - x_hat * m) * rstd + dgp = dy * x_hat + + tl.store(dx_ptrs, dx.to(DX_ptr.dtype.element_ty), mask=mask) + tl.store(dgp_ptrs, dgp, mask=mask) + + +@triton.jit +def rmsnorm_bwd_kernel_multi_row( + DY_ptr, + X_ptr, + G_ptr, + RSTD_ptr, + DX_ptr, + DG_PART_ptr, + stride_xb, + stride_xh, + stride_dyb, + stride_dyh, + stride_dxb, + stride_dxh, + stride_dgp, + B, + H: tl.constexpr, + BLOCK_H: tl.constexpr, + ROWS_PER_BLOCK: tl.constexpr, +): + pid = tl.program_id(0) + row_start = pid * ROWS_PER_BLOCK + row_offs = row_start + tl.arange(0, ROWS_PER_BLOCK) + row_mask = row_offs < B + h_offs = tl.arange(0, BLOCK_H) + h_mask = h_offs < H + + x_ptrs = X_ptr + row_offs[:, None] * stride_xb + h_offs[None, :] * stride_xh + dy_ptrs = DY_ptr + row_offs[:, None] * stride_dyb + h_offs[None, :] * stride_dyh + dx_ptrs = DX_ptr + row_offs[:, None] * stride_dxb + h_offs[None, :] * stride_dxh + dgp_ptrs = DG_PART_ptr + pid * stride_dgp + h_offs + g_ptrs = G_ptr + h_offs + + full_mask = row_mask[:, None] & h_mask[None, :] + x = tl.load(x_ptrs, mask=full_mask, other=0.0).to(tl.float32) + dy = tl.load(dy_ptrs, mask=full_mask, other=0.0).to(tl.float32) + g = tl.load(g_ptrs, mask=h_mask, other=0.0).to(tl.float32) + rstd = tl.load(RSTD_ptr + row_offs, mask=row_mask, other=0.0).to(tl.float32) + + x_hat = x * rstd[:, None] + dxhat = dy * g[None, :] + m = tl.sum(dxhat * x_hat, axis=1) / H + dx = (dxhat - x_hat * m[:, None]) * rstd[:, None] + + tl.store(dx_ptrs, dx.to(DX_ptr.dtype.element_ty), mask=full_mask) + + # Per-program dgamma reduction — mask out-of-range rows to zero so any + # padding tail contributes nothing. Writes one fp32 [H] slab per program + # instead of ROWS_PER_BLOCK rows. + dgp_block = (dy * x_hat) * row_mask[:, None].to(tl.float32) + dgp_row = tl.sum(dgp_block, axis=0) + tl.store(dgp_ptrs, dgp_row, mask=h_mask) + + +# --------------------------------------------------------------------------- +# Backward — fused residual variant. Adds the gradient that flows through +# ``x_plus_r`` (consumed by the next residual add) to the standard RMSNorm dx. +# --------------------------------------------------------------------------- +@triton.jit +def rmsnorm_bwd_residual_kernel( + DY_ptr, + DXPR_ptr, + XPR_ptr, + G_ptr, + RSTD_ptr, + DX_ptr, + DG_PART_ptr, + stride_xprb, + stride_xprh, + stride_dyb, + stride_dyh, + stride_dxprb, + stride_dxprh, + stride_dxb, + stride_dxh, + stride_dgb, + H: tl.constexpr, + BLOCK_H: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_H) + mask = offs < H + + xpr_ptrs = XPR_ptr + row * stride_xprb + offs * stride_xprh + dy_ptrs = DY_ptr + row * stride_dyb + offs * stride_dyh + dxpr_ptrs = DXPR_ptr + row * stride_dxprb + offs * stride_dxprh + dx_ptrs = DX_ptr + row * stride_dxb + offs * stride_dxh + dgp_ptrs = DG_PART_ptr + row * stride_dgb + offs + g_ptrs = G_ptr + offs + + xpr = tl.load(xpr_ptrs, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(dy_ptrs, mask=mask, other=0.0).to(tl.float32) + dxpr = tl.load(dxpr_ptrs, mask=mask, other=0.0).to(tl.float32) + g = tl.load(g_ptrs, mask=mask, other=0.0).to(tl.float32) + rstd = tl.load(RSTD_ptr + row).to(tl.float32) + + x_hat = xpr * rstd + dxhat = dy * g + m = tl.sum(dxhat * x_hat, axis=0) / H + dx_norm = (dxhat - x_hat * m) * rstd + dx = dx_norm + dxpr + dgp = dy * x_hat + + tl.store(dx_ptrs, dx.to(DX_ptr.dtype.element_ty), mask=mask) + tl.store(dgp_ptrs, dgp, mask=mask) + + +@triton.jit +def rmsnorm_bwd_residual_kernel_multi_row( + DY_ptr, + DXPR_ptr, + XPR_ptr, + G_ptr, + RSTD_ptr, + DX_ptr, + DG_PART_ptr, + stride_xprb, + stride_xprh, + stride_dyb, + stride_dyh, + stride_dxprb, + stride_dxprh, + stride_dxb, + stride_dxh, + stride_dgp, + B, + H: tl.constexpr, + BLOCK_H: tl.constexpr, + ROWS_PER_BLOCK: tl.constexpr, +): + pid = tl.program_id(0) + row_start = pid * ROWS_PER_BLOCK + row_offs = row_start + tl.arange(0, ROWS_PER_BLOCK) + row_mask = row_offs < B + h_offs = tl.arange(0, BLOCK_H) + h_mask = h_offs < H + + xpr_ptrs = XPR_ptr + row_offs[:, None] * stride_xprb + h_offs[None, :] * stride_xprh + dy_ptrs = DY_ptr + row_offs[:, None] * stride_dyb + h_offs[None, :] * stride_dyh + dxpr_ptrs = DXPR_ptr + row_offs[:, None] * stride_dxprb + h_offs[None, :] * stride_dxprh + dx_ptrs = DX_ptr + row_offs[:, None] * stride_dxb + h_offs[None, :] * stride_dxh + dgp_ptrs = DG_PART_ptr + pid * stride_dgp + h_offs + g_ptrs = G_ptr + h_offs + + full_mask = row_mask[:, None] & h_mask[None, :] + xpr = tl.load(xpr_ptrs, mask=full_mask, other=0.0).to(tl.float32) + dy = tl.load(dy_ptrs, mask=full_mask, other=0.0).to(tl.float32) + dxpr = tl.load(dxpr_ptrs, mask=full_mask, other=0.0).to(tl.float32) + g = tl.load(g_ptrs, mask=h_mask, other=0.0).to(tl.float32) + rstd = tl.load(RSTD_ptr + row_offs, mask=row_mask, other=0.0).to(tl.float32) + + x_hat = xpr * rstd[:, None] + dxhat = dy * g[None, :] + m = tl.sum(dxhat * x_hat, axis=1) / H + dx_norm = (dxhat - x_hat * m[:, None]) * rstd[:, None] + dx = dx_norm + dxpr + + tl.store(dx_ptrs, dx.to(DX_ptr.dtype.element_ty), mask=full_mask) + + dgp_block = (dy * x_hat) * row_mask[:, None].to(tl.float32) + dgp_row = tl.sum(dgp_block, axis=0) + tl.store(dgp_ptrs, dgp_row, mask=h_mask) diff --git a/tests/pytorch/ops/test_normalization.py b/tests/pytorch/ops/test_normalization.py index c89224e30..90aa0d373 100644 --- a/tests/pytorch/ops/test_normalization.py +++ b/tests/pytorch/ops/test_normalization.py @@ -8,14 +8,13 @@ import torch import torch.nn.functional as F -from primus_turbo.pytorch.ops.normalization import rmsnorm +from primus_turbo.pytorch.ops.normalization import rmsnorm, rmsnorm_residual from tests.pytorch.test_utils import get_tolerances -# @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("dtype", [torch.float32]) -@pytest.mark.parametrize("outer_shape", [(1,), (511,), (4096,), (8192,), (16384,)]) -@pytest.mark.parametrize("inner_shape", [33, 513, 4096, 5120, 7168, 8192]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("outer_shape", [(1,), (511,), (4096,), (8192,)]) +@pytest.mark.parametrize("inner_shape", [33, 128, 513, 4096, 5120, 7168, 8192]) def test_rmsnorm_ops(dtype, outer_shape, inner_shape): torch.manual_seed(1) device = "cuda:0" @@ -31,8 +30,6 @@ def test_rmsnorm_ops(dtype, outer_shape, inner_shape): y_ref = F.rms_norm(x_ref, [inner_shape], gamma_ref, eps) y = rmsnorm(x, gamma, eps) - # print(y_ref, y_ref.shape) - # print(y, y.shape) torch.testing.assert_close(y_ref, y, **get_tolerances(dtype)) # Backward @@ -40,11 +37,47 @@ def test_rmsnorm_ops(dtype, outer_shape, inner_shape): y.backward(grad_out) y_ref.backward(grad_out) - # print(x.grad) - # print(x_ref.grad) + torch.testing.assert_close(x.grad, x_ref.grad, **get_tolerances(dtype)) + torch.testing.assert_close(gamma.grad, gamma_ref.grad, **get_tolerances(dtype)) + - # print(gamma.grad) - # print(gamma_ref.grad) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("outer_shape", [(1,), (511,), (4096,)]) +@pytest.mark.parametrize("inner_shape", [128, 4096, 8192]) +def test_rmsnorm_residual_ops(dtype, outer_shape, inner_shape): + torch.manual_seed(2) + device = "cuda:0" + eps = 1e-6 + + shape = outer_shape + (inner_shape,) + x = torch.randn(shape, dtype=dtype, device=device, requires_grad=True) + residual = torch.randn(shape, dtype=dtype, device=device, requires_grad=True) + gamma = torch.randn(inner_shape, dtype=dtype, device=device, requires_grad=True) + + x_ref = x.detach().clone().requires_grad_() + r_ref = residual.detach().clone().requires_grad_() + gamma_ref = gamma.detach().clone().requires_grad_() + + # Forward + h_ref = x_ref + r_ref + y_ref = F.rms_norm(h_ref, [inner_shape], gamma_ref, eps) + + y, x_plus_r = rmsnorm_residual(x, residual, gamma, eps) + + torch.testing.assert_close(x_plus_r, h_ref.detach(), **get_tolerances(dtype)) + torch.testing.assert_close(y, y_ref, **get_tolerances(dtype)) + + # Backward — only flow gradient through y so dxpr = 0. + grad_out = torch.randn_like(y) + y.backward(grad_out) + y_ref.backward(grad_out) torch.testing.assert_close(x.grad, x_ref.grad, **get_tolerances(dtype)) - torch.testing.assert_close(gamma.grad, gamma_ref.grad, **get_tolerances(dtype)) + torch.testing.assert_close(residual.grad, r_ref.grad, **get_tolerances(dtype)) + # ``dgamma`` is a reduction over ``B`` rows; the residual variant doubles the + # input magnitude entering the sum, so for low-precision dtypes we use a + # slightly looser tolerance to absorb the extra bf16/fp16 reduction noise. + dg_tol = get_tolerances(dtype) + if dtype in (torch.float16, torch.bfloat16): + dg_tol = dict(rtol=3e-2, atol=3e-2) + torch.testing.assert_close(gamma.grad, gamma_ref.grad, **dg_tol) From e8eee833093b1e541f1ca5ba1cbac31d0455802f Mon Sep 17 00:00:00 2001 From: RuibinCheung Date: Wed, 20 May 2026 01:07:40 +0000 Subject: [PATCH 2/2] fix --- csrc/pytorch/bindings_pytorch.cpp | 12 ------------ csrc/pytorch/extensions.h | 14 -------------- 2 files changed, 26 deletions(-) diff --git a/csrc/pytorch/bindings_pytorch.cpp b/csrc/pytorch/bindings_pytorch.cpp index 349dfa06e..d485f12e6 100644 --- a/csrc/pytorch/bindings_pytorch.cpp +++ b/csrc/pytorch/bindings_pytorch.cpp @@ -58,10 +58,6 @@ TORCH_LIBRARY(primus_turbo_cpp_extension, m) { m.def("shuffle_scale(Tensor scale, int[] layout) -> Tensor"); m.def("shuffle_weight(Tensor weight, int[] layout) -> Tensor"); - // ********* RMSNorm ********* - m.def("rmsnorm_fwd(Tensor input, Tensor gamma, float eps) -> Tensor"); - m.def("rmsnorm_bwd(Tensor input, Tensor gamma, Tensor grad_out, float eps) -> Tensor[]"); - // ********* Grouped Gemm ********* m.def("ck_grouped_gemm(Tensor a, Tensor b, Tensor group_lens, Tensor group_offs, bool transA, " "bool transB, int? num_cu=None) -> Tensor"); @@ -105,10 +101,6 @@ TORCH_LIBRARY_IMPL(primus_turbo_cpp_extension, CUDA, m) { m.impl("shuffle_scale", shuffle_scale_impl); m.impl("shuffle_weight", shuffle_weight_impl); - // ********* RMSNorm ********* - m.impl("rmsnorm_fwd", rmsnorm_fwd); - m.impl("rmsnorm_bwd", rmsnorm_bwd); - // ********* Grouped Gemm ********* m.impl("ck_grouped_gemm", ck_grouped_gemm); m.impl("ck_grouped_gemm_variable_k", ck_grouped_gemm_variable_k); @@ -143,10 +135,6 @@ TORCH_LIBRARY_IMPL(primus_turbo_cpp_extension, Meta, m) { m.impl("shuffle_scale", shuffle_scale_impl_meta); m.impl("shuffle_weight", shuffle_weight_impl_meta); - // ********* RMSNorm ********* - m.impl("rmsnorm_fwd", rmsnorm_fwd_meta); - m.impl("rmsnorm_bwd", rmsnorm_bwd_meta); - // ********* Grouped Gemm ********* m.impl("ck_grouped_gemm", ck_grouped_gemm_meta); m.impl("ck_grouped_gemm_variable_k", ck_grouped_gemm_variable_k_meta); diff --git a/csrc/pytorch/extensions.h b/csrc/pytorch/extensions.h index b7c09f85e..30520e6f6 100644 --- a/csrc/pytorch/extensions.h +++ b/csrc/pytorch/extensions.h @@ -153,20 +153,6 @@ at::Tensor turbo_gemm_fp8_meta(at::Tensor A, at::Tensor scaleA_inv, at::Tensor B at::Tensor scaleB_inv, const at::ScalarType out_dtype, bool transA, bool transB, bool transC, const std::string &granularity); -//================================================================== -// Normalization -//================================================================== - -at::Tensor rmsnorm_fwd(const at::Tensor &input, const at::Tensor &gamma, const double eps); - -at::Tensor rmsnorm_fwd_meta(const at::Tensor &input, const at::Tensor &gamma, const double eps); - -std::vector rmsnorm_bwd(const at::Tensor &input, const at::Tensor &gamma, - const at::Tensor &grad_output, const double eps); - -std::vector rmsnorm_bwd_meta(const at::Tensor &input, const at::Tensor &gamma, - const at::Tensor &grad_output, const double eps); - //================================================================== // Grouped GEMM //==================================================================