Enable FLASH_ATTN backend with upstream flash-attn CK on ROCm for decode#866
Draft
mgehre-amd wants to merge 6 commits intoROCm:gfx11from
Draft
Enable FLASH_ATTN backend with upstream flash-attn CK on ROCm for decode#866mgehre-amd wants to merge 6 commits intoROCm:gfx11from
mgehre-amd wants to merge 6 commits intoROCm:gfx11from
Conversation
git fetch upstream --tags and git describe can fail if the upstream repo is unreachable or no tags are reachable from HEAD. Use || to avoid aborting the workflow. Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Trigger on push to gfx11 instead of main/matthias.awq_gemv. Remove create-release and publish-to-gh-pages jobs. Wheel is available as a GitHub Actions artifact. Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
- Trigger workflow on PRs targeting gfx11 (build-only) - On push to gfx11, upload wheel to S3 via OIDC + boto3 - S3 upload gated on ROCm org to skip on forks Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
The FLASH_ATTN backend in vLLM V1 was tightly coupled to vllm_flash_attn (CUDA-only). On ROCm, fa_utils.py imported upstream flash_attn_varlen_func but the forward path passed vllm-specific kwargs (out, fa_version, scheduler_metadata, etc.) that the upstream API doesn't accept, and get_flash_attn_version() returned None causing an assertion failure. Changes: - Replace raw upstream import with a wrapper in fa_utils.py that translates vLLM's calling convention to the upstream _wrapped_flash_attn_varlen_forward API, handling seqused_k -> cu_seqlens_k conversion for paged KV cache - Return FA version 2 on ROCm when upstream flash-attn is available - Set block_size to MultipleOf(128) on ROCm (CK kernel requirement) The Triton AMD backend (FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE) does NOT support paged attention, so FLASH_ATTN as LLM backend requires the CK backend (FLASH_ATTENTION_TRITON_AMD_ENABLE unset). Validated on gfx1151 with Qwen2.5-1.5B-Instruct: correct text generation with paged KV cache, concurrent requests, and multi-token decode. Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The FLASH_ATTN backend in vLLM V1 was tightly coupled to vllm_flash_attn (CUDA-only). On ROCm, fa_utils.py imported upstream flash_attn_varlen_func but the forward path passed vllm-specific kwargs (out, fa_version, scheduler_metadata, etc.) that the upstream API doesn't accept, and get_flash_attn_version() returned None causing an assertion failure.
Changes:
The Triton AMD backend (FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE) does NOT support paged attention, so FLASH_ATTN as LLM backend requires the CK backend (FLASH_ATTENTION_TRITON_AMD_ENABLE unset).
Validated on gfx1151 with Qwen2.5-1.5B-Instruct: correct text generation with paged KV cache, concurrent requests, and multi-token decode.