Skip to content

perf: optimize GDN decode with SGLang fused recurrent kernel#727

Open
zovonoir wants to merge 1 commit intoROCm:mainfrom
zovonoir:perf/gdn-decode-optimization
Open

perf: optimize GDN decode with SGLang fused recurrent kernel#727
zovonoir wants to merge 1 commit intoROCm:mainfrom
zovonoir:perf/gdn-decode-optimization

Conversation

@zovonoir
Copy link
Copy Markdown
Contributor

@zovonoir zovonoir commented May 9, 2026

Summary

  • Use SGLang's fused_sigmoid_gating_delta_rule_update for pure-decode batches in GDN attention, fusing gating computation (sigmoid, softplus, dt_bias) with the delta rule state update into a single kernel call
  • Skip materialization of intermediate g/beta tensors when the fused path is available
  • Fix conv_state transpose to be conditional — SGLang provides [slot, conv_dim, state_len] layout directly, while ModelRunner uses [slot, state_len, conv_dim]
  • Use getattr for gdn_metadata access to avoid AttributeError on non-GDN attention metadata

Benchmark

Qwen3.5-27B bf16 (MI308X, TP2, ISL=60381, OSL=132):

Version TPOT (ms) TTFT (ms) Tput/GPU (tok/s)
ATOM baseline 18.9 17562 1510
+ fused Triton kernels (#708) + tuned GEMM (ROCm/aiter#3077) 14.8 12613 2078
+ this patch 14.7 12628 2079

Qwen3.5-35B-A3B-FP8 (MI308X, TP1, ISL=4096, OSL=2048, concurrency=384):

Version TPOT (ms) TTFT (ms) Tput/GPU (tok/s)
before this patch 133.3 32559 7722
+ this patch 129.4 32400 7937

Files Changed

  • atom/model_ops/attention_gdn.py — fused decode path + conv_state fix

Test plan

  • End-to-end benchmark on MI308X with Qwen3.5-27B (bf16) and Qwen3.5-35B-A3B (FP8)
  • CI accuracy tests

🤖 Generated with Claude Code

Use SGLang's fused_sigmoid_gating_delta_rule_update for pure-decode
batches, which fuses gating computation (sigmoid, softplus, dt_bias)
with the delta rule state update into a single kernel call, avoiding
materialization of intermediate g/beta tensors.

Also fix conv_state transpose to be conditional — SGLang provides
[slot, conv_dim, state_len] layout directly, while ModelRunner uses
[slot, state_len, conv_dim].

Benchmark (MI308X, TP2, ISL=60381, OSL=132):
Qwen3.5-27B bf16:
| Version           | TPOT(ms) | TTFT(ms) | Tput/GPU |
|-------------------|----------|----------|----------|
| before this patch | 15.0     | 12613    | 2078     |
| + this patch      | 14.5     | 12628    | 2079     |

Qwen3.5-35B-A3B-FP8 (TP1, ISL=4096, OSL=2048, concurrency=384):
| Version           | TPOT(ms) | TTFT(ms) | Tput/GPU |
|-------------------|----------|----------|----------|
| before this patch | 133.3    | 32559    | 7722     |
| + this patch      | 129.4    | 32400    | 7937     |

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings May 9, 2026 04:53
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR optimizes Gated Delta Net (GDN) attention for decode-only batches by leveraging SGLang’s fused sigmoid-gating + delta-rule update kernel to reduce intermediate tensor materialization and kernel launches, while also improving compatibility across different attention-metadata/cache layouts.

Changes:

  • Adds an optional SGLang fused decode path (fused_sigmoid_gating_delta_rule_update) and skips materializing g/beta when that path is used.
  • Makes conv_state transposition conditional to support differing cache layouts (ModelRunner vs SGLang).
  • Accesses gdn_metadata via getattr to avoid errors when non-GDN attention metadata is used.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +25 to +27
sglang_fused_sigmoid_gating_delta_rule_update = importlib.import_module(
"sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent"
).fused_sigmoid_gating_delta_rule_update
Comment on lines +191 to +194
# transpose below. SGLang already provides [slot, conv_dim, state_len],
# and the Triton kernel consumes the original conv_state strides directly.
if conv_state.size(1) != self.conv1d.weight.size(0):
# transpose for ModelRunner
# and the Triton kernel consumes the original conv_state strides directly.
if conv_state.size(1) != self.conv1d.weight.size(0):
# transpose for ModelRunner
conv_state = conv_state.transpose(-1, -2)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the fix about conv state layout here have been merged in PR #532, please rebase main for the changes.


try:
sglang_fused_sigmoid_gating_delta_rule_update = importlib.import_module(
"sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that this optimization is customized for the SGLang framework and for code maintainability, would it be better to implement it via inheritance in atom/plugin/sglang/attention_backend/attention_gdn.py?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants