Skip to content

perf: fused Triton kernels for Qwen3.5 RMSNorm and MRoPE#708

Open
zovonoir wants to merge 3 commits intoROCm:mainfrom
zovonoir:perf/triton-fused-kernels-qwen35
Open

perf: fused Triton kernels for Qwen3.5 RMSNorm and MRoPE#708
zovonoir wants to merge 3 commits intoROCm:mainfrom
zovonoir:perf/triton-fused-kernels-qwen35

Conversation

@zovonoir
Copy link
Copy Markdown
Contributor

@zovonoir zovonoir commented May 7, 2026

Summary

  • Add fused Triton kernels for RMSNormGated (decode + prefill paths) with aiter HIP kernel integration and fallback chain
  • Replace torch.compile in GemmaRMSNorm with direct aiter rmsnorm2d_fwd calls, caching weight + 1.0 to avoid per-call allocation
  • Add fused Triton MRoPE kernel that applies multi-resolution RoPE to Q and K in a single launch (specialized for Qwen3.5: head_size=256, rotary_dim=64, Neox-style)
  • Integrate MRoPE fusion in qwen3_next.py, remove pre-allocated output pattern (output[:] = ...) in favor of direct return

Benchmark

MI308X, Qwen3.5-27B, ATOM SGLang plugin mode, TP2, ISL=60381, OSL=132:

Version TPOT (ms) TTFT (ms) Tput/GPU (tok/s)
ATOM baseline 18.7 15266 1574
+ RMSNormGated Triton 16.7 15247 1597
+ this patch 15.6 15300 1605

TPOT reduced by 17% end-to-end (18.7 → 15.6 ms).

Files Changed

  • atom/model_ops/layernorm.py — Triton RMSNormGated kernels + aiter HIP integration + GemmaRMSNorm optimization
  • atom/model_ops/triton_mrope.py — new file: fused MRoPE Triton kernel for Qwen3.5
  • atom/models/qwen3_next.py — integrate MRoPE fusion + remove pre-allocated output pattern

Test plan

  • End-to-end benchmark on MI308X with Qwen3.5-27B
  • CI accuracy tests

🤖 Generated with Claude Code

Three performance optimizations for Qwen3.5 (GDN + full attention):

1. RMSNormGated: add Triton fused kernel (decode + prefill paths) and
   aiter HIP kernel integration with fallback chain
2. GemmaRMSNorm: replace torch.compile with direct aiter rmsnorm2d_fwd
   calls, cache (weight + 1.0) to avoid per-call allocation
3. MRoPE: add fused Triton kernel that applies multi-resolution RoPE
   to Q and K in a single launch (specialized for head_size=256,
   rotary_dim=64, Neox-style)
4. qwen3_next: integrate MRoPE fusion, remove pre-allocated output
   pattern (output[:] = ...) in favor of direct return

Benchmark (MI308X, Qwen3.5-27B, TP2, ISL=60381, OSL=132):

| Version                 | TPOT(ms) | TTFT(ms) | Tput/GPU |
|-------------------------|----------|----------|----------|
| ATOM baseline           | 18.7     | 15266    | 1574     |
| + RMSNormGated Triton   | 16.7     | 15247    | 1597     |
| + this patch            | 15.6     | 15300    | 1605     |

TPOT reduced by 17% end-to-end (18.7 → 15.6 ms).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings May 7, 2026 02:59
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces fused Triton kernels to reduce launch overhead and improve end-to-end latency for Qwen3.5, including a fused Q/K MRoPE path and optimized RMSNorm variants.

Changes:

  • Integrates a fused Triton MRoPE kernel for applying Qwen3.5 multi-resolution RoPE to Q and K in a single launch, with a fallback to the existing rotary embedding implementation.
  • Adds Triton (and optional aiter HIP) fused kernels for RMSNormGated and replaces torch.compile usage in GemmaRMSNorm with direct aiter rmsnorm2d_fwd calls (including a cached weight + 1.0).
  • Refactors qwen3_next.py attention/linear-attention codepaths to return tensors directly (removing the preallocated output-buffer write pattern).

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.

File Description
atom/models/qwen3_next.py Wires in fused MRoPE Q/K path and refactors attention APIs to return outputs instead of writing into preallocated buffers.
atom/model_ops/triton_mrope.py Adds a new fused Triton kernel specialized for Qwen3.5 MRoPE (Q/K fused).
atom/model_ops/layernorm.py Adds/uses fused Triton + aiter RMSNormGated paths and switches GemmaRMSNorm to aiter rmsnorm2d kernels with cached weight adjustment.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread atom/models/qwen3_next.py
Comment on lines +977 to 981
hidden_states = self.linear_attn(
hidden_states=(
hidden_bf16 if hidden_bf16 is not None else hidden_states
),
output=self_attention_output,
x_fp8=hidden_states if x_scale is not None else None,
Comment thread atom/models/qwen3_next.py
Comment on lines 726 to 731
def forward(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
x_fp8=None,
x_scale=None,
):
Comment on lines +216 to +220
mrope_section = getattr(rotary_emb, "mrope_section", None)
if (
positions.ndim != 2
or mrope_section is None
or not getattr(rotary_emb, "mrope_interleaved", False)
Comment on lines +226 to +232
or q.shape[1] != num_q_heads * head_size
or k.shape[1] != num_k_heads * head_size
):
return None

q_out = torch.empty_like(q)
k_out = torch.empty_like(k)
Comment on lines +677 to +681
def _get_aiter_weight(self) -> torch.Tensor:
"""Cache weight + 1.0 for aiter rmsnorm (which uses x*w, not x*(1+w))."""
if not hasattr(self, "_aiter_weight_cache") or self._aiter_weight_cache is None:
self._aiter_weight_cache = self.weight.data + 1.0
return self._aiter_weight_cache
Comment on lines +692 to +703
x = x.view(-1, ori_shape[-1])

if residual is not None:
residual = residual.view(-1, ori_shape[-1])
out = torch.empty_like(x)
residual_out = torch.empty_like(residual)
rmsnorm2d_fwd_with_add(
out, x, residual, residual_out, weight, self.variance_epsilon
)
return out.view(ori_shape), residual_out.view(ori_shape)

return rmsnorm2d_fwd(x, weight, self.variance_epsilon).view(ori_shape)
@ganyi1996ppo
Copy link
Copy Markdown
Contributor

we seems already have this if my memory if I remember correctly.....
try open the env var ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION

@zovonoir
Copy link
Copy Markdown
Contributor Author

zovonoir commented May 7, 2026

Hi @ganyi1996ppo, thanks for the suggestion!

I tested ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1 on Qwen3.5-27B (MI308X, TP2, ISL=60381, OSL=132, bf16) and compared it against the Triton fused MRoPE kernel in this PR. Here are the results:

Configuration TPOT (ms) TTFT (ms) Tput/GPU (tok/s)
ATOM baseline (no optimizations) 18.7 15266 1574
This PR (Triton RMSNormGated + fused MRoPE) 15.6 15300 1605
ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1 (replaces this PR's MRoPE path) 16.6 15271 1597

The built-in fusion env var gives ~12% TPOT improvement over baseline (18.7→16.6ms), but this PR's Triton approach achieves ~17% (18.7→15.6ms). The two paths are mutually exclusive in Qwen3NextAttention.forward() — enabling the env var bypasses try_mrope_qk_fused().

Note: the env var path fuses QK norm + RoPE + cache quant inside Attention, while this PR fuses QK norm + RoPE externally via a dedicated Triton kernel. The difference may come from the Triton kernel being better tuned for this specific workload.

— Claude (AI assistant helping @zovonoir with performance analysis)

zovonoir and others added 2 commits May 7, 2026 14:24
Keep aiter rmsnorm2d_fwd path over upstream's triton_gemma_rmsnorm —
direct aiter kernel calls avoid Triton JIT overhead and match the
approach used for RMSNormGated.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings May 7, 2026 08:34
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.

z_ptr,
weight_ptr,
out_ptr,
num_rows: tl.constexpr,
pos_stride_row: tl.constexpr,
cos_stride_pos: tl.constexpr,
sin_stride_pos: tl.constexpr,
num_tokens: tl.constexpr,
Comment on lines +216 to +234
mrope_section = getattr(rotary_emb, "mrope_section", None)
if (
positions.ndim != 2
or mrope_section is None
or not getattr(rotary_emb, "mrope_interleaved", False)
or head_size != 256
or getattr(rotary_emb, "rotary_dim", None) != 64
or not getattr(rotary_emb, "is_neox_style", False)
or q.ndim != 2
or k.ndim != 2
or q.shape[1] != num_q_heads * head_size
or k.shape[1] != num_k_heads * head_size
):
return None

q_out = torch.empty_like(q)
k_out = torch.empty_like(k)
num_tokens = positions.shape[1]
block_d = triton.next_power_of_2(head_size)
Comment on lines +677 to +681
def _get_aiter_weight(self) -> torch.Tensor:
"""Cache weight + 1.0 for aiter rmsnorm (which uses x*w, not x*(1+w))."""
if not hasattr(self, "_aiter_weight_cache") or self._aiter_weight_cache is None:
self._aiter_weight_cache = self.weight.data + 1.0
return self._aiter_weight_cache
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants