Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 154 additions & 9 deletions atom/model_ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import aiter
import torch
import triton
import triton.language as tl
from aiter import (
QuantType,
layernorm2d_fwd,
Expand Down Expand Up @@ -288,6 +290,61 @@ def forward(
return x, residual


# decode
@triton.jit
def _rmsnorm_gated_contiguous_128_kernel(
x_ptr,
z_ptr,
weight_ptr,
out_ptr,
num_heads: tl.constexpr,
eps: tl.constexpr,
):
token_id = tl.program_id(0)
head_id = tl.program_id(1)
offsets = tl.arange(0, 128)
row_offset = (token_id * num_heads + head_id) * 128

x = tl.load(x_ptr + row_offset + offsets, cache_modifier=".ca").to(tl.float32)
z = tl.load(z_ptr + row_offset + offsets, cache_modifier=".ca").to(tl.float32)
weight = tl.load(weight_ptr + offsets, cache_modifier=".ca").to(tl.float32)

variance = tl.sum(x * x, axis=0) * 0.0078125
inv_rms = tl.rsqrt(variance + eps)
gate = z * tl.sigmoid(z)
out = x * inv_rms * weight * gate

tl.store(out_ptr + row_offset + offsets, out)


# prefill
@triton.jit
def _rmsnorm_gated_contiguous_128_tiled_rows_kernel(
x_ptr,
z_ptr,
weight_ptr,
out_ptr,
num_rows: tl.constexpr,
eps: tl.constexpr,
block_rows: tl.constexpr,
):
row_offsets = tl.program_id(0) * block_rows + tl.arange(0, block_rows)
dim_offsets = tl.arange(0, 128)
mask_rows = row_offsets < num_rows
offsets = row_offsets[:, None] * 128 + dim_offsets[None, :]

x = tl.load(x_ptr + offsets, mask=mask_rows[:, None], other=0.0).to(tl.float32)
z = tl.load(z_ptr + offsets, mask=mask_rows[:, None], other=0.0).to(tl.float32)
weight = tl.load(weight_ptr + dim_offsets, cache_modifier=".ca").to(tl.float32)

variance = tl.sum(x * x, axis=1) * 0.0078125
inv_rms = tl.rsqrt(variance + eps)
gate = z * tl.sigmoid(z)
out = x * inv_rms[:, None] * weight[None, :] * gate

tl.store(out_ptr + offsets, out, mask=mask_rows[:, None])


class RMSNormGated(nn.Module):
"""RMS Normalization with optional gating.

Expand Down Expand Up @@ -360,6 +417,81 @@ def __init__(
def reset_parameters(self):
torch.nn.init.ones_(self.weight)

def forward_triton(self, x: torch.Tensor, z: torch.Tensor):
if (
z is None
or x.ndim != 3
or self.group_size is not None
or not self.norm_before_gate
or x.shape[-1] != 128
or not x.is_contiguous()
or not z.is_contiguous()
):
return self.forward_native(x, z)

num_tokens, num_heads, head_dim = x.shape
out = torch.empty(
(num_tokens, num_heads * head_dim),
dtype=x.dtype,
device=x.device,
)

num_rows = num_tokens * num_heads
if num_rows >= 65536:
block_rows = 32
_rmsnorm_gated_contiguous_128_tiled_rows_kernel[
(triton.cdiv(num_rows, block_rows),)
](
x,
z,
self.weight,
out,
num_rows,
self.eps,
block_rows,
num_warps=4,
num_stages=1,
)
else:
_rmsnorm_gated_contiguous_128_kernel[(num_tokens, num_heads)](
x,
z,
self.weight,
out,
num_heads,
self.eps,
num_warps=1,
num_stages=1,
)

return (out, None)

def forward_aiter_gated(self, x: torch.Tensor, z: torch.Tensor):
try:
from aiter.ops.gated_rmsnorm_fp8_group_quant import gated_rmsnorm_bf16
except (ImportError, AttributeError):
return self.forward_triton(x, z)

if (
z is None
or x.ndim != 3
or self.group_size is not None
or not self.norm_before_gate
or x.shape[-1] != 128
or not x.is_contiguous()
or not z.is_contiguous()
):
return self.forward_native(x, z)

num_tokens, num_heads, head_dim = x.shape
out = torch.empty(
(num_tokens, num_heads * head_dim),
dtype=x.dtype,
device=x.device,
)
gated_rmsnorm_bf16(out, x, z, self.weight, self.eps)
return (out, None)

def forward_native(
self, x: torch.Tensor, z: torch.Tensor
) -> tuple[torch.Tensor, None]:
Expand Down Expand Up @@ -479,7 +611,7 @@ def forward(
if self.use_fused_fp8_quant:
return self.forward_fused_fp8(x, z)

return self.forward_native(x, z)
return self.forward_aiter_gated(x, z)


class GemmaRMSNorm(nn.Module):
Expand Down Expand Up @@ -542,16 +674,33 @@ def forward_native(
"""PyTorch-native implementation equivalent to forward()."""
return self.forward_static(self.weight.data, self.variance_epsilon, x, residual)

def _get_aiter_weight(self) -> torch.Tensor:
"""Cache weight + 1.0 for aiter rmsnorm (which uses x*w, not x*(1+w))."""
if not hasattr(self, "_aiter_weight_cache") or self._aiter_weight_cache is None:
self._aiter_weight_cache = self.weight.data + 1.0
return self._aiter_weight_cache
Comment on lines +677 to +681
Comment on lines +677 to +681

def forward_cuda(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from atom.model_ops.triton_gemma_rmsnorm import gemma_rmsnorm_triton
from aiter import rmsnorm2d_fwd, rmsnorm2d_fwd_with_add

return gemma_rmsnorm_triton(
x, self.weight.data, self.variance_epsilon, residual
)
weight = self._get_aiter_weight()
ori_shape = x.shape
x = x.view(-1, ori_shape[-1])

if residual is not None:
residual = residual.view(-1, ori_shape[-1])
out = torch.empty_like(x)
residual_out = torch.empty_like(residual)
rmsnorm2d_fwd_with_add(
out, x, residual, residual_out, weight, self.variance_epsilon
)
return out.view(ori_shape), residual_out.view(ori_shape)

return rmsnorm2d_fwd(x, weight, self.variance_epsilon).view(ori_shape)
Comment on lines +692 to +703

def _forward_fused_fp8(self, x, residual=None):
from aiter.ops.fused_qk_rmsnorm_group_quant import fused_qk_rmsnorm_group_quant
Expand Down Expand Up @@ -603,10 +752,6 @@ def forward(
# ---------------------------------------------------------------------------
# Fused Q/K RMSNorm Triton kernel
# ---------------------------------------------------------------------------
import triton # noqa: E402
import triton.language as tl # noqa: E402


@triton.jit
def _fused_qk_norm_single_kernel(
q_ptr,
Expand Down
Loading
Loading