diff --git a/atom/plugin/vllm/attention_backend/attention_gdn.py b/atom/plugin/vllm/attention_backend/attention_gdn.py index 619964abf..bc1230665 100644 --- a/atom/plugin/vllm/attention_backend/attention_gdn.py +++ b/atom/plugin/vllm/attention_backend/attention_gdn.py @@ -358,12 +358,16 @@ def forward( # 2.2: Process the remaining part if attn_metadata.num_prefills > 0: + from aiter.ops.triton.gated_delta_net.gated_delta_rule import ( + chunk_gated_delta_rule_opt_vk, + ) + initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() initial_state[~has_initial_state, ...] = 0 ( core_attn_out_non_spec, last_recurrent_state, - ) = self.chunk_gated_delta_rule( + ) = chunk_gated_delta_rule_opt_vk( q=query_non_spec, k=key_non_spec, v=value_non_spec,