Skip to content

Enable FLASH_ATTN backend with upstream flash-attn CK on ROCm for decode#866

Draft
mgehre-amd wants to merge 6 commits intoROCm:gfx11from
mgehre-amd:matthias.flash-attn-ck-backend
Draft

Enable FLASH_ATTN backend with upstream flash-attn CK on ROCm for decode#866
mgehre-amd wants to merge 6 commits intoROCm:gfx11from
mgehre-amd:matthias.flash-attn-ck-backend

Conversation

@mgehre-amd
Copy link
Copy Markdown

@mgehre-amd mgehre-amd commented Apr 10, 2026

The FLASH_ATTN backend in vLLM V1 was tightly coupled to vllm_flash_attn (CUDA-only). On ROCm, fa_utils.py imported upstream flash_attn_varlen_func but the forward path passed vllm-specific kwargs (out, fa_version, scheduler_metadata, etc.) that the upstream API doesn't accept, and get_flash_attn_version() returned None causing an assertion failure.

Changes:

  • Replace raw upstream import with a wrapper in fa_utils.py that translates vLLM's calling convention to the upstream _wrapped_flash_attn_varlen_forward API, handling seqused_k -> cu_seqlens_k conversion for paged KV cache
  • Return FA version 2 on ROCm when upstream flash-attn is available
  • Set block_size to MultipleOf(128) on ROCm (CK kernel requirement)

The Triton AMD backend (FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE) does NOT support paged attention, so FLASH_ATTN as LLM backend requires the CK backend (FLASH_ATTENTION_TRITON_AMD_ENABLE unset).

Validated on gfx1151 with Qwen2.5-1.5B-Instruct: correct text generation with paged KV cache, concurrent requests, and multi-token decode.

  • TODO: detect whether flash-attn contains CK backend; and whether the CK backend is enabled (FLASH_ATTENTION_TRITON_AMD_ENABLE unset).

git fetch upstream --tags and git describe can fail if the
upstream repo is unreachable or no tags are reachable from HEAD.
Use || to avoid aborting the workflow.

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Trigger on push to gfx11 instead of main/matthias.awq_gemv.
Remove create-release and publish-to-gh-pages jobs.
Wheel is available as a GitHub Actions artifact.

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
- Trigger workflow on PRs targeting gfx11 (build-only)
- On push to gfx11, upload wheel to S3 via OIDC + boto3
- S3 upload gated on ROCm org to skip on forks

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
The FLASH_ATTN backend in vLLM V1 was tightly coupled to vllm_flash_attn
(CUDA-only). On ROCm, fa_utils.py imported upstream flash_attn_varlen_func
but the forward path passed vllm-specific kwargs (out, fa_version,
scheduler_metadata, etc.) that the upstream API doesn't accept, and
get_flash_attn_version() returned None causing an assertion failure.

Changes:
- Replace raw upstream import with a wrapper in fa_utils.py that translates
  vLLM's calling convention to the upstream _wrapped_flash_attn_varlen_forward
  API, handling seqused_k -> cu_seqlens_k conversion for paged KV cache
- Return FA version 2 on ROCm when upstream flash-attn is available
- Set block_size to MultipleOf(128) on ROCm (CK kernel requirement)

The Triton AMD backend (FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE) does NOT
support paged attention, so FLASH_ATTN as LLM backend requires the CK
backend (FLASH_ATTENTION_TRITON_AMD_ENABLE unset).

Validated on gfx1151 with Qwen2.5-1.5B-Instruct: correct text generation
with paged KV cache, concurrent requests, and multi-token decode.

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
@mgehre-amd mgehre-amd changed the title Enable FLASH_ATTN backend with upstream flash-attn CK on ROCm Enable FLASH_ATTN backend with upstream flash-attn CK on ROCm for decode May 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant