Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
2b9fbc5
refactor nvte_get_fused_attn_backend with FE calls
cyanguwa May 6, 2026
16b837c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 6, 2026
5a482f9
Merge branch 'main' into fe_check_support
cyanguwa May 6, 2026
42bcd89
replace code+string with string only
cyanguwa May 7, 2026
de8e814
clean up logic/comments/structure
cyanguwa May 8, 2026
4b8c7ed
Merge branch 'main' into fe_check_support
cyanguwa May 8, 2026
81e59a9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 8, 2026
5640c68
Merge branch 'main' into fe_check_support
cyanguwa May 8, 2026
6c5126d
fix compilation errors
cyanguwa May 8, 2026
d35bff7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 8, 2026
f6fc585
remove handle from API; add bottom_right_diagonal
cyanguwa May 8, 2026
3e666b0
add batch_size to API
cyanguwa May 8, 2026
056aba6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 8, 2026
e054863
fix jax binding
cyanguwa May 8, 2026
a7fe928
specify o_dtype for FP8s
cyanguwa May 8, 2026
c9b22b5
fix BRCM and custom_fp8 tests
cyanguwa May 8, 2026
ac44e66
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 8, 2026
9131b2d
add o_format/etc to API and other tweaks
cyanguwa May 8, 2026
956f159
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 8, 2026
b21f606
minor tweaks for docstring
cyanguwa May 8, 2026
3421920
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 8, 2026
7956b43
replace with nvte_get_fused_attn_backend_v2 and add NVTEFusedAttnConfig
cyanguwa May 8, 2026
e86fc67
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 8, 2026
e2561d0
fix FP8 tests
cyanguwa May 12, 2026
724a12f
add do_dtype and dqkv_dtype to API
cyanguwa May 12, 2026
3ae36df
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 12, 2026
3532e98
Merge branch 'main' into fe_check_support
cyanguwa May 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def impl_test_self_attn(

if not is_fused_attn_kernel_available(
is_training,
batch,
dtype,
dtype,
QKVLayout.BS3HD,
Expand Down Expand Up @@ -227,6 +228,7 @@ def test_cross_attn(

if not is_fused_attn_kernel_available(
is_training,
batch,
dtype,
dtype,
QKVLayout.BSHD_BS2HD,
Expand Down Expand Up @@ -368,6 +370,7 @@ def impl_test_context_parallel_attn(
def check_has_backend_for_mask(mask_type):
return is_fused_attn_kernel_available(
is_training,
batch,
dtype,
dtype,
qkv_layout,
Expand Down
6 changes: 4 additions & 2 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,9 @@ def _check_configs(self):
"is either BSHD_BSHD_BSHD or THD_THD_THD"
)

self.backend = FusedAttnHelper(
self.backend, message = FusedAttnHelper(
self.is_training,
self.batch_size,
self.dtype,
self.dtype,
self.qkv_layout,
Expand All @@ -460,9 +461,10 @@ def _check_configs(self):
self.head_dim_qk,
self.head_dim_v,
(-1, -1) if self.window_size is None else self.window_size,
self.attn_mask_type.is_bottom_right(),
).get_fused_attn_backend()
if self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
pytest.skip("Unsupported inputs combination or device compute capability.")
pytest.skip(message)

if (
self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS
Expand Down
24 changes: 24 additions & 0 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1775,12 +1775,23 @@ def test_dpa_fp8_extra_state(model, dtype):
config = model_configs_fp8_extra_state[model]
# Test backend availability
is_training = True
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=True,
)
fp8_meta = {}
fp8_meta["recipe"] = fp8_recipe
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=torch.float8_e4m3fn,
qkv_layout="sb3hd",
is_training=is_training,
deterministic=_deterministic,
fp8=True,
fp8_meta=fp8_meta,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported and not flash_attn_supported:
Expand Down Expand Up @@ -2567,13 +2578,25 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
Both paths take F16 input and output. QKV layout is bs3hd"""

config = model_configs_fp8[model]
os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "1"

# Test backend availability
is_training = True
fp8_meta = {}
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=True,
)
fp8_meta["recipe"] = fp8_recipe
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=torch.float8_e4m3fn,
qkv_layout="bs3hd",
fp8=True,
fp8_meta=fp8_meta,
is_training=is_training,
deterministic=_deterministic,
)
Expand Down Expand Up @@ -2651,6 +2674,7 @@ def _run_custom_mha_fp8(dtype, config, backend):
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=True,
)

mha = Custom_MHA_FP8(config).to(dtype=dtype, device="cuda")
Expand Down
5 changes: 5 additions & 0 deletions tests/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,10 @@ def __init__(
self.attn_type = "self" if (self.max_seqlen_q == self.max_seqlen_kv) else "cross"
self.bias_shape = bias_shape
self.window_size = check_set_window_size(self.attn_mask_type, window_size)
self.bottom_right_diagonal = self.attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}
self.context_parallel = context_parallel
self.cp_comm_type = cp_comm_type
self.return_max_logit = return_max_logit
Expand Down Expand Up @@ -376,6 +380,7 @@ def test():
head_dim_v=config.head_dim_v,
attn_mask_type=config.attn_mask_type,
window_size=config.window_size,
bottom_right_diagonal=config.bottom_right_diagonal,
alibi_slopes_shape=alibi_slopes_shape,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias_shape=core_attention_bias_shape,
Expand Down
Loading
Loading