perf: optimize GDN decode with SGLang fused recurrent kernel#727
perf: optimize GDN decode with SGLang fused recurrent kernel#727
Conversation
Use SGLang's fused_sigmoid_gating_delta_rule_update for pure-decode batches, which fuses gating computation (sigmoid, softplus, dt_bias) with the delta rule state update into a single kernel call, avoiding materialization of intermediate g/beta tensors. Also fix conv_state transpose to be conditional — SGLang provides [slot, conv_dim, state_len] layout directly, while ModelRunner uses [slot, state_len, conv_dim]. Benchmark (MI308X, TP2, ISL=60381, OSL=132): Qwen3.5-27B bf16: | Version | TPOT(ms) | TTFT(ms) | Tput/GPU | |-------------------|----------|----------|----------| | before this patch | 15.0 | 12613 | 2078 | | + this patch | 14.5 | 12628 | 2079 | Qwen3.5-35B-A3B-FP8 (TP1, ISL=4096, OSL=2048, concurrency=384): | Version | TPOT(ms) | TTFT(ms) | Tput/GPU | |-------------------|----------|----------|----------| | before this patch | 133.3 | 32559 | 7722 | | + this patch | 129.4 | 32400 | 7937 | Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
This PR optimizes Gated Delta Net (GDN) attention for decode-only batches by leveraging SGLang’s fused sigmoid-gating + delta-rule update kernel to reduce intermediate tensor materialization and kernel launches, while also improving compatibility across different attention-metadata/cache layouts.
Changes:
- Adds an optional SGLang fused decode path (
fused_sigmoid_gating_delta_rule_update) and skips materializingg/betawhen that path is used. - Makes
conv_statetransposition conditional to support differing cache layouts (ModelRunner vs SGLang). - Accesses
gdn_metadataviagetattrto avoid errors when non-GDN attention metadata is used.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| sglang_fused_sigmoid_gating_delta_rule_update = importlib.import_module( | ||
| "sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent" | ||
| ).fused_sigmoid_gating_delta_rule_update |
| # transpose below. SGLang already provides [slot, conv_dim, state_len], | ||
| # and the Triton kernel consumes the original conv_state strides directly. | ||
| if conv_state.size(1) != self.conv1d.weight.size(0): | ||
| # transpose for ModelRunner |
| # and the Triton kernel consumes the original conv_state strides directly. | ||
| if conv_state.size(1) != self.conv1d.weight.size(0): | ||
| # transpose for ModelRunner | ||
| conv_state = conv_state.transpose(-1, -2) |
There was a problem hiding this comment.
the fix about conv state layout here have been merged in PR #532, please rebase main for the changes.
|
|
||
| try: | ||
| sglang_fused_sigmoid_gating_delta_rule_update = importlib.import_module( | ||
| "sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent" |
There was a problem hiding this comment.
Given that this optimization is customized for the SGLang framework and for code maintainability, would it be better to implement it via inheritance in atom/plugin/sglang/attention_backend/attention_gdn.py?
Summary
fused_sigmoid_gating_delta_rule_updatefor pure-decode batches in GDN attention, fusing gating computation (sigmoid, softplus, dt_bias) with the delta rule state update into a single kernel call[slot, conv_dim, state_len]layout directly, while ModelRunner uses[slot, state_len, conv_dim]getattrforgdn_metadataaccess to avoid AttributeError on non-GDN attention metadataBenchmark
Qwen3.5-27B bf16 (MI308X, TP2, ISL=60381, OSL=132):
Qwen3.5-35B-A3B-FP8 (MI308X, TP1, ISL=4096, OSL=2048, concurrency=384):
Files Changed
atom/model_ops/attention_gdn.py— fused decode path + conv_state fixTest plan
🤖 Generated with Claude Code