[Do Not Merge] For review purpose: Rocm/aiter mla dsv4 decode cudagraph#900
Draft
tjtanaavllm wants to merge 16 commits intoROCm:hexwang/dsv4_adapt_upstreamfrom
Draft
Conversation
Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu> Signed-off-by: Woosuk Kwon <woosuk@inferact.ai> Signed-off-by: Nick Hill <nickhill123@gmail.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: yasong.wang <yasong.wang@inferact.ai> Signed-off-by: Zhewen Li <zhewenli@inferact.ai> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
…oject#225) Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: ganyi <ygan@amd.com> Made-with: Cursor
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Replace the PyTorch reference sparse MLA decode with AITER's
persistent-mode ASM kernel (aiter.mla.mla_decode_fwd) on gfx950.
This gives ~2-3x decode speedup at high batch sizes.
Key changes:
- New module: vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py
- AiterSparseScratch: lazy-init persistent-mode metadata buffers,
keyed by (batch, nhead, topk, dtype) so 61 layers share one
allocation per decode step
- aiter_sparse_attn_decode: drop-in replacement handling dual-scope
attention (SWA + extra), LSE-based merging, and attn_sink correction
- Uses FP8/FP8 path only (gfx950 persistent-mode + return_lse
requires FP8)
- Fixed-stride kv_indices layout with -1 sentinels (required by
AITER persistent-mode kernels)
- deepseek_v4_attention.py:
- Add _aiter_scratch / _aiter_extra_scratch fields to __init__
- Gate ROCm decode path: VLLM_ROCM_USE_AITER_MLA_DSV4_DECODE=1
routes to _forward_decode_aiter, otherwise falls back to the
existing PyTorch reference
- Fix missing RoutingMethodType import in fused_moe/oracle/mxfp4.py
Validated numerically (cosine > 0.999) across TP2/TP4/TP8 configs
on MI355X. Micro-benchmarked at 2.4x speedup (b=128, dual-scope).
Co-authored-by: Claude
Signed-off-by: Chuan Li <chuan.li@amd.com>
Made-with: Cursor
Address review feedback (tjtanaa, vllm-project#40889): on ROCm, DeepSeek sparse attention can only run through AITER, so gating the op with VLLM_ROCM_USE_AITER_MLA_DSV4_DECODE adds no value. - Remove is_aiter_dsv4_decode_enabled() and the env-var lookup from vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py. - Simplify the ROCm branch in DeepseekV4MLAAttention._forward_decode to dispatch unconditionally to _forward_decode_aiter. - Drop the now-unused os import and the env-var mention in the _forward_decode_aiter docstring. Signed-off-by: Chuan Li <chuan.li@amd.com> Made-with: Cursor
Hoist all per-step allocations on the AITER sparse-MLA decode path into
the existing `AiterSparseScratch` cache so cudagraph capture sees stable
memory layouts. Previously each layer's call site freshly allocated
`qo_indptr`, `kv_indptr`, `kv_indices`, `kv_last_page_lens`, `q_scale`,
`kv_scale`, `q_fp8`, the bf16 output buffer, and intermediate boolean
masks every step, which was incompatible with HIP-graph capture and
generated unnecessary allocator pressure with 61 DSv4 attention layers.
Changes
-------
* `AiterSparseScratch` now caches:
* Static buffers: `qo_indptr` (arange), `kv_last_page_lens` (ones),
`col_arange`, `q_scale`, `kv_scale`.
* Per-step write buffers: `kv_indptr`, `kv_indices_2d`, `valid_mask`,
`valid_lens`, `q_fp8`, `out_buf`.
* `rebuild()` allocates and (where applicable) initialises every buffer
once per `(total_q, h_q, topk, d_qk, d_v, dtype, kvtype)` key and runs
`aiter.get_mla_metadata_v1` against the persistent qo/kv/last-page
tensors.
* `_aiter_decode_one_scope` rewrites the per-step buffers in-place via
`torch.lt(out=)`, `tensor.copy_`, `masked_fill_`, and
`torch.cumsum(out=)` instead of fresh allocations.
* The public `aiter_sparse_attn_decode` signature is unchanged, so
`DeepseekV4MLAAttention._forward_decode_aiter` keeps working as-is.
Follow-up to PR vllm-project#40889; remaining cudagraph blocker is the per-step
`blocked_k.to(fp8_e4m3fn)` cast on the dequantised KV cache, which is
tracked separately together with the dequantise-into-FP8 fast path
suggested in code review.
Test plan
---------
* `python -c "import ast; ast.parse(open('vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py').read())"` passes.
* On MI355X with `chuali_glm51` container (rocm/vllm-dev:nightly +
torch 2.10/HIP 7.2): manual smoke test pending; will follow up with
cudagraph-mode (non-eager) startup log + decode parity vs eager.
AI assistance: drafted with Cursor agent; human-reviewed before
submission.
Signed-off-by: ChuanLi1101 <chuanli1101@gmail.com>
`aiter.get_mla_metadata_v1` produces a `work_*`/`reduce_*` plan that is
keyed on the *actual* per-batch kv lengths, not just on shapes. The
persistent ASM `mla_a8w8_qh16_qseqlen1_gqaratio16_lse_ps` kernel reads
out of bounds (causing a GPU memory access fault) if those buffers are
left stale across steps with different kv lengths.
Fix the cudagraph-clean refactor so the metadata is rewritten in-place
on every per-step call against the current `kv_indptr`. The buffer
sizes returned by `get_mla_metadata_info_v1` are determined by shapes
+ `max_split_per_batch` only, so they remain large enough for any kv
length distribution and the data pointers stay stable for graph capture.
* `AiterSparseScratch.rebuild()` now only allocates buffers and stores
the static gqa/topk/dtype parameters; it no longer requires a
`kv_indptr_seed` and no longer runs the metadata builder itself.
* New `AiterSparseScratch.refresh_metadata()` reruns
`get_mla_metadata_v1` writing into the same `work_*`/`reduce_*` slots.
* `_aiter_decode_one_scope` writes `valid_mask`/`valid_lens`/
`kv_indptr`/`kv_indices_2d`/`q_fp8` directly into scratch every
step, then calls `refresh_metadata()` and `mla.mla_decode_fwd`.
Validated with the standalone `bench_remote/_unit_test_cudagraph.py`
harness on MI355X:
- Call 1 (lens=[3,2]): success, scratch key set.
- Call 2 (same lens): rebuild skipped, all data_ptrs stable, output
bit-identical to call 1.
- Call 3 (lens=[4,1]): all data_ptrs still stable, output differs as
expected (max abs diff = 2.39 vs identical-input call), no fault.
- Parity check vs the original non-cudagraph implementation:
max abs diff = 0.000000.
Signed-off-by: Chuan Li <chuanli1101@gmail.com>
Co-authored-by: Cursor
Signed-off-by: Li <chuali@amd.com>
caa50dd to
d040af0
Compare
e786a2d to
72346ca
Compare
c9b9e2a to
ab79b0a
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Purpose
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.