perf: fused Triton kernels for Qwen3.5 RMSNorm and MRoPE#708
perf: fused Triton kernels for Qwen3.5 RMSNorm and MRoPE#708
Conversation
Three performance optimizations for Qwen3.5 (GDN + full attention): 1. RMSNormGated: add Triton fused kernel (decode + prefill paths) and aiter HIP kernel integration with fallback chain 2. GemmaRMSNorm: replace torch.compile with direct aiter rmsnorm2d_fwd calls, cache (weight + 1.0) to avoid per-call allocation 3. MRoPE: add fused Triton kernel that applies multi-resolution RoPE to Q and K in a single launch (specialized for head_size=256, rotary_dim=64, Neox-style) 4. qwen3_next: integrate MRoPE fusion, remove pre-allocated output pattern (output[:] = ...) in favor of direct return Benchmark (MI308X, Qwen3.5-27B, TP2, ISL=60381, OSL=132): | Version | TPOT(ms) | TTFT(ms) | Tput/GPU | |-------------------------|----------|----------|----------| | ATOM baseline | 18.7 | 15266 | 1574 | | + RMSNormGated Triton | 16.7 | 15247 | 1597 | | + this patch | 15.6 | 15300 | 1605 | TPOT reduced by 17% end-to-end (18.7 → 15.6 ms). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
This PR introduces fused Triton kernels to reduce launch overhead and improve end-to-end latency for Qwen3.5, including a fused Q/K MRoPE path and optimized RMSNorm variants.
Changes:
- Integrates a fused Triton MRoPE kernel for applying Qwen3.5 multi-resolution RoPE to Q and K in a single launch, with a fallback to the existing rotary embedding implementation.
- Adds Triton (and optional aiter HIP) fused kernels for
RMSNormGatedand replacestorch.compileusage inGemmaRMSNormwith direct aiterrmsnorm2d_fwdcalls (including a cachedweight + 1.0). - Refactors
qwen3_next.pyattention/linear-attention codepaths to return tensors directly (removing the preallocated output-buffer write pattern).
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
atom/models/qwen3_next.py |
Wires in fused MRoPE Q/K path and refactors attention APIs to return outputs instead of writing into preallocated buffers. |
atom/model_ops/triton_mrope.py |
Adds a new fused Triton kernel specialized for Qwen3.5 MRoPE (Q/K fused). |
atom/model_ops/layernorm.py |
Adds/uses fused Triton + aiter RMSNormGated paths and switches GemmaRMSNorm to aiter rmsnorm2d kernels with cached weight adjustment. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| hidden_states = self.linear_attn( | ||
| hidden_states=( | ||
| hidden_bf16 if hidden_bf16 is not None else hidden_states | ||
| ), | ||
| output=self_attention_output, | ||
| x_fp8=hidden_states if x_scale is not None else None, |
| def forward( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| output: torch.Tensor, | ||
| x_fp8=None, | ||
| x_scale=None, | ||
| ): |
| mrope_section = getattr(rotary_emb, "mrope_section", None) | ||
| if ( | ||
| positions.ndim != 2 | ||
| or mrope_section is None | ||
| or not getattr(rotary_emb, "mrope_interleaved", False) |
| or q.shape[1] != num_q_heads * head_size | ||
| or k.shape[1] != num_k_heads * head_size | ||
| ): | ||
| return None | ||
|
|
||
| q_out = torch.empty_like(q) | ||
| k_out = torch.empty_like(k) |
| def _get_aiter_weight(self) -> torch.Tensor: | ||
| """Cache weight + 1.0 for aiter rmsnorm (which uses x*w, not x*(1+w)).""" | ||
| if not hasattr(self, "_aiter_weight_cache") or self._aiter_weight_cache is None: | ||
| self._aiter_weight_cache = self.weight.data + 1.0 | ||
| return self._aiter_weight_cache |
| x = x.view(-1, ori_shape[-1]) | ||
|
|
||
| if residual is not None: | ||
| residual = residual.view(-1, ori_shape[-1]) | ||
| out = torch.empty_like(x) | ||
| residual_out = torch.empty_like(residual) | ||
| rmsnorm2d_fwd_with_add( | ||
| out, x, residual, residual_out, weight, self.variance_epsilon | ||
| ) | ||
| return out.view(ori_shape), residual_out.view(ori_shape) | ||
|
|
||
| return rmsnorm2d_fwd(x, weight, self.variance_epsilon).view(ori_shape) |
|
we seems already have this if my memory if I remember correctly..... |
|
Hi @ganyi1996ppo, thanks for the suggestion! I tested
The built-in fusion env var gives ~12% TPOT improvement over baseline (18.7→16.6ms), but this PR's Triton approach achieves ~17% (18.7→15.6ms). The two paths are mutually exclusive in Note: the env var path fuses QK norm + RoPE + cache quant inside — Claude (AI assistant helping @zovonoir with performance analysis) |
Keep aiter rmsnorm2d_fwd path over upstream's triton_gemma_rmsnorm — direct aiter kernel calls avoid Triton JIT overhead and match the approach used for RMSNormGated. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
| z_ptr, | ||
| weight_ptr, | ||
| out_ptr, | ||
| num_rows: tl.constexpr, |
| pos_stride_row: tl.constexpr, | ||
| cos_stride_pos: tl.constexpr, | ||
| sin_stride_pos: tl.constexpr, | ||
| num_tokens: tl.constexpr, |
| mrope_section = getattr(rotary_emb, "mrope_section", None) | ||
| if ( | ||
| positions.ndim != 2 | ||
| or mrope_section is None | ||
| or not getattr(rotary_emb, "mrope_interleaved", False) | ||
| or head_size != 256 | ||
| or getattr(rotary_emb, "rotary_dim", None) != 64 | ||
| or not getattr(rotary_emb, "is_neox_style", False) | ||
| or q.ndim != 2 | ||
| or k.ndim != 2 | ||
| or q.shape[1] != num_q_heads * head_size | ||
| or k.shape[1] != num_k_heads * head_size | ||
| ): | ||
| return None | ||
|
|
||
| q_out = torch.empty_like(q) | ||
| k_out = torch.empty_like(k) | ||
| num_tokens = positions.shape[1] | ||
| block_d = triton.next_power_of_2(head_size) |
| def _get_aiter_weight(self) -> torch.Tensor: | ||
| """Cache weight + 1.0 for aiter rmsnorm (which uses x*w, not x*(1+w)).""" | ||
| if not hasattr(self, "_aiter_weight_cache") or self._aiter_weight_cache is None: | ||
| self._aiter_weight_cache = self.weight.data + 1.0 | ||
| return self._aiter_weight_cache |
Summary
RMSNormGated(decode + prefill paths) with aiter HIP kernel integration and fallback chaintorch.compileinGemmaRMSNormwith direct aiterrmsnorm2d_fwdcalls, cachingweight + 1.0to avoid per-call allocationqwen3_next.py, remove pre-allocated output pattern (output[:] = ...) in favor of direct returnBenchmark
MI308X, Qwen3.5-27B, ATOM SGLang plugin mode, TP2, ISL=60381, OSL=132:
TPOT reduced by 17% end-to-end (18.7 → 15.6 ms).
Files Changed
atom/model_ops/layernorm.py— Triton RMSNormGated kernels + aiter HIP integration + GemmaRMSNorm optimizationatom/model_ops/triton_mrope.py— new file: fused MRoPE Triton kernel for Qwen3.5atom/models/qwen3_next.py— integrate MRoPE fusion + remove pre-allocated output patternTest plan
🤖 Generated with Claude Code