From f3be04795140fa8dc89760175ad6a28e3b67b5f1 Mon Sep 17 00:00:00 2001 From: zovonoir Date: Thu, 7 May 2026 10:59:18 +0800 Subject: [PATCH] perf: fused Triton kernels for Qwen3.5 RMSNorm and MRoPE MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three performance optimizations for Qwen3.5 (GDN + full attention): 1. RMSNormGated: add Triton fused kernel (decode + prefill paths) and aiter HIP kernel integration with fallback chain 2. GemmaRMSNorm: replace torch.compile with direct aiter rmsnorm2d_fwd calls, cache (weight + 1.0) to avoid per-call allocation 3. MRoPE: add fused Triton kernel that applies multi-resolution RoPE to Q and K in a single launch (specialized for head_size=256, rotary_dim=64, Neox-style) 4. qwen3_next: integrate MRoPE fusion, remove pre-allocated output pattern (output[:] = ...) in favor of direct return Benchmark (MI308X, Qwen3.5-27B, TP2, ISL=60381, OSL=132): | Version | TPOT(ms) | TTFT(ms) | Tput/GPU | |-------------------------|----------|----------|----------| | ATOM baseline | 18.7 | 15266 | 1574 | | + RMSNormGated Triton | 16.7 | 15247 | 1597 | | + this patch | 15.6 | 15300 | 1605 | TPOT reduced by 17% end-to-end (18.7 → 15.6 ms). Co-Authored-By: Claude Opus 4.6 (1M context) --- atom/model_ops/layernorm.py | 165 ++++++++++++++++-- atom/model_ops/triton_mrope.py | 297 +++++++++++++++++++++++++++++++++ atom/models/qwen3_next.py | 38 +++-- 3 files changed, 472 insertions(+), 28 deletions(-) create mode 100644 atom/model_ops/triton_mrope.py diff --git a/atom/model_ops/layernorm.py b/atom/model_ops/layernorm.py index 4e8aeeaf8..c29e2f065 100644 --- a/atom/model_ops/layernorm.py +++ b/atom/model_ops/layernorm.py @@ -5,6 +5,8 @@ import aiter import torch +import triton +import triton.language as tl from aiter import ( QuantType, layernorm2d_fwd, @@ -288,6 +290,61 @@ def forward( return x, residual +# decode +@triton.jit +def _rmsnorm_gated_contiguous_128_kernel( + x_ptr, + z_ptr, + weight_ptr, + out_ptr, + num_heads: tl.constexpr, + eps: tl.constexpr, +): + token_id = tl.program_id(0) + head_id = tl.program_id(1) + offsets = tl.arange(0, 128) + row_offset = (token_id * num_heads + head_id) * 128 + + x = tl.load(x_ptr + row_offset + offsets, cache_modifier=".ca").to(tl.float32) + z = tl.load(z_ptr + row_offset + offsets, cache_modifier=".ca").to(tl.float32) + weight = tl.load(weight_ptr + offsets, cache_modifier=".ca").to(tl.float32) + + variance = tl.sum(x * x, axis=0) * 0.0078125 + inv_rms = tl.rsqrt(variance + eps) + gate = z * tl.sigmoid(z) + out = x * inv_rms * weight * gate + + tl.store(out_ptr + row_offset + offsets, out) + + +# prefill +@triton.jit +def _rmsnorm_gated_contiguous_128_tiled_rows_kernel( + x_ptr, + z_ptr, + weight_ptr, + out_ptr, + num_rows: tl.constexpr, + eps: tl.constexpr, + block_rows: tl.constexpr, +): + row_offsets = tl.program_id(0) * block_rows + tl.arange(0, block_rows) + dim_offsets = tl.arange(0, 128) + mask_rows = row_offsets < num_rows + offsets = row_offsets[:, None] * 128 + dim_offsets[None, :] + + x = tl.load(x_ptr + offsets, mask=mask_rows[:, None], other=0.0).to(tl.float32) + z = tl.load(z_ptr + offsets, mask=mask_rows[:, None], other=0.0).to(tl.float32) + weight = tl.load(weight_ptr + dim_offsets, cache_modifier=".ca").to(tl.float32) + + variance = tl.sum(x * x, axis=1) * 0.0078125 + inv_rms = tl.rsqrt(variance + eps) + gate = z * tl.sigmoid(z) + out = x * inv_rms[:, None] * weight[None, :] * gate + + tl.store(out_ptr + offsets, out, mask=mask_rows[:, None]) + + class RMSNormGated(nn.Module): """RMS Normalization with optional gating. @@ -360,6 +417,81 @@ def __init__( def reset_parameters(self): torch.nn.init.ones_(self.weight) + def forward_triton(self, x: torch.Tensor, z: torch.Tensor): + if ( + z is None + or x.ndim != 3 + or self.group_size is not None + or not self.norm_before_gate + or x.shape[-1] != 128 + or not x.is_contiguous() + or not z.is_contiguous() + ): + return self.forward_native(x, z) + + num_tokens, num_heads, head_dim = x.shape + out = torch.empty( + (num_tokens, num_heads * head_dim), + dtype=x.dtype, + device=x.device, + ) + + num_rows = num_tokens * num_heads + if num_rows >= 65536: + block_rows = 32 + _rmsnorm_gated_contiguous_128_tiled_rows_kernel[ + (triton.cdiv(num_rows, block_rows),) + ]( + x, + z, + self.weight, + out, + num_rows, + self.eps, + block_rows, + num_warps=4, + num_stages=1, + ) + else: + _rmsnorm_gated_contiguous_128_kernel[(num_tokens, num_heads)]( + x, + z, + self.weight, + out, + num_heads, + self.eps, + num_warps=1, + num_stages=1, + ) + + return (out, None) + + def forward_aiter_gated(self, x: torch.Tensor, z: torch.Tensor): + try: + from aiter.ops.gated_rmsnorm_fp8_group_quant import gated_rmsnorm_bf16 + except (ImportError, AttributeError): + return self.forward_triton(x, z) + + if ( + z is None + or x.ndim != 3 + or self.group_size is not None + or not self.norm_before_gate + or x.shape[-1] != 128 + or not x.is_contiguous() + or not z.is_contiguous() + ): + return self.forward_native(x, z) + + num_tokens, num_heads, head_dim = x.shape + out = torch.empty( + (num_tokens, num_heads * head_dim), + dtype=x.dtype, + device=x.device, + ) + gated_rmsnorm_bf16(out, x, z, self.weight, self.eps) + return (out, None) + def forward_native( self, x: torch.Tensor, z: torch.Tensor ) -> tuple[torch.Tensor, None]: @@ -479,7 +611,7 @@ def forward( if self.use_fused_fp8_quant: return self.forward_fused_fp8(x, z) - return self.forward_native(x, z) + return self.forward_aiter_gated(x, z) class GemmaRMSNorm(nn.Module): @@ -542,18 +674,33 @@ def forward_native( """PyTorch-native implementation equivalent to forward().""" return self.forward_static(self.weight.data, self.variance_epsilon, x, residual) + def _get_aiter_weight(self) -> torch.Tensor: + """Cache weight + 1.0 for aiter rmsnorm (which uses x*w, not x*(1+w)).""" + if not hasattr(self, "_aiter_weight_cache") or self._aiter_weight_cache is None: + self._aiter_weight_cache = self.weight.data + 1.0 + return self._aiter_weight_cache + def forward_cuda( self, x: torch.Tensor, residual: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if torch.compiler.is_compiling(): - return self.forward_native(x, residual) + from aiter import rmsnorm2d_fwd, rmsnorm2d_fwd_with_add - if not getattr(self, "_is_compiled", False): - self.forward_static = torch.compile(self.forward_static) # type: ignore - self._is_compiled = True - return self.forward_native(x, residual) + weight = self._get_aiter_weight() + ori_shape = x.shape + x = x.view(-1, ori_shape[-1]) + + if residual is not None: + residual = residual.view(-1, ori_shape[-1]) + out = torch.empty_like(x) + residual_out = torch.empty_like(residual) + rmsnorm2d_fwd_with_add( + out, x, residual, residual_out, weight, self.variance_epsilon + ) + return out.view(ori_shape), residual_out.view(ori_shape) + + return rmsnorm2d_fwd(x, weight, self.variance_epsilon).view(ori_shape) def _forward_fused_fp8(self, x, residual=None): from aiter.ops.fused_qk_rmsnorm_group_quant import fused_qk_rmsnorm_group_quant @@ -605,10 +752,6 @@ def forward( # --------------------------------------------------------------------------- # Fused Q/K RMSNorm Triton kernel # --------------------------------------------------------------------------- -import triton # noqa: E402 -import triton.language as tl # noqa: E402 - - @triton.jit def _fused_qk_norm_single_kernel( q_ptr, diff --git a/atom/model_ops/triton_mrope.py b/atom/model_ops/triton_mrope.py new file mode 100644 index 000000000..9ee245223 --- /dev/null +++ b/atom/model_ops/triton_mrope.py @@ -0,0 +1,297 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +"""Triton kernels for Qwen3.5 MRoPE. + +This module currently specializes the hot MRoPE path used by Qwen3.5: +positions are 3D T/H/W ids, RoPE is Neox-style, head_size=256, and +rotary_dim=64. Unsupported shapes return ``None`` so callers can fall back to +the generic rotary embedding implementation. +""" + +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from torch import nn + + +@triton.jit +def _mrope_qk_kernel( + q_ptr, + k_ptr, + q_out_ptr, + k_out_ptr, + positions_ptr, + cos_ptr, + sin_ptr, + q_stride_t: tl.constexpr, + k_stride_t: tl.constexpr, + q_out_stride_t: tl.constexpr, + k_out_stride_t: tl.constexpr, + pos_stride_row: tl.constexpr, + cos_stride_pos: tl.constexpr, + sin_stride_pos: tl.constexpr, + num_q_heads: tl.constexpr, + num_k_heads: tl.constexpr, + head_size: tl.constexpr, + rotary_dim: tl.constexpr, + rotary_half: tl.constexpr, + section_h: tl.constexpr, + section_w: tl.constexpr, + block_d: tl.constexpr, +): + pid = tl.program_id(0) + total_heads = num_q_heads + num_k_heads + token_id = pid // total_heads + head_id = pid - token_id * total_heads + d = tl.arange(0, block_d) + mask = d < head_size + + is_q = head_id < num_q_heads + local_head = tl.where(is_q, head_id, head_id - num_q_heads) + + q_base_in = token_id * q_stride_t + local_head * head_size + k_base_in = token_id * k_stride_t + local_head * head_size + q_base_out = token_id * q_out_stride_t + local_head * head_size + k_base_out = token_id * k_out_stride_t + local_head * head_size + + x_q = tl.load(q_ptr + q_base_in + d, mask=mask & is_q, other=0.0).to(tl.float32) + x_k = tl.load(k_ptr + k_base_in + d, mask=mask & ~is_q, other=0.0).to(tl.float32) + x = x_q + x_k + + rot_mask = d < rotary_dim + first_half = d < rotary_half + freq_idx = tl.where(first_half, d, d - rotary_half) + pair_d = tl.where( + first_half, + d + rotary_half, + tl.where(d < rotary_dim, d - rotary_half, d), + ) + pair_q = tl.load(q_ptr + q_base_in + pair_d, mask=mask & is_q, other=0.0).to( + tl.float32 + ) + pair_k = tl.load(k_ptr + k_base_in + pair_d, mask=mask & ~is_q, other=0.0).to( + tl.float32 + ) + pair = pair_q + pair_k + + pos_t = tl.load(positions_ptr + token_id) + pos_h = tl.load(positions_ptr + pos_stride_row + token_id) + pos_w = tl.load(positions_ptr + 2 * pos_stride_row + token_id) + + use_h = ((freq_idx % 3) == 1) & (freq_idx < section_h * 3) + use_w = ((freq_idx % 3) == 2) & (freq_idx < section_w * 3) + pos = tl.where(use_h, pos_h, tl.where(use_w, pos_w, pos_t)) + + cos = tl.load( + cos_ptr + pos * cos_stride_pos + freq_idx, + mask=rot_mask, + other=1.0, + ).to(tl.float32) + sin = tl.load( + sin_ptr + pos * sin_stride_pos + freq_idx, + mask=rot_mask, + other=0.0, + ).to(tl.float32) + + rotated = tl.where(first_half, -pair, pair) + out = tl.where(rot_mask, x * cos + rotated * sin, x) + + tl.store(q_out_ptr + q_base_out + d, out, mask=mask & is_q) + tl.store(k_out_ptr + k_base_out + d, out, mask=mask & ~is_q) + + +@triton.jit +def _mrope_qk_tiled_kernel( + q_ptr, + k_ptr, + q_out_ptr, + k_out_ptr, + positions_ptr, + cos_ptr, + sin_ptr, + q_stride_t: tl.constexpr, + k_stride_t: tl.constexpr, + q_out_stride_t: tl.constexpr, + k_out_stride_t: tl.constexpr, + pos_stride_row: tl.constexpr, + cos_stride_pos: tl.constexpr, + sin_stride_pos: tl.constexpr, + num_tokens: tl.constexpr, + num_q_heads: tl.constexpr, + num_k_heads: tl.constexpr, + head_size: tl.constexpr, + rotary_dim: tl.constexpr, + rotary_half: tl.constexpr, + section_h: tl.constexpr, + section_w: tl.constexpr, + block_t: tl.constexpr, + block_d: tl.constexpr, +): + token_block = tl.program_id(0) + head_id = tl.program_id(1) + rows = token_block * block_t + tl.arange(0, block_t) + d = tl.arange(0, block_d) + row_mask = rows < num_tokens + d_mask = d < head_size + + is_q = head_id < num_q_heads + local_head = tl.where(is_q, head_id, head_id - num_q_heads) + + offsets_q = rows[:, None] * q_stride_t + local_head * head_size + d[None, :] + offsets_k = rows[:, None] * k_stride_t + local_head * head_size + d[None, :] + mask = row_mask[:, None] & d_mask[None, :] + + x_q = tl.load(q_ptr + offsets_q, mask=mask & is_q, other=0.0).to(tl.float32) + x_k = tl.load(k_ptr + offsets_k, mask=mask & ~is_q, other=0.0).to(tl.float32) + x = x_q + x_k + + rot_mask = d < rotary_dim + first_half = d < rotary_half + freq_idx = tl.where(first_half, d, d - rotary_half) + pair_d = tl.where( + first_half, + d + rotary_half, + tl.where(d < rotary_dim, d - rotary_half, d), + ) + pair_offsets_q = ( + rows[:, None] * q_stride_t + local_head * head_size + pair_d[None, :] + ) + pair_offsets_k = ( + rows[:, None] * k_stride_t + local_head * head_size + pair_d[None, :] + ) + pair_q = tl.load(q_ptr + pair_offsets_q, mask=mask & is_q, other=0.0).to(tl.float32) + pair_k = tl.load(k_ptr + pair_offsets_k, mask=mask & ~is_q, other=0.0).to( + tl.float32 + ) + pair = pair_q + pair_k + + pos_t = tl.load(positions_ptr + rows, mask=row_mask, other=0) + pos_h = tl.load(positions_ptr + pos_stride_row + rows, mask=row_mask, other=0) + pos_w = tl.load(positions_ptr + 2 * pos_stride_row + rows, mask=row_mask, other=0) + + use_h = ((freq_idx % 3) == 1) & (freq_idx < section_h * 3) + use_w = ((freq_idx % 3) == 2) & (freq_idx < section_w * 3) + pos = tl.where( + use_h[None, :], + pos_h[:, None], + tl.where(use_w[None, :], pos_w[:, None], pos_t[:, None]), + ) + + cos = tl.load( + cos_ptr + pos * cos_stride_pos + freq_idx[None, :], + mask=row_mask[:, None] & rot_mask[None, :], + other=1.0, + ).to(tl.float32) + sin = tl.load( + sin_ptr + pos * sin_stride_pos + freq_idx[None, :], + mask=row_mask[:, None] & rot_mask[None, :], + other=0.0, + ).to(tl.float32) + + rotated = tl.where(first_half[None, :], -pair, pair) + out = tl.where(rot_mask[None, :], x * cos + rotated * sin, x) + + out_offsets_q = rows[:, None] * q_out_stride_t + local_head * head_size + d[None, :] + out_offsets_k = rows[:, None] * k_out_stride_t + local_head * head_size + d[None, :] + tl.store(q_out_ptr + out_offsets_q, out, mask=mask & is_q) + tl.store(k_out_ptr + out_offsets_k, out, mask=mask & ~is_q) + + +def try_mrope_qk_fused( + rotary_emb: nn.Module, + positions: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + num_q_heads: int, + num_k_heads: int, + head_size: int, +) -> tuple[torch.Tensor, torch.Tensor] | None: + """Try the specialized Qwen3.5 MRoPE Triton path. + + Returns ``None`` for unsupported shapes so callers can fall back to the + generic rotary embedding module. + """ + mrope_section = getattr(rotary_emb, "mrope_section", None) + if ( + positions.ndim != 2 + or mrope_section is None + or not getattr(rotary_emb, "mrope_interleaved", False) + or head_size != 256 + or getattr(rotary_emb, "rotary_dim", None) != 64 + or not getattr(rotary_emb, "is_neox_style", False) + or q.ndim != 2 + or k.ndim != 2 + or q.shape[1] != num_q_heads * head_size + or k.shape[1] != num_k_heads * head_size + ): + return None + + q_out = torch.empty_like(q) + k_out = torch.empty_like(k) + num_tokens = positions.shape[1] + block_d = triton.next_power_of_2(head_size) + cos_cache = rotary_emb.cos_cache + sin_cache = rotary_emb.sin_cache + + if num_tokens >= 128: + block_t = 16 + _mrope_qk_tiled_kernel[ + (triton.cdiv(num_tokens, block_t), num_q_heads + num_k_heads) + ]( + q, + k, + q_out, + k_out, + positions, + cos_cache, + sin_cache, + q.stride(0), + k.stride(0), + q_out.stride(0), + k_out.stride(0), + positions.stride(0), + cos_cache.stride(0), + sin_cache.stride(0), + num_tokens, + num_q_heads, + num_k_heads, + head_size, + 64, + 32, + mrope_section[1], + mrope_section[2], + block_t, + block_d, + num_warps=8, + num_stages=1, + ) + else: + _mrope_qk_kernel[(num_tokens * (num_q_heads + num_k_heads),)]( + q, + k, + q_out, + k_out, + positions, + cos_cache, + sin_cache, + q.stride(0), + k.stride(0), + q_out.stride(0), + k_out.stride(0), + positions.stride(0), + cos_cache.stride(0), + sin_cache.stride(0), + num_q_heads, + num_k_heads, + head_size, + 64, + 32, + mrope_section[1], + mrope_section[2], + block_d, + num_warps=8, + num_stages=1, + ) + return q_out, k_out diff --git a/atom/models/qwen3_next.py b/atom/models/qwen3_next.py index 8a528e75f..dbe06a487 100644 --- a/atom/models/qwen3_next.py +++ b/atom/models/qwen3_next.py @@ -34,6 +34,7 @@ fused_split_chunk_qwen_next_qkvzba, ) from atom.model_ops.topK import is_rocm_aiter_fusion_shared_expert_enabled +from atom.model_ops.triton_mrope import try_mrope_qk_fused from atom.model_ops.utils import atom_parameter from atom.models.utils import ( IntermediateTensors, @@ -405,7 +406,6 @@ def __init__( def forward( self, positions: torch.Tensor, - output: torch.Tensor, hidden_states: torch.Tensor, x_scale=None, ) -> torch.Tensor: @@ -431,7 +431,19 @@ def forward( ) else: q, k = self.qk_norm(q, k) - q, k = self.rotary_emb(positions, q, k) + fused_qk = try_mrope_qk_fused( + self.rotary_emb, + positions, + q, + k, + self.num_heads, + self.num_kv_heads, + self.head_dim, + ) + if fused_qk is None: + q, k = self.rotary_emb(positions, q, k) + else: + q, k = fused_qk attn_output = self.attn(q, k, v) if self.use_fused_sigmoid_mul_quant: @@ -440,13 +452,13 @@ def forward( ) attn_output, attn_scale = fused_sigmoid_mul_fp8_quant(attn_output, gate) - output[:] = self.o_proj(attn_output, x_scale=attn_scale) + output = self.o_proj(attn_output, x_scale=attn_scale) elif self.attn_output_gate: gate = torch.sigmoid(gate) attn_output = attn_output * gate - output[:] = self.o_proj(attn_output) + output = self.o_proj(attn_output) else: - output[:] = self.o_proj(attn_output) + output = self.o_proj(attn_output) return output @@ -714,7 +726,6 @@ def rearrange_mixed_qkv(self, mixed_qkv): def forward( self, hidden_states: torch.Tensor, - output: torch.Tensor, x_fp8=None, x_scale=None, ): @@ -724,8 +735,6 @@ def forward( 2. Core attention (custom op) 3. Output projection """ - num_tokens = hidden_states.size(0) - # ============================================================ # Part 1: Input Projection # ============================================================ @@ -771,7 +780,8 @@ def forward( # ============================================================ core_attn_out, maybe_scale = self.norm(core_attn_out, z) - output[:num_tokens] = self.out_proj(core_attn_out, x_scale=maybe_scale) + output = self.out_proj(core_attn_out, x_scale=maybe_scale) + return output if is_vllm(): @@ -963,28 +973,22 @@ def forward( else: hidden_states, residual = self.input_layernorm(hidden_states, residual) - self_attention_output = torch.empty( - hidden_states.shape, dtype=residual.dtype, device=hidden_states.device - ) if self.layer_type == "linear_attention": - self.linear_attn( + hidden_states = self.linear_attn( hidden_states=( hidden_bf16 if hidden_bf16 is not None else hidden_states ), - output=self_attention_output, x_fp8=hidden_states if x_scale is not None else None, x_scale=x_scale, ) elif self.layer_type == "full_attention": - self.self_attn( + hidden_states = self.self_attn( hidden_states=hidden_states, - output=self_attention_output, positions=positions, x_scale=x_scale, ) else: raise ValueError("Invalid layer_type") - hidden_states = self_attention_output if self.layer_scale: if len(hidden_states.shape) == 2: