-
Notifications
You must be signed in to change notification settings - Fork 721
[PyTorch] Enable head dim 256 for FA4 #2932
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
3d0fcd7
8aa5242
ad00e76
472c9dd
3090b57
7e9faf1
8fafa1f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,7 +7,7 @@ | |
| """ | ||
| import math | ||
| import os | ||
| from typing import Any, Dict, List, Optional, Tuple, Union | ||
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union | ||
| import warnings | ||
| import logging | ||
| import functools | ||
|
|
@@ -147,8 +147,11 @@ class FlashAttentionUtils: | |
| fa4_version = PkgVersion("0") | ||
| use_v4 = False | ||
| v4_installation_steps = """\ | ||
| pip install flash-attn-4==4.0.0b8 nvidia-cutlass-dsl[cu13]""" | ||
| pip install flash-attn-4==4.0.0b11 nvidia-cutlass-dsl[cu13]""" | ||
| v4_warning_printed = False | ||
| # Set by backends.py if FA4 is installed; calls flash_attn.cute.interface._validate_head_dims | ||
| # which raises AssertionError for unsupported (head_dim, head_dim_v) combinations. | ||
| v4_validate_head_dims: Callable = None | ||
|
|
||
| @staticmethod | ||
| def set_flash_attention_version(): | ||
|
|
@@ -792,21 +795,25 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt | |
| ) | ||
| use_flash_attention_3 = False | ||
|
|
||
| if use_flash_attention_4 and FlashAttentionUtils.v4_is_installed: | ||
| # FA4 head dimension support is architecture-dependent | ||
| # (matches _validate_head_dims in flash_attn.cute.interface): | ||
| # SM90: head_dim <= 256 and head_dim_v <= 256 | ||
| # SM100/110: head_dim <= 128 and head_dim_v <= 128, | ||
| # OR DeepSeek MLA shape (head_dim=192, head_dim_v=128) | ||
| # SM80/120: constrained by shared memory (~256 max in practice) | ||
| _fa4_hdim_ok = True | ||
| if (10, 0) <= device_compute_capability < (12, 0): | ||
| _is_standard = head_dim_qk <= 128 and head_dim_v <= 128 | ||
| _is_deepseek = head_dim_qk == 192 and head_dim_v == 128 | ||
| _fa4_hdim_ok = _is_standard or _is_deepseek | ||
| else: | ||
| _fa4_hdim_ok = head_dim_qk <= 256 and head_dim_v <= 256 | ||
| if not _fa4_hdim_ok: | ||
| if ( | ||
| use_flash_attention_4 | ||
| and FlashAttentionUtils.v4_is_installed | ||
| and FlashAttentionUtils.v4_validate_head_dims is not None | ||
| ): | ||
| # Defer to FA4's own _validate_head_dims to keep TE in sync with FA4 supported shapes | ||
| # (e.g., (256, 256) on SM100, (192, 128) DeepSeek, (64, 512) MLA-absorbed). | ||
| # The function asserts on unsupported combinations; SM80/SM120 have no validation branch | ||
| # in FA4 so the call passes through silently for those archs. | ||
| _fa4_alignment = 16 // torch.empty(0, dtype=qkv_dtype).element_size() | ||
| try: | ||
| # pylint: disable-next=not-callable | ||
| FlashAttentionUtils.v4_validate_head_dims( | ||
| head_dim_qk, | ||
| head_dim_v, | ||
| device_compute_capability[0], | ||
| _fa4_alignment, | ||
| ) | ||
| except AssertionError: | ||
| logger.debug( | ||
| "Disabling FlashAttention 4 due to unsupported head dimensions. " | ||
| "Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.", | ||
|
|
@@ -815,13 +822,33 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt | |
| device_compute_capability[0] * 10 + device_compute_capability[1], | ||
| ) | ||
| use_flash_attention_4 = False | ||
| # Workaround: SM100 backward kernel bug when MLA + 2CTA (head_dim_qk >= 128). | ||
| # FlashAttentionBackwardSm100 computes dK_reduce_ncol = gcd(32, tile_hdim // 2) | ||
| # based on Q/K head_dim but reuses it for dV TMEM load atoms. When | ||
| # (tile_hdimv // 2) % dK_reduce_ncol != 0, dV reads are misaligned. | ||
| # See: flash_attn/cute/flash_bwd_sm100.py, line ~262 and ~3890. | ||
| elif ( | ||
| _fa4_hdim_ok | ||
| # flash-attn-4 4.0.0b11 validates (256, 256) on SM100, but its dedicated | ||
| # hd256 kernel diverges from the reference for cross-attention/decode-like | ||
| # shapes such as sq=1, skv=2048. Keep FA4 enabled for the self-attention | ||
| # hd256 path covered by the dedicated test, and fall back for cross-attn. | ||
| if ( | ||
| use_flash_attention_4 | ||
| and (10, 0) <= device_compute_capability < (12, 0) | ||
| and head_dim_qk == head_dim_v == 256 | ||
| and max_seqlen_q != max_seqlen_kv | ||
| ): | ||
| logger.debug( | ||
| "Disabling FlashAttention 4 for SM100 head_dim=256 cross-attention. " | ||
| "Found: max_seqlen_q = %s, max_seqlen_kv = %s.", | ||
| max_seqlen_q, | ||
| max_seqlen_kv, | ||
| ) | ||
| use_flash_attention_4 = False | ||
| # Workaround: SM100 backward kernel bug when MLA + 2CTA (head_dim_qk >= 128) for | ||
| # the standard (non-dedicated) kernel path. FlashAttentionBackwardSm100 computes | ||
| # dK_reduce_ncol = gcd(32, tile_hdim // 2) based on Q/K head_dim but reuses it for | ||
| # dV TMEM load atoms. When (tile_hdimv // 2) % dK_reduce_ncol != 0, dV reads are | ||
| # misaligned (e.g. dqk=128, dv=96 gives 48 % 32 != 0). The dedicated (256, 256) | ||
| # kernel uses its own tmem layout and is not affected. | ||
| # See: flash_attn/cute/flash_bwd_sm100.py ~L262 and ~L3890. Still present in | ||
| # flash-attn-4 4.0.0b11. | ||
| if ( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this still be checked when
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I double checked that this is a bug of FA4. Kernels produce wrong results on these shapes but they're allowed by |
||
| use_flash_attention_4 | ||
| and is_training | ||
| and head_dim_qk != head_dim_v | ||
| and head_dim_qk >= 128 | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.