Skip to content

[Do Not Merge] For review purpose: Rocm/aiter mla dsv4 decode cudagraph#900

Draft
tjtanaavllm wants to merge 16 commits intoROCm:hexwang/dsv4_adapt_upstreamfrom
ChuanLi1101:rocm/aiter-mla-dsv4-decode-cudagraph
Draft

[Do Not Merge] For review purpose: Rocm/aiter mla dsv4 decode cudagraph#900
tjtanaavllm wants to merge 16 commits intoROCm:hexwang/dsv4_adapt_upstreamfrom
ChuanLi1101:rocm/aiter-mla-dsv4-decode-cudagraph

Conversation

@tjtanaavllm
Copy link
Copy Markdown

@tjtanaavllm tjtanaavllm commented Apr 26, 2026

Purpose

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

zyongye and others added 16 commits April 24, 2026 02:58
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Signed-off-by: Nick Hill <nickhill123@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: yasong.wang <yasong.wang@inferact.ai>
Signed-off-by: Zhewen Li <zhewenli@inferact.ai>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: ganyi <ygan@amd.com>
Made-with: Cursor
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Replace the PyTorch reference sparse MLA decode with AITER's
persistent-mode ASM kernel (aiter.mla.mla_decode_fwd) on gfx950.
This gives ~2-3x decode speedup at high batch sizes.

Key changes:
- New module: vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py
  - AiterSparseScratch: lazy-init persistent-mode metadata buffers,
    keyed by (batch, nhead, topk, dtype) so 61 layers share one
    allocation per decode step
  - aiter_sparse_attn_decode: drop-in replacement handling dual-scope
    attention (SWA + extra), LSE-based merging, and attn_sink correction
  - Uses FP8/FP8 path only (gfx950 persistent-mode + return_lse
    requires FP8)
  - Fixed-stride kv_indices layout with -1 sentinels (required by
    AITER persistent-mode kernels)

- deepseek_v4_attention.py:
  - Add _aiter_scratch / _aiter_extra_scratch fields to __init__
  - Gate ROCm decode path: VLLM_ROCM_USE_AITER_MLA_DSV4_DECODE=1
    routes to _forward_decode_aiter, otherwise falls back to the
    existing PyTorch reference

- Fix missing RoutingMethodType import in fused_moe/oracle/mxfp4.py

Validated numerically (cosine > 0.999) across TP2/TP4/TP8 configs
on MI355X. Micro-benchmarked at 2.4x speedup (b=128, dual-scope).

Co-authored-by: Claude
Signed-off-by: Chuan Li <chuan.li@amd.com>
Made-with: Cursor
Address review feedback (tjtanaa, vllm-project#40889): on ROCm,
DeepSeek sparse attention can only run through AITER, so gating the
op with VLLM_ROCM_USE_AITER_MLA_DSV4_DECODE adds no value.

- Remove is_aiter_dsv4_decode_enabled() and the env-var lookup from
  vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py.
- Simplify the ROCm branch in DeepseekV4MLAAttention._forward_decode
  to dispatch unconditionally to _forward_decode_aiter.
- Drop the now-unused os import and the env-var mention in the
  _forward_decode_aiter docstring.

Signed-off-by: Chuan Li <chuan.li@amd.com>
Made-with: Cursor
Hoist all per-step allocations on the AITER sparse-MLA decode path into
the existing `AiterSparseScratch` cache so cudagraph capture sees stable
memory layouts. Previously each layer's call site freshly allocated
`qo_indptr`, `kv_indptr`, `kv_indices`, `kv_last_page_lens`, `q_scale`,
`kv_scale`, `q_fp8`, the bf16 output buffer, and intermediate boolean
masks every step, which was incompatible with HIP-graph capture and
generated unnecessary allocator pressure with 61 DSv4 attention layers.

Changes
-------

* `AiterSparseScratch` now caches:
  * Static buffers: `qo_indptr` (arange), `kv_last_page_lens` (ones),
    `col_arange`, `q_scale`, `kv_scale`.
  * Per-step write buffers: `kv_indptr`, `kv_indices_2d`, `valid_mask`,
    `valid_lens`, `q_fp8`, `out_buf`.
* `rebuild()` allocates and (where applicable) initialises every buffer
  once per `(total_q, h_q, topk, d_qk, d_v, dtype, kvtype)` key and runs
  `aiter.get_mla_metadata_v1` against the persistent qo/kv/last-page
  tensors.
* `_aiter_decode_one_scope` rewrites the per-step buffers in-place via
  `torch.lt(out=)`, `tensor.copy_`, `masked_fill_`, and
  `torch.cumsum(out=)` instead of fresh allocations.
* The public `aiter_sparse_attn_decode` signature is unchanged, so
  `DeepseekV4MLAAttention._forward_decode_aiter` keeps working as-is.

Follow-up to PR vllm-project#40889; remaining cudagraph blocker is the per-step
`blocked_k.to(fp8_e4m3fn)` cast on the dequantised KV cache, which is
tracked separately together with the dequantise-into-FP8 fast path
suggested in code review.

Test plan
---------

* `python -c "import ast; ast.parse(open('vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py').read())"` passes.
* On MI355X with `chuali_glm51` container (rocm/vllm-dev:nightly +
  torch 2.10/HIP 7.2): manual smoke test pending; will follow up with
  cudagraph-mode (non-eager) startup log + decode parity vs eager.

AI assistance: drafted with Cursor agent; human-reviewed before
submission.

Signed-off-by: ChuanLi1101 <chuanli1101@gmail.com>
`aiter.get_mla_metadata_v1` produces a `work_*`/`reduce_*` plan that is
keyed on the *actual* per-batch kv lengths, not just on shapes. The
persistent ASM `mla_a8w8_qh16_qseqlen1_gqaratio16_lse_ps` kernel reads
out of bounds (causing a GPU memory access fault) if those buffers are
left stale across steps with different kv lengths.

Fix the cudagraph-clean refactor so the metadata is rewritten in-place
on every per-step call against the current `kv_indptr`. The buffer
sizes returned by `get_mla_metadata_info_v1` are determined by shapes
+ `max_split_per_batch` only, so they remain large enough for any kv
length distribution and the data pointers stay stable for graph capture.

  * `AiterSparseScratch.rebuild()` now only allocates buffers and stores
    the static gqa/topk/dtype parameters; it no longer requires a
    `kv_indptr_seed` and no longer runs the metadata builder itself.
  * New `AiterSparseScratch.refresh_metadata()` reruns
    `get_mla_metadata_v1` writing into the same `work_*`/`reduce_*` slots.
  * `_aiter_decode_one_scope` writes `valid_mask`/`valid_lens`/
    `kv_indptr`/`kv_indices_2d`/`q_fp8` directly into scratch every
    step, then calls `refresh_metadata()` and `mla.mla_decode_fwd`.

Validated with the standalone `bench_remote/_unit_test_cudagraph.py`
harness on MI355X:
  - Call 1 (lens=[3,2]): success, scratch key set.
  - Call 2 (same lens): rebuild skipped, all data_ptrs stable, output
    bit-identical to call 1.
  - Call 3 (lens=[4,1]): all data_ptrs still stable, output differs as
    expected (max abs diff = 2.39 vs identical-input call), no fault.
  - Parity check vs the original non-cudagraph implementation:
    max abs diff = 0.000000.

Signed-off-by: Chuan Li <chuanli1101@gmail.com>
Co-authored-by: Cursor
Signed-off-by: Li <chuali@amd.com>
@whx-sjtu whx-sjtu force-pushed the hexwang/dsv4_adapt_upstream branch 5 times, most recently from caa50dd to d040af0 Compare May 1, 2026 03:30
@ganyi1996ppo ganyi1996ppo force-pushed the hexwang/dsv4_adapt_upstream branch from e786a2d to 72346ca Compare May 2, 2026 13:51
@whx-sjtu whx-sjtu force-pushed the hexwang/dsv4_adapt_upstream branch 5 times, most recently from c9b9e2a to ab79b0a Compare May 5, 2026 08:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants