Skip to content

Remove chunk split in Qwen3.5 and Qwen3Next#682

Open
ganyi1996ppo wants to merge 9 commits intomainfrom
ganyi/remove_chunk_split
Open

Remove chunk split in Qwen3.5 and Qwen3Next#682
ganyi1996ppo wants to merge 9 commits intomainfrom
ganyi/remove_chunk_split

Conversation

@ganyi1996ppo
Copy link
Copy Markdown
Contributor

@ganyi1996ppo ganyi1996ppo commented May 2, 2026

Motivation

depends on ROCm/aiter#3010

Technical Details

Test Plan

Test Result

Submission Checklist

Signed-off-by: ganyi <ygan@amd.com>
Copilot AI review requested due to automatic review settings May 2, 2026 13:11
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR removes the runtime “split-chunk” Triton path for Qwen3.5 and Qwen3-Next GDN attention projections by switching to contiguous projection outputs (via weight deinterleaving at load time) and using torch.split in the forward path. It also adds explicit zeroing of padding regions to make CUDA graph replay safer.

Changes:

  • Replace fused split-chunk runtime logic in Qwen3.5 / Qwen3-Next forwards with torch.split-based slicing.
  • Add deinterleaving weight loaders so projection outputs become contiguous in the expected [q|k|v|z|b|a] (or [b|a]) layouts.
  • Zero-fill attention output buffers on profile runs and zero the padding tail for CUDA graph replay safety.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
atom/plugin/vllm/attention_backend/attention_gdn.py Zeroes outputs for profile-run early return and clears padding tail for CUDA graph replay safety.
atom/models/qwen3_next.py Removes runtime split-chunk usage; adds deinterleaving loaders and switches to torch.split in forward.
atom/models/qwen3_5.py Removes runtime split-chunk usage; switches to torch.split/reshape and allocates core_attn_out directly.
atom/model_ops/linear.py Reworks QKVZBAParallelLinear section sizing and adds deinterleaving logic in its weight_loader.
atom/model_ops/attention_gdn.py Mirrors the CUDA-graph-safety zeroing behavior for the non-plugin attention op path.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread atom/models/qwen3_next.py Outdated
Comment on lines +179 to +181
# Even indices → b, odd indices → a
param.data[:nv] = src[0::2]
param.data[nv:] = src[1::2]
Comment thread atom/models/qwen3_next.py Outdated
Comment on lines +693 to +705
delattr(linear.weight, "weight_loader")
setattr(
linear.weight,
"weight_loader",
_qkvz_deinterleave_weight_loader(
self.num_k_heads,
self.num_v_heads,
self.head_k_dim,
self.head_v_dim,
self.tp_size,
self.tp_rank,
),
)
Comment thread atom/models/qwen3_next.py Outdated
Comment on lines +723 to +732
delattr(linear.weight, "weight_loader")
setattr(
linear.weight,
"weight_loader",
_ba_deinterleave_weight_loader(
self.num_v_heads,
self.tp_size,
self.tp_rank,
),
)
Comment thread atom/models/qwen3_next.py Outdated
Comment on lines +140 to +160
for g in range(nk):
base = g * group_size
# q
param.data[g * head_k_dim : (g + 1) * head_k_dim] = src[
base : base + head_k_dim
]
# k
param.data[q_total + g * head_k_dim : q_total + (g + 1) * head_k_dim] = src[
base + head_k_dim : base + 2 * head_k_dim
]
# v sub-heads
for s in range(R):
v_src = base + 2 * head_k_dim + s * head_v_dim
v_dst = q_total + k_total + (g * R + s) * head_v_dim
param.data[v_dst : v_dst + head_v_dim] = src[v_src : v_src + head_v_dim]
# z sub-heads
for s in range(R):
z_src = base + 2 * head_k_dim + R * head_v_dim + s * head_v_dim
z_dst = q_total + k_total + v_total + (g * R + s) * head_v_dim
param.data[z_dst : z_dst + head_v_dim] = src[z_src : z_src + head_v_dim]

Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Copilot AI review requested due to automatic review settings May 4, 2026 13:03
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread atom/models/qwen3_next.py
QKVZParallelLinear,
RowParallelLinear,
)
) # noqa: F401
Comment thread atom/model_ops/linear.py
Comment on lines +887 to +901
@staticmethod
def _match_dtype(param_data, loaded_weight):
"""View param_data as loaded_weight's dtype if they differ but share element size.

This mirrors ``weight_loader_process`` behaviour for FP8 on ROCm where
the param is ``float8_e4m3fnuz`` but the checkpoint stores
``float8_e4m3fn``. The normalisation happens later in
``process_weights_after_loading``.
"""
if (
param_data.dtype != loaded_weight.dtype
and param_data.element_size() == loaded_weight.element_size()
):
return param_data.view(loaded_weight.dtype)
return param_data
Comment on lines 545 to 577
def forward_cuda(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from atom.model_ops.triton_gemma_rmsnorm import gemma_rmsnorm_triton
# Use the aiter HIP fused_qk_rmsnorm_group_quant kernel in no-quant mode
# (q_out_scale=None) to perform Gemma RMSNorm + optional residual add.
# Same math as the Triton kernel: out = rmsnorm(x [+ residual]) * (1 + w),
# but executed by the aiter kernel for higher achieved bandwidth.
from aiter.ops.fused_qk_rmsnorm_group_quant import fused_qk_rmsnorm_group_quant

ori_shape = x.shape
x_2d = x.view(-1, ori_shape[-1])

out = torch.empty_like(x_2d)
if residual is not None:
residual_2d = residual.view(-1, ori_shape[-1])
res_out = torch.empty_like(x_2d)
else:
residual_2d = None
res_out = None

return gemma_rmsnorm_triton(
x, self.weight.data, self.variance_epsilon, residual
fused_qk_rmsnorm_group_quant(
q_out_quantized=None,
q_out_scale=None,
q=x_2d,
q_weight=self.weight.data,
q_epsilon=self.variance_epsilon,
q_out_unquantized=out,
q_res_out=res_out,
q_residual=residual_2d,
gemma_norm=True,
)
Signed-off-by: ganyi <ygan@amd.com>
Copilot AI review requested due to automatic review settings May 6, 2026 07:11
@ganyi1996ppo ganyi1996ppo force-pushed the ganyi/remove_chunk_split branch from 84aeaa1 to 7ecfc87 Compare May 6, 2026 07:11
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 7 out of 7 changed files in this pull request and generated 2 comments.

Comment thread atom/model_ops/linear.py
Comment on lines +918 to +924
ws = getattr(self, "weight_scale", None)
if ws is not None:
hk_s, hv_s = hk // 128, hv // 128
s = ws.data
ds = torch.empty_like(s)
self._deinterleave(ds, s, nk, R, hk_s, hv_s)
self.weight_scale.data = ds
Comment thread atom/model_ops/linear.py
Comment on lines +887 to +901
@staticmethod
def _match_dtype(param_data, loaded_weight):
"""View param_data as loaded_weight's dtype if they differ but share element size.

This mirrors ``weight_loader_process`` behaviour for FP8 on ROCm where
the param is ``float8_e4m3fnuz`` but the checkpoint stores
``float8_e4m3fn``. The normalisation happens later in
``process_weights_after_loading``.
"""
if (
param_data.dtype != loaded_weight.dtype
and param_data.element_size() == loaded_weight.element_size()
):
return param_data.view(loaded_weight.dtype)
return param_data
Signed-off-by: ganyi <ygan@amd.com>
Copilot AI review requested due to automatic review settings May 6, 2026 12:28
@ganyi1996ppo ganyi1996ppo force-pushed the ganyi/remove_chunk_split branch 2 times, most recently from a802d3b to 5464767 Compare May 6, 2026 12:28
@ganyi1996ppo ganyi1996ppo mentioned this pull request May 6, 2026
1 task
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 8 changed files in this pull request and generated 3 comments.

v_feat_idx = idx_feats - v_start_dim
is_q_block = idx_feats < k_start_dim
is_k_block = (idx_feats >= k_start_dim) & (idx_feats < v_start_dim)
is_v_block = idx_feats >= v_start_dim
Comment thread atom/model_ops/linear.py
Comment on lines +919 to +924
if ws is not None:
hk_s, hv_s = hk // 128, hv // 128
s = ws.data
ds = torch.empty_like(s)
self._deinterleave(ds, s, nk, R, hk_s, hv_s)
self.weight_scale.data = ds
Comment thread atom/model_ops/linear.py
Comment on lines +640 to +641
Output layout is always ``[q | k | v | z | b | a]`` contiguous on dim-0,
so the caller can use ``torch.split`` for zero-copy views.
@ganyi1996ppo ganyi1996ppo force-pushed the ganyi/remove_chunk_split branch from 5464767 to 46a2b76 Compare May 7, 2026 00:55
Copilot AI review requested due to automatic review settings May 7, 2026 02:05
@ganyi1996ppo ganyi1996ppo force-pushed the ganyi/remove_chunk_split branch from 46a2b76 to bc3d69f Compare May 7, 2026 02:05
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.

v_feat_idx = idx_feats - v_start_dim
is_q_block = idx_feats < k_start_dim
is_k_block = (idx_feats >= k_start_dim) & (idx_feats < v_start_dim)
is_v_block = idx_feats >= v_start_dim
Comment thread atom/model_ops/linear.py
Comment on lines +887 to +901
@staticmethod
def _match_dtype(param_data, loaded_weight):
"""View param_data as loaded_weight's dtype if they differ but share element size.

This mirrors ``weight_loader_process`` behaviour for FP8 on ROCm where
the param is ``float8_e4m3fnuz`` but the checkpoint stores
``float8_e4m3fn``. The normalisation happens later in
``process_weights_after_loading``.
"""
if (
param_data.dtype != loaded_weight.dtype
and param_data.element_size() == loaded_weight.element_size()
):
return param_data.view(loaded_weight.dtype)
return param_data
Signed-off-by: ganyi <ygan@amd.com>
@ganyi1996ppo ganyi1996ppo force-pushed the ganyi/remove_chunk_split branch from bc3d69f to 4751bc3 Compare May 7, 2026 05:04
Signed-off-by: ganyi <ygan@amd.com>
Copilot AI review requested due to automatic review settings May 7, 2026 05:34
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 9 out of 9 changed files in this pull request and generated 2 comments.

Comment on lines +97 to +103
# within one of q/k/v — only one mask is all-true per block.
q_feat_idx = idx_feats
k_feat_idx = idx_feats - k_start_dim
v_feat_idx = idx_feats - v_start_dim
is_q_block = idx_feats < k_start_dim
is_k_block = (idx_feats >= k_start_dim) & (idx_feats < v_start_dim)
is_v_block = idx_feats >= v_start_dim
Comment thread atom/model_ops/linear.py
Comment on lines +887 to +901
@staticmethod
def _match_dtype(param_data, loaded_weight):
"""View param_data as loaded_weight's dtype if they differ but share element size.

This mirrors ``weight_loader_process`` behaviour for FP8 on ROCm where
the param is ``float8_e4m3fnuz`` but the checkpoint stores
``float8_e4m3fn``. The normalisation happens later in
``process_weights_after_loading``.
"""
if (
param_data.dtype != loaded_weight.dtype
and param_data.element_size() == loaded_weight.element_size()
):
return param_data.view(loaded_weight.dtype)
return param_data

return gemma_rmsnorm_triton(
x, self.weight.data, self.variance_epsilon, residual
fused_qk_rmsnorm_group_quant(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

switch to fused_qk_rmsnorm?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fused_qk_rmsnorm seems not support gemma yet......

"extraArgs": "-tp 8",
"env_vars": "",
"extraArgs": "-tp 8 --kv_cache_dtype fp8",
"env_vars": "ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enable this one by default?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Signed-off-by: ganyi <ygan@amd.com>
Comment thread atom/models/qwen3_next.py
# Output layout is [q|k|v|z|b|a] contiguous (deinterleaved at load)
mixed_qkv, z_flat, b, a = torch.split(
projected, [qkv_size, z_size, b_size, a_size], dim=-1
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F821> reported by reviewdog 🐶
Undefined name num_tokens

Comment thread atom/models/qwen3_next.py
Signed-off-by: ganyi <ygan@amd.com>
Copilot AI review requested due to automatic review settings May 8, 2026 07:31
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 11 out of 11 changed files in this pull request and generated 3 comments.

Comment thread atom/utils/envs.py
Comment on lines 37 to 38
# QK-norm-rope-cache-quant fusion for Qwen3-MoE; disabled by default.
# Enable for Qwen3-MoE to get better performance.
Comment on lines +243 to +246
# use_ps = self.adopt_persistent_kernel(
# head_size, num_kv_heads, num_q_heads_total
# )
use_ps = True
v_feat_idx = idx_feats - v_start_dim
is_q_block = idx_feats < k_start_dim
is_k_block = (idx_feats >= k_start_dim) & (idx_feats < v_start_dim)
is_v_block = idx_feats >= v_start_dim
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants