Remove chunk split in Qwen3.5 and Qwen3Next#682
Conversation
Signed-off-by: ganyi <ygan@amd.com>
There was a problem hiding this comment.
Pull request overview
This PR removes the runtime “split-chunk” Triton path for Qwen3.5 and Qwen3-Next GDN attention projections by switching to contiguous projection outputs (via weight deinterleaving at load time) and using torch.split in the forward path. It also adds explicit zeroing of padding regions to make CUDA graph replay safer.
Changes:
- Replace fused split-chunk runtime logic in Qwen3.5 / Qwen3-Next forwards with
torch.split-based slicing. - Add deinterleaving weight loaders so projection outputs become contiguous in the expected
[q|k|v|z|b|a](or[b|a]) layouts. - Zero-fill attention output buffers on profile runs and zero the padding tail for CUDA graph replay safety.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
atom/plugin/vllm/attention_backend/attention_gdn.py |
Zeroes outputs for profile-run early return and clears padding tail for CUDA graph replay safety. |
atom/models/qwen3_next.py |
Removes runtime split-chunk usage; adds deinterleaving loaders and switches to torch.split in forward. |
atom/models/qwen3_5.py |
Removes runtime split-chunk usage; switches to torch.split/reshape and allocates core_attn_out directly. |
atom/model_ops/linear.py |
Reworks QKVZBAParallelLinear section sizing and adds deinterleaving logic in its weight_loader. |
atom/model_ops/attention_gdn.py |
Mirrors the CUDA-graph-safety zeroing behavior for the non-plugin attention op path. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Even indices → b, odd indices → a | ||
| param.data[:nv] = src[0::2] | ||
| param.data[nv:] = src[1::2] |
| delattr(linear.weight, "weight_loader") | ||
| setattr( | ||
| linear.weight, | ||
| "weight_loader", | ||
| _qkvz_deinterleave_weight_loader( | ||
| self.num_k_heads, | ||
| self.num_v_heads, | ||
| self.head_k_dim, | ||
| self.head_v_dim, | ||
| self.tp_size, | ||
| self.tp_rank, | ||
| ), | ||
| ) |
| delattr(linear.weight, "weight_loader") | ||
| setattr( | ||
| linear.weight, | ||
| "weight_loader", | ||
| _ba_deinterleave_weight_loader( | ||
| self.num_v_heads, | ||
| self.tp_size, | ||
| self.tp_rank, | ||
| ), | ||
| ) |
| for g in range(nk): | ||
| base = g * group_size | ||
| # q | ||
| param.data[g * head_k_dim : (g + 1) * head_k_dim] = src[ | ||
| base : base + head_k_dim | ||
| ] | ||
| # k | ||
| param.data[q_total + g * head_k_dim : q_total + (g + 1) * head_k_dim] = src[ | ||
| base + head_k_dim : base + 2 * head_k_dim | ||
| ] | ||
| # v sub-heads | ||
| for s in range(R): | ||
| v_src = base + 2 * head_k_dim + s * head_v_dim | ||
| v_dst = q_total + k_total + (g * R + s) * head_v_dim | ||
| param.data[v_dst : v_dst + head_v_dim] = src[v_src : v_src + head_v_dim] | ||
| # z sub-heads | ||
| for s in range(R): | ||
| z_src = base + 2 * head_k_dim + R * head_v_dim + s * head_v_dim | ||
| z_dst = q_total + k_total + v_total + (g * R + s) * head_v_dim | ||
| param.data[z_dst : z_dst + head_v_dim] = src[z_src : z_src + head_v_dim] | ||
|
|
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| QKVZParallelLinear, | ||
| RowParallelLinear, | ||
| ) | ||
| ) # noqa: F401 |
| @staticmethod | ||
| def _match_dtype(param_data, loaded_weight): | ||
| """View param_data as loaded_weight's dtype if they differ but share element size. | ||
|
|
||
| This mirrors ``weight_loader_process`` behaviour for FP8 on ROCm where | ||
| the param is ``float8_e4m3fnuz`` but the checkpoint stores | ||
| ``float8_e4m3fn``. The normalisation happens later in | ||
| ``process_weights_after_loading``. | ||
| """ | ||
| if ( | ||
| param_data.dtype != loaded_weight.dtype | ||
| and param_data.element_size() == loaded_weight.element_size() | ||
| ): | ||
| return param_data.view(loaded_weight.dtype) | ||
| return param_data |
| def forward_cuda( | ||
| self, | ||
| x: torch.Tensor, | ||
| residual: torch.Tensor | None = None, | ||
| ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: | ||
| from atom.model_ops.triton_gemma_rmsnorm import gemma_rmsnorm_triton | ||
| # Use the aiter HIP fused_qk_rmsnorm_group_quant kernel in no-quant mode | ||
| # (q_out_scale=None) to perform Gemma RMSNorm + optional residual add. | ||
| # Same math as the Triton kernel: out = rmsnorm(x [+ residual]) * (1 + w), | ||
| # but executed by the aiter kernel for higher achieved bandwidth. | ||
| from aiter.ops.fused_qk_rmsnorm_group_quant import fused_qk_rmsnorm_group_quant | ||
|
|
||
| ori_shape = x.shape | ||
| x_2d = x.view(-1, ori_shape[-1]) | ||
|
|
||
| out = torch.empty_like(x_2d) | ||
| if residual is not None: | ||
| residual_2d = residual.view(-1, ori_shape[-1]) | ||
| res_out = torch.empty_like(x_2d) | ||
| else: | ||
| residual_2d = None | ||
| res_out = None | ||
|
|
||
| return gemma_rmsnorm_triton( | ||
| x, self.weight.data, self.variance_epsilon, residual | ||
| fused_qk_rmsnorm_group_quant( | ||
| q_out_quantized=None, | ||
| q_out_scale=None, | ||
| q=x_2d, | ||
| q_weight=self.weight.data, | ||
| q_epsilon=self.variance_epsilon, | ||
| q_out_unquantized=out, | ||
| q_res_out=res_out, | ||
| q_residual=residual_2d, | ||
| gemma_norm=True, | ||
| ) |
84aeaa1 to
7ecfc87
Compare
| ws = getattr(self, "weight_scale", None) | ||
| if ws is not None: | ||
| hk_s, hv_s = hk // 128, hv // 128 | ||
| s = ws.data | ||
| ds = torch.empty_like(s) | ||
| self._deinterleave(ds, s, nk, R, hk_s, hv_s) | ||
| self.weight_scale.data = ds |
| @staticmethod | ||
| def _match_dtype(param_data, loaded_weight): | ||
| """View param_data as loaded_weight's dtype if they differ but share element size. | ||
|
|
||
| This mirrors ``weight_loader_process`` behaviour for FP8 on ROCm where | ||
| the param is ``float8_e4m3fnuz`` but the checkpoint stores | ||
| ``float8_e4m3fn``. The normalisation happens later in | ||
| ``process_weights_after_loading``. | ||
| """ | ||
| if ( | ||
| param_data.dtype != loaded_weight.dtype | ||
| and param_data.element_size() == loaded_weight.element_size() | ||
| ): | ||
| return param_data.view(loaded_weight.dtype) | ||
| return param_data |
Signed-off-by: ganyi <ygan@amd.com>
a802d3b to
5464767
Compare
| v_feat_idx = idx_feats - v_start_dim | ||
| is_q_block = idx_feats < k_start_dim | ||
| is_k_block = (idx_feats >= k_start_dim) & (idx_feats < v_start_dim) | ||
| is_v_block = idx_feats >= v_start_dim |
| if ws is not None: | ||
| hk_s, hv_s = hk // 128, hv // 128 | ||
| s = ws.data | ||
| ds = torch.empty_like(s) | ||
| self._deinterleave(ds, s, nk, R, hk_s, hv_s) | ||
| self.weight_scale.data = ds |
| Output layout is always ``[q | k | v | z | b | a]`` contiguous on dim-0, | ||
| so the caller can use ``torch.split`` for zero-copy views. |
5464767 to
46a2b76
Compare
46a2b76 to
bc3d69f
Compare
| v_feat_idx = idx_feats - v_start_dim | ||
| is_q_block = idx_feats < k_start_dim | ||
| is_k_block = (idx_feats >= k_start_dim) & (idx_feats < v_start_dim) | ||
| is_v_block = idx_feats >= v_start_dim |
| @staticmethod | ||
| def _match_dtype(param_data, loaded_weight): | ||
| """View param_data as loaded_weight's dtype if they differ but share element size. | ||
|
|
||
| This mirrors ``weight_loader_process`` behaviour for FP8 on ROCm where | ||
| the param is ``float8_e4m3fnuz`` but the checkpoint stores | ||
| ``float8_e4m3fn``. The normalisation happens later in | ||
| ``process_weights_after_loading``. | ||
| """ | ||
| if ( | ||
| param_data.dtype != loaded_weight.dtype | ||
| and param_data.element_size() == loaded_weight.element_size() | ||
| ): | ||
| return param_data.view(loaded_weight.dtype) | ||
| return param_data |
Signed-off-by: ganyi <ygan@amd.com>
bc3d69f to
4751bc3
Compare
Signed-off-by: ganyi <ygan@amd.com>
| # within one of q/k/v — only one mask is all-true per block. | ||
| q_feat_idx = idx_feats | ||
| k_feat_idx = idx_feats - k_start_dim | ||
| v_feat_idx = idx_feats - v_start_dim | ||
| is_q_block = idx_feats < k_start_dim | ||
| is_k_block = (idx_feats >= k_start_dim) & (idx_feats < v_start_dim) | ||
| is_v_block = idx_feats >= v_start_dim |
| @staticmethod | ||
| def _match_dtype(param_data, loaded_weight): | ||
| """View param_data as loaded_weight's dtype if they differ but share element size. | ||
|
|
||
| This mirrors ``weight_loader_process`` behaviour for FP8 on ROCm where | ||
| the param is ``float8_e4m3fnuz`` but the checkpoint stores | ||
| ``float8_e4m3fn``. The normalisation happens later in | ||
| ``process_weights_after_loading``. | ||
| """ | ||
| if ( | ||
| param_data.dtype != loaded_weight.dtype | ||
| and param_data.element_size() == loaded_weight.element_size() | ||
| ): | ||
| return param_data.view(loaded_weight.dtype) | ||
| return param_data |
|
|
||
| return gemma_rmsnorm_triton( | ||
| x, self.weight.data, self.variance_epsilon, residual | ||
| fused_qk_rmsnorm_group_quant( |
There was a problem hiding this comment.
switch to fused_qk_rmsnorm?
There was a problem hiding this comment.
fused_qk_rmsnorm seems not support gemma yet......
| "extraArgs": "-tp 8", | ||
| "env_vars": "", | ||
| "extraArgs": "-tp 8 --kv_cache_dtype fp8", | ||
| "env_vars": "ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1", |
There was a problem hiding this comment.
enable this one by default?
Signed-off-by: ganyi <ygan@amd.com>
| # Output layout is [q|k|v|z|b|a] contiguous (deinterleaved at load) | ||
| mixed_qkv, z_flat, b, a = torch.split( | ||
| projected, [qkv_size, z_size, b_size, a_size], dim=-1 | ||
| ) |
Signed-off-by: ganyi <ygan@amd.com>
| # QK-norm-rope-cache-quant fusion for Qwen3-MoE; disabled by default. | ||
| # Enable for Qwen3-MoE to get better performance. |
| # use_ps = self.adopt_persistent_kernel( | ||
| # head_size, num_kv_heads, num_q_heads_total | ||
| # ) | ||
| use_ps = True |
| v_feat_idx = idx_feats - v_start_dim | ||
| is_q_block = idx_feats < k_start_dim | ||
| is_k_block = (idx_feats >= k_start_dim) & (idx_feats < v_start_dim) | ||
| is_v_block = idx_feats >= v_start_dim |
Motivation
depends on ROCm/aiter#3010
Technical Details
Test Plan
Test Result
Submission Checklist