From 2b9fbc5c3dc59e86187350ba875eee819e8f2bf0 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 5 May 2026 18:50:17 -0700 Subject: [PATCH 01/23] refactor nvte_get_fused_attn_backend with FE calls Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn.cpp | 516 ++++++------------ .../fused_attn_f16_arbitrary_seqlen.cu | 136 +++++ .../fused_attn_f16_arbitrary_seqlen.h | 22 + .../common/fused_attn/fused_attn_fp8.cu | 101 ++++ .../common/fused_attn/fused_attn_fp8.h | 25 + .../include/transformer_engine/fused_attn.h | 50 +- .../jax/csrc/extensions/attention.cpp | 32 +- .../pytorch/csrc/extensions/attention.cpp | 17 +- 8 files changed, 539 insertions(+), 360 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 141767b803..615f7c2a03 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -226,357 +226,189 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { } } +namespace { + +// Per-thread storage for the message string handed back through +// NVTEFusedAttnBackendStatus::message. Re-used (cleared + re-populated) on every call to +// nvte_get_fused_attn_backend on this thread, which is exactly the lifetime documented in the +// public header. +thread_local std::string g_fused_attn_backend_status_buffer; + +// Apply (code, msg) to *out_status (if non-null), routing the message through the +// thread-local buffer so the returned `const char*` outlives this function call. +void set_status(NVTEFusedAttnBackendStatus *out_status, cudnn_frontend::error_code_t code, + const std::string &message) { + if (out_status == nullptr) return; + g_fused_attn_backend_status_buffer = message; + out_status->code = static_cast(code); + out_status->message = g_fused_attn_backend_status_buffer.c_str(); +} + +void set_status(NVTEFusedAttnBackendStatus *out_status, const cudnn_frontend::error_t &err) { + set_status(out_status, err.code, err.err_msg); +} + +void set_ok(NVTEFusedAttnBackendStatus *out_status) { + set_status(out_status, cudnn_frontend::error_code_t::OK, ""); +} + +} // namespace + // select a backend for fused attention +// +// Routing flow: +// 1. Apply TE post-filters that encode policies cuDNN-FE doesn't model directly: +// a. requires_64bit_ragged_offset -> cudnn >= 9.5 +// b. qkv_format == THD requires a padding-style mask +// c. cuDNN <= 9.15 + is_training + bshd/sbhd + max_seqlen_kv % 128 != 0 + +// cuda_graph + non-padding mask is rejected (known capture quirk) +// 2. Dispatch by dtype to the appropriate probe(s): +// - FP8 (E4M3/E5M2): is_supported_fp8_fwd (+ is_supported_fp8_bwd if training) +// - FP16/BF16: is_supported_f16_fwd (+ is_supported_f16_bwd if training) +// The probes call the same _impl that the executor uses, with workspace=nullptr. +// They run validate -> build_operation_graph -> create_execution_plans -> +// check_support -> build_plans, and populate a thread-local cache that the +// executor cache-hits on. +// 3. Return the selected backend, or NVTE_No_Backend if any probe rejects. +// +// When `out_status` is non-null, it is filled with a code + message describing the +// rejection (or {OK, ""} on success). TE post-filter rejections synthesize an +// INVALID_VALUE entry; probe rejections forward the cuDNN-FE / NVTE_CHECK error verbatim. NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( - bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic) { + bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTEDType o_dtype, + NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, + size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, + size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, + bool return_max_logit, bool cuda_graph, bool deterministic, cudnnHandle_t handle, + NVTEFusedAttnBackendStatus *out_status) { using namespace transformer_engine; - NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - const int device_id = cuda::current_device(); - const int sm_arch_ = cuda::sm_arch(device_id); + // Initialize to OK so callers get a clean status on the success path without us having to + // remember to set it at every return. + set_ok(out_status); NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type."); - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - auto cudnn_runtime_version = cudnnGetVersion(); - // For ragged offsets we only support 32-bit prior to cuDNN 9.5 - // Only used when THD format is requested. + const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + const auto cudnn_runtime_version = cudnnGetVersion(); + + // ---------- TE post-filters (apply before delegating to cuDNN-FE) ---------- + + // (1) Ragged-offset width: cuDNN < 9.5 only supports 32-bit offsets. const bool requires_64bit_ragged_offset = (qkv_format == NVTE_THD && fused_attn::get_ragged_offset_dtype( layout_group, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v) == DType::kInt64); - const bool supported_ragged_offset_size = - (!requires_64bit_ragged_offset || cudnn_runtime_version >= 90500); - - if ((q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2) && - sm_arch_ >= 90 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && - // 8.9: t3hd, max_s=512, d=64, padding - ((cudnn_runtime_version >= 8900 && sm_arch_ < 100 && - qkv_layout == NVTE_QKV_Layout::NVTE_T3HD && max_seqlen_q == max_seqlen_kv && - max_seqlen_q <= 512 && head_dim_qk == 64 && head_dim_v == 64 && - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - // 9.2.1: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal} - (cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 && - max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 && - (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || - // 9.7: {bshd, sbhd}, any seqlen, d<=256 for sm90 and d<=128 for sm100, {padding, padding_causal} - (cudnn_runtime_version >= 90700 && - // TODO (cyang): add is_training to nvte_get_fused_attn_backend - // sm90: fwd d<=256, bwd d=128 only - // sm100: fwd d<=128, bwd d<=128 - ((sm_arch_ < 100 && (!is_training) && head_dim_qk <= 256 && head_dim_v <= 256) || - (sm_arch_ < 100 && is_training && head_dim_qk == 128 && head_dim_v == 128) || - (sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) && - head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && - (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) || - // 9.21: d_qk=192, d_v=128 - (cudnn_runtime_version >= 92100 && sm_arch_ >= 100 && head_dim_qk <= 192 && - head_dim_v <= 128 && head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && - (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK))) && - // pre-9.21: {bshd, sbhd}, {vanilla} - // 9.21+: {bshd, sbhd, bhsd}, {vanilla, off-by-one, learnable} - ((cudnn_runtime_version < 92100 && - (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && - softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) || - (cudnn_runtime_version >= 92100 && - (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD || - qkv_format == NVTE_QKV_Format::NVTE_BHSD))) && - !requires_64bit_ragged_offset && - // 9.10.0: known bugs with SDPA FP8 - (cudnn_runtime_version != 91000) && !return_max_logit) { - if (cudnn_runtime_version >= 8900) { - backend = NVTE_Fused_Attn_Backend::NVTE_FP8; - } else { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: FP8 fused attention is supported by cuDNN 8.9.0+." - " Please upgrade your cuDNN version if possible." - << std::endl; - } - } else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) { - bool flag_m512 = false; - bool flag_arb = false; - if ((sm_arch_ == 80 || sm_arch_ == 90) && (max_seqlen_q <= 512 && max_seqlen_q % 64 == 0) && - (max_seqlen_kv <= 512 && max_seqlen_kv % 64 == 0) && (head_dim_qk == 64) && - (head_dim_v == 64) && (num_attn_heads == num_gqa_groups) && - ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || - (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && - ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && - max_seqlen_q == max_seqlen_kv) || - (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) && - ((qkv_layout == NVTE_QKV_Layout::NVTE_SB3HD) || - (qkv_layout == NVTE_QKV_Layout::NVTE_SBHD_SB2HD) || - (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) || - (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) || - (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)) && - ((window_size_left == -1) && (window_size_right == -1 || window_size_right == 0)) && - !requires_64bit_ragged_offset && - (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && !return_max_logit) { - flag_m512 = true; + if (requires_64bit_ragged_offset && cudnn_runtime_version < 90500) { + set_status(out_status, cudnn_frontend::error_code_t::INVALID_VALUE, + "Configuration requires 64-bit ragged offsets, which require cuDNN >= 9.5."); + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; + } + + // (2) THD requires a padding-style mask. + if (qkv_format == NVTE_QKV_Format::NVTE_THD && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) { + set_status(out_status, cudnn_frontend::error_code_t::INVALID_VALUE, + "THD-format attention requires a padding-style mask " + "(PADDING / PADDING_CAUSAL / PADDING_CAUSAL_BOTTOM_RIGHT)."); + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; + } + + // (3) cuDNN-Graph capture quirk on cuDNN <= 9.15: training + bshd/sbhd with + // max_seqlen_kv % 128 != 0 + cuda_graph + non-padding mask hangs/miscompiles. + if (cudnn_runtime_version <= 91500 && is_training && + (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && + (max_seqlen_kv % 128 != 0) && cuda_graph && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) { + set_status(out_status, cudnn_frontend::error_code_t::INVALID_VALUE, + "Known cuDNN <= 9.15 capture quirk: training + bshd/sbhd + " + "max_seqlen_kv % 128 != 0 + cuda_graph + non-padding mask is unsupported."); + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; + } + + // ---------- Dispatch by dtype ---------- + + // Probes use a single-batch graph; capability checks in cuDNN-FE are batch-agnostic. + constexpr size_t probe_batch = 1; + // bottom_right_diagonal is a runtime API knob the router doesn't see; the BRCM-via-mask + // case is captured by attn_mask_type, so we probe with the default top-left alignment. + constexpr bool probe_bottom_right_diagonal = false; + + const bool is_fp8 = + (q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2); + const bool is_f16_or_bf16 = + (q_dtype == NVTEDType::kNVTEFloat16 || q_dtype == NVTEDType::kNVTEBFloat16); + + if (is_fp8) { + // TE-only FP8 post-filters: no 64-bit ragged offsets, no max-logit output. + if (requires_64bit_ragged_offset) { + set_status(out_status, cudnn_frontend::error_code_t::INVALID_VALUE, + "FP8 fused attention does not support 64-bit ragged offsets."); + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } - if ( - // TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging - // architecture - ((cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90)) || - (cudnn_runtime_version >= 8903 && sm_arch_ >= 80 && sm_arch_ < 100) || - (cudnn_runtime_version >= 90700 && sm_arch_ >= 100)) && - // sequence length - ((cudnn_runtime_version < 90000 && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0) || - (cudnn_runtime_version >= 90000)) && - // number of heads - ((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) || - (cudnn_runtime_version >= 8907)) && - // head dimension - // multiples of 8 - (head_dim_qk % 8 == 0 && head_dim_v % 8 == 0 && - // <= 128 - ((head_dim_qk <= 128 && head_dim_v <= 128) || - // 9.1: <= 256 + Hopper + fprop - // 9.5: <= 256 + Hopper + bprop - (head_dim_qk <= 256 && head_dim_v <= 256 && - ((!is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90100) || - (is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90500))) || - // 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1 - (!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 && max_seqlen_q > 1 && - layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) || - // 9.10.2: any head_dim + any arch + fprop + paged - // 9.10.2: any head_dim + any arch + fprop + non_paged + sq > 1 - // 9.10.2: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM} - (!is_training && cudnn_runtime_version >= 91002 && - (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD || max_seqlen_q > 1 || - (max_seqlen_q == 1 && attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK && - attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) || - // 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged - (head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 && - cudnn_runtime_version >= 91100)) && - // 9.11+ bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA - // Conditional to temporarily use blanket cudnn_runtime_version >= 9.11 until fixed - (!((cudnn_runtime_version >= 91100) && is_training && sm_arch_ == 90 && - head_dim_qk >= 128 && head_dim_v >= 128 && !(head_dim_qk == 192 && head_dim_v == 128) && - head_dim_qk != head_dim_v))) && - // bias type - ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || - (cudnn_runtime_version >= 8906 && - (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || - (bias_type == NVTE_Bias_Type::NVTE_ALIBI && - attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK && - attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && - attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && - attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && - sm_arch_ >= 90) || - (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 90))) || - (cudnn_runtime_version >= 90000 && - (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 80))) && - // mask type - // pre-8.9.6: causal - ((cudnn_runtime_version < 8906 && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || - // 8.9.6: {bshd, sbhd} + {no_mask, causal, padding, padding_causal} - (cudnn_runtime_version >= 8906 && - (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) && - (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || - // 9.1: adds thd + {padding, padding_causal} - (cudnn_runtime_version >= 90100 && qkv_format == NVTE_QKV_Format::NVTE_THD && - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) || - // 9.3: adds {bshd, sbhd} + causal_bottom_right + self/cross-attn (sq <= skv) - (cudnn_runtime_version >= 90300 && - (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) && - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && - max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv && - bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || - // 9.5: adds {paged_kv_bshd, paged_kv_sbhd} + {padding, padding_causal, padding_causal_bottom_right} - (cudnn_runtime_version >= 90500 && - layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD && - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && - max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv)) && - bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || - // 9.6: adds {bshd, sbhd, thd} + padding_causal_bottom_right + self/cross-attn (sq <= skv) - (cudnn_runtime_version >= 90600 && - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && - max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv && - bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || - // 9.7: removes s_q/s_kv % 64 = 0 for {causal_bottom_right, padding_causal_bottom_right} - // for any q_format/kv_format, and paged/non-paged - (cudnn_runtime_version >= 90700 && - (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - ((attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && - bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || - ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && - max_seqlen_q <= max_seqlen_kv)))) && - // bias + mask combination - (!(cudnn_runtime_version >= 8906 && - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) && - bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && - // qkv format - (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD || - qkv_format == NVTE_QKV_Format::NVTE_BHSD || - (qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 && - ((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) || - cudnn_runtime_version >= 90600)) || - ((q_format == NVTE_QKV_Format::NVTE_SBHD || q_format == NVTE_QKV_Format::NVTE_BSHD || - q_format == NVTE_QKV_Format::NVTE_BHSD || - (q_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90) || - kv_format == NVTE_QKV_Format::NVTE_SBHD || kv_format == NVTE_QKV_Format::NVTE_BSHD || - kv_format == NVTE_QKV_Format::NVTE_BHSD || - (kv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90)) && - cudnn_runtime_version >= 90700)) && - // sliding window - // pre-9.2: full attn, causal - ((cudnn_runtime_version < 90200 && window_size_left == -1 && - (window_size_right == -1 || window_size_right == 0)) || - // 9.2: SWA (left, 0) + top-left diagonal + {bshd, sbhd} - (cudnn_runtime_version >= 90200 && - ((window_size_left == -1 && window_size_right == -1 && - attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK) || - ((window_size_left == -1 || window_size_left >= 0) && window_size_right == 0 && - (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && - max_seqlen_q == max_seqlen_kv)) && - max_seqlen_q <= max_seqlen_kv && dropout == 0.0 && - bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && - (qkv_format == NVTE_QKV_Format::NVTE_BSHD || - qkv_format == NVTE_QKV_Format::NVTE_SBHD)))) || - // 9.6: SWA (left, 0) + top-left/bottom-right diagonal + {bshd, sbhd, thd} - (cudnn_runtime_version >= 90600 && - ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || - ((window_size_left >= 0 || window_size_left == -1) && - (window_size_right >= 0 || window_size_right == -1) && - ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && - // TODO(cyang): fix bug for BRCM + cross-attention on sm100 - (sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv && - cudnn_runtime_version <= 90700) || - cudnn_runtime_version > 90700)))) || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && - (sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv && - cudnn_runtime_version <= 90700) || - cudnn_runtime_version > 90700))))) && - max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && - dropout == 0.0)))) && - // check 64-bit ragged offset support - (supported_ragged_offset_size) && - // 9.10.0/9.10.1: known bugs with SDPA F16 - (cudnn_runtime_version != 91000) && (cudnn_runtime_version != 91001) && - // softmax type - // pre-9.13.1: vanilla - // 9.13.1+: vanilla, off-by-one, learnable - (cudnn_runtime_version >= 91301 || - (cudnn_runtime_version < 91301 && - softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)) && - // determinism on Blackwell - // pre-9.18.1: fwd: deterministic; bwd: non-deterministic - // 9.18.1+: fwd: deterministic; bwd: non-deterministic/deterministic - (sm_arch_ < 100 || - (sm_arch_ >= 100 && (!is_training || - (is_training && !deterministic && - (dropout == 0.0 || bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)) || - (is_training && deterministic && cudnn_runtime_version >= 91801 && - dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS))))) { - flag_arb = true; + if (return_max_logit) { + set_status(out_status, cudnn_frontend::error_code_t::INVALID_VALUE, + "FP8 fused attention does not support return_max_logit."); + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } - if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { - backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; + const DType q_t = static_cast(q_dtype); + const DType o_t = static_cast(o_dtype); + auto fwd_status = is_supported_fp8_fwd( + probe_batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, + head_dim_v, is_training, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, probe_bottom_right_diagonal, q_t, o_t, scaling_mode, + handle); + if (fwd_status.is_bad()) { + set_status(out_status, fwd_status); + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } - if ((max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { - if (flag_arb == true) { - backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; - } else if ((flag_arb == false) && (flag_m512 == true)) { - backend = NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen; - } - int env_backend = static_cast(backend); - env_backend = transformer_engine::getenv("NVTE_FUSED_ATTN_BACKEND", env_backend); - if (((env_backend == static_cast(NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)) && - flag_m512) || - ((env_backend == static_cast(NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen)) && - flag_arb)) { - backend = static_cast(env_backend); + if (is_training) { + auto bwd_status = is_supported_fp8_bwd( + probe_batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, + head_dim_v, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, probe_bottom_right_diagonal, deterministic, q_t, + o_t, scaling_mode, handle); + if (bwd_status.is_bad()) { + set_status(out_status, bwd_status); + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } } - if (cudnn_runtime_version < 8901 && - backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.1+." - " Please upgrade your cuDNN version if possible." - << std::endl; - } - if (cudnn_runtime_version < 8900 && - backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.0+." - " Please upgrade your cuDNN version if possible." - << std::endl; - } - if ((cudnn_runtime_version == 91400) && (max_seqlen_kv > 1024) && (window_size_left != -1) && - (attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK) && - (attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK)) { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: Given combination of attention mask (non-causal) and " - "max_seqlen_kv (> 1024) does not support fused attention for cuDNN 9.14.0. " - " Please upgrade your cuDNN version if possible." - << std::endl; - } - if ((cudnn_runtime_version <= 91500) && is_training && - (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && - (max_seqlen_kv % 128 != 0) && cuda_graph && - (attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK) && - (attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) && - (attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)) { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: Given combination of attention mask (non-padding)," - " max_seqlen_kv (not divisible by 128), and qkv_format (BSHD/SBHD) for" - " backward fused attention with graph capture requires cuDNN 9.15.1+. " - "Please upgrade your cuDNN version if possible." - << std::endl; + return NVTE_Fused_Attn_Backend::NVTE_FP8; + } + + if (is_f16_or_bf16) { + const DType q_t = static_cast(q_dtype); + auto fwd_status = is_supported_f16_fwd( + probe_batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, + head_dim_v, is_training, return_max_logit, dropout, qkv_layout, bias_type, attn_mask_type, + softmax_type, window_size_left, window_size_right, probe_bottom_right_diagonal, q_t, + handle); + if (fwd_status.is_bad()) { + set_status(out_status, fwd_status); + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } - if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen && sm_arch_ == 120) { - if (cudnn_runtime_version < 91801) { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: Given combination of sm_arch_ == 120 and cudnn_runtime_version < " - "91801 is not supported. " - << " Please upgrade your cuDNN version if possible." << std::endl; - } else if (deterministic && is_training) { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: Deterministic fused attention on SM120 is not supported." - << std::endl; - } else { - // Known missing support for T3HD/TH3D layouts on SM120 - const bool is_t3hd_or_th3d = - (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD || qkv_layout == NVTE_QKV_Layout::NVTE_TH3D); - if (is_t3hd_or_th3d) { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: Given combination of T3HD/TH3D layouts on SM120 is not supported. " - << " Please consider using other THD layouts if possible." << std::endl; - } + if (is_training) { + auto bwd_status = is_supported_f16_bwd( + probe_batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, + head_dim_v, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, probe_bottom_right_diagonal, deterministic, q_t, + handle); + if (bwd_status.is_bad()) { + set_status(out_status, bwd_status); + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } } - } else { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; + return NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; } - return backend; + + set_status(out_status, cudnn_frontend::error_code_t::INVALID_VALUE, + "Unsupported Q dtype for fused attention " + "(only FP16/BF16/FP8_E4M3/FP8_E5M2 are routable)."); + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } // NVTE fused attention FWD with separate Q, K and V @@ -661,11 +493,14 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); const NVTEDType KV_type = static_cast(input_K->data.dtype); + const NVTEDType O_type = static_cast(output_O->data.dtype); + const NVTEScalingMode scaling_mode = input_Q->scaling_mode; NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, - h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, - return_max_logit, cuda_graph, false); + is_training, Q_type, KV_type, O_type, scaling_mode, qkv_layout, bias_type, attn_mask_type, + softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, + window_size_right, return_max_logit, cuda_graph, /*deterministic=*/false, handle, + /*out_status=*/nullptr); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, @@ -747,11 +582,14 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); const NVTEDType KV_type = static_cast(input_K->data.dtype); + const NVTEDType O_type = static_cast(input_O->data.dtype); + const NVTEScalingMode scaling_mode = input_Q->scaling_mode; NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, - h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false, - cuda_graph, deterministic); + /*is_training=*/true, Q_type, KV_type, O_type, scaling_mode, qkv_layout, bias_type, + attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, + window_size_left, window_size_right, /*return_max_logit=*/false, cuda_graph, deterministic, + handle, /*out_status=*/nullptr); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 6df7ad35c8..57ca14a3e1 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -1333,4 +1333,140 @@ void fused_attn_arbitrary_seqlen_bwd( NVTE_ERROR("Unexpected workspace_size."); } } + +namespace { +// Probe-time defaults for runtime-only quantities the router doesn't see (paged-KV dims, +// ragged max-tokens, bias dims). These produce a graph whose support surface matches the +// real executor's: for non-paged / non-ragged paths these are unused inside the impl; +// for ragged-THD we rebind to worst-case bounds; for paged we use 1 page of full s_kv per +// batch (= same dims as non-paged), so cuDNN-FE applies the paged-attention support rules. +struct ProbeDims { + int64_t max_b; + int64_t max_t_q; + int64_t max_t_kv; + int64_t num_pages_k; + int64_t num_pages_v; + int64_t page_size_k; + int64_t page_size_v; + int64_t max_pages_per_seq_k; + int64_t max_pages_per_seq_v; + int64_t bias_b; + int64_t bias_h; + int64_t bias_sq; + int64_t bias_skv; +}; + +ProbeDims compute_probe_dims(int64_t batch, int64_t num_attn_heads, int64_t max_seqlen_q, + int64_t max_seqlen_kv, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type) { + const NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + const NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + const bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD); + const bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD); + const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + const bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD); + const bool has_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); + + ProbeDims d{}; + d.max_b = (is_ragged_q || is_ragged_kv) ? batch : 0; + d.max_t_q = is_ragged_q ? batch * max_seqlen_q : 0; + d.max_t_kv = is_ragged_kv ? batch * max_seqlen_kv : 0; + d.num_pages_k = is_paged_kv ? batch : 0; + d.num_pages_v = is_paged_kv ? batch : 0; + d.page_size_k = is_paged_kv ? max_seqlen_kv : 0; + d.page_size_v = is_paged_kv ? max_seqlen_kv : 0; + d.max_pages_per_seq_k = is_paged_kv ? 1 : 0; + d.max_pages_per_seq_v = is_paged_kv ? 1 : 0; + d.bias_b = has_bias ? batch : 0; + d.bias_h = has_bias ? num_attn_heads : 0; + d.bias_sq = has_bias ? max_seqlen_q : 0; + d.bias_skv = has_bias ? max_seqlen_kv : 0; + return d; +} +} // namespace + +cudnn_frontend::error_t is_supported_f16_fwd( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, + bool return_max_logit, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, DType q_dtype, cudnnHandle_t handle) { + const ProbeDims d = compute_probe_dims(static_cast(batch), + static_cast(num_attn_heads), + static_cast(max_seqlen_q), + static_cast(max_seqlen_kv), qkv_layout, + bias_type); + const NVTE_QKV_Format o_format = nvte_get_q_format(qkv_layout); + + size_t workspace_size = 0; + try { + fused_attn::fused_attn_arbitrary_seqlen_fwd_impl( + static_cast(batch), static_cast(num_attn_heads), + static_cast(num_gqa_groups), static_cast(max_seqlen_q), + static_cast(max_seqlen_kv), static_cast(head_dim_qk), + static_cast(head_dim_v), d.max_b, d.max_t_q, d.max_t_kv, d.num_pages_k, + d.num_pages_v, d.page_size_k, d.page_size_v, d.max_pages_per_seq_k, + d.max_pages_per_seq_v, d.bias_b, d.bias_h, d.bias_sq, d.bias_skv, is_training, + return_max_logit, /*scaling_factor=*/1.0f, p_dropout, qkv_layout, o_format, bias_type, + mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, + /*devPtrQ=*/nullptr, /*devPtrK=*/nullptr, /*devPtrV=*/nullptr, /*devPtrBias=*/nullptr, + /*devPtrSoftmaxOffset=*/nullptr, /*devPtrS1=*/nullptr, /*devPtrS2=*/nullptr, + /*devPtrO=*/nullptr, /*devPtrDropoutSeed=*/nullptr, /*devPtrDropoutOffset=*/nullptr, + /*devPtrCuSeqlensQ=*/nullptr, /*devPtrCuSeqlensKV=*/nullptr, + /*devPtrPageTableK=*/nullptr, /*devPtrPageTableV=*/nullptr, + /*devPtrSeqOffsetsQ=*/nullptr, /*devPtrSeqOffsetsKV=*/nullptr, + get_cudnn_fe_dtype(q_dtype), /*workspace=*/nullptr, &workspace_size, + /*stream=*/static_cast(0), handle); + return {cudnn_frontend::error_code_t::OK, ""}; + } catch (const std::exception &e) { + return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, e.what()}; + } catch (...) { + return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, + "is_supported_f16_fwd: unknown failure"}; + } +} + +cudnn_frontend::error_t is_supported_f16_bwd( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic, DType q_dtype, cudnnHandle_t handle) { + const ProbeDims d = compute_probe_dims(static_cast(batch), + static_cast(num_attn_heads), + static_cast(max_seqlen_q), + static_cast(max_seqlen_kv), qkv_layout, + bias_type); + const NVTE_QKV_Format o_format = nvte_get_q_format(qkv_layout); + const NVTE_QKV_Format do_format = o_format; + const NVTE_QKV_Layout dqkv_layout = qkv_layout; + + size_t workspace_size = 0; + try { + fused_attn::fused_attn_arbitrary_seqlen_bwd_impl( + static_cast(batch), static_cast(num_attn_heads), + static_cast(num_gqa_groups), static_cast(max_seqlen_q), + static_cast(max_seqlen_kv), static_cast(head_dim_qk), + static_cast(head_dim_v), d.max_b, d.max_t_q, d.max_t_kv, d.bias_b, d.bias_h, + d.bias_sq, d.bias_skv, /*scaling_factor=*/1.0f, p_dropout, qkv_layout, o_format, do_format, + dqkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, + bottom_right_diagonal, deterministic, /*devPtrQ=*/nullptr, /*devPtrKTranspose=*/nullptr, + /*devPtrVTranspose=*/nullptr, /*devPtrO=*/nullptr, /*devPtrSoftmaxStats=*/nullptr, + /*devPtrBias=*/nullptr, /*devPtrSoftmaxOffset=*/nullptr, /*devPtrdQ=*/nullptr, + /*devPtrdK=*/nullptr, /*devPtrdV=*/nullptr, /*devPtrdO=*/nullptr, + /*devPtrdBias=*/nullptr, /*devPtrdSoftmaxOffset=*/nullptr, + /*devPtrDropoutSeed=*/nullptr, /*devPtrDropoutOffset=*/nullptr, + /*devPtrCuSeqlensQ=*/nullptr, /*devPtrCuSeqlensKV=*/nullptr, + /*devPtrSeqOffsetsQ=*/nullptr, /*devPtrSeqOffsetsKV=*/nullptr, + get_cudnn_fe_dtype(q_dtype), /*workspace=*/nullptr, &workspace_size, + /*stream=*/static_cast(0), handle); + return {cudnn_frontend::error_code_t::OK, ""}; + } catch (const std::exception &e) { + return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, e.what()}; + } catch (...) { + return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, + "is_supported_f16_bwd: unknown failure"}; + } +} + } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 8f79b5bb4a..38cf48c1f0 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -12,6 +12,7 @@ #define TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_ #include +#include #include "common/common.h" #include "transformer_engine/fused_attn.h" @@ -47,6 +48,27 @@ void fused_attn_arbitrary_seqlen_bwd( const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); +// Probe: drives cuDNN-FE (validate -> build_operation_graph -> create_execution_plans -> +// check_support -> build_plans) for an F16/BF16 forward graph with the given configuration. +// Returns the cuDNN-FE status: error_code_t::OK iff the graph compiles end-to-end. On OK, +// the built graph is inserted into the same thread-local cache used by +// fused_attn_arbitrary_seqlen_fwd_impl, so the executor cache-hits on matching descriptors. +// On rejection, err_msg contains the underlying cuDNN-FE / NVTE_CHECK message. +cudnn_frontend::error_t is_supported_f16_fwd( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, + bool return_max_logit, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, DType q_dtype, cudnnHandle_t handle); + +// Probe: same as above for the F16/BF16 backward graph. +cudnn_frontend::error_t is_supported_f16_bwd( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic, DType q_dtype, cudnnHandle_t handle); + } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_ diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index d97f388459..c9f7a9ee76 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2991,4 +2991,105 @@ void fused_attn_fp8_bwd( return; } } + +cudnn_frontend::error_t is_supported_fp8_fwd( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, DType q_dtype, DType o_dtype, NVTEScalingMode scaling_mode, + cudnnHandle_t handle) { + // FP8 fwd impl rejects any qkv_format other than BSHD/SBHD/BHSD with NVTE_ERROR; mirror that + // here so the probe returns a typed rejection instead of catching the throw. + const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format != NVTE_QKV_Format::NVTE_BSHD && qkv_format != NVTE_QKV_Format::NVTE_SBHD && + qkv_format != NVTE_QKV_Format::NVTE_BHSD) { + return {cudnn_frontend::error_code_t::INVALID_VALUE, + "FP8 fused attention only supports BSHD/SBHD/BHSD layouts."}; + } + size_t workspace_size = 0; + try { + fused_attn::fused_attn_fp8_fwd_impl( + static_cast(batch), static_cast(num_attn_heads), + static_cast(num_gqa_groups), static_cast(max_seqlen_q), + static_cast(max_seqlen_kv), static_cast(head_dim_qk), + static_cast(head_dim_v), is_training, /*scaling_factor=*/1.0f, p_dropout, + qkv_layout, /*o_format=*/qkv_format, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, bottom_right_diagonal, + /*devPtrQ=*/nullptr, /*devPtrK=*/nullptr, /*devPtrV=*/nullptr, + /*devPtrSoftmaxOffset=*/nullptr, /*devPtrM=*/nullptr, /*devPtrO=*/nullptr, + /*devPtrDescaleQ=*/nullptr, /*devPtrDescaleK=*/nullptr, /*devPtrDescaleV=*/nullptr, + /*devPtrDescaleS=*/nullptr, /*devPtrScaleS=*/nullptr, /*devPtrScaleO=*/nullptr, + /*devPtrAmaxO=*/nullptr, /*devPtrAmaxS=*/nullptr, /*devPtrcuSeqlensQ=*/nullptr, + /*devPtrcuSeqlensKV=*/nullptr, /*devPtrDropoutSeed=*/nullptr, + /*devPtrDropoutOffset=*/nullptr, get_cudnn_fe_dtype(q_dtype), + get_cudnn_fe_dtype(o_dtype), scaling_mode, + /*qkv_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, + /*workspace=*/nullptr, &workspace_size, + /*stream=*/static_cast(0), handle); + return {cudnn_frontend::error_code_t::OK, ""}; + } catch (const std::exception &e) { + return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, e.what()}; + } catch (...) { + return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, + "is_supported_fp8_fwd: unknown failure"}; + } +} + +cudnn_frontend::error_t is_supported_fp8_bwd( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic, DType q_dtype, DType o_dtype, + NVTEScalingMode scaling_mode, cudnnHandle_t handle) { + const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format != NVTE_QKV_Format::NVTE_BSHD && qkv_format != NVTE_QKV_Format::NVTE_SBHD && + qkv_format != NVTE_QKV_Format::NVTE_BHSD) { + return {cudnn_frontend::error_code_t::INVALID_VALUE, + "FP8 fused attention only supports BSHD/SBHD/BHSD layouts."}; + } + // For FP8 bwd, dO data type matches O data type and dQKV data type matches Q data type + // (this mirrors the assumption used by callers of fused_attn_fp8_bwd in TE). + const cudnn_frontend::DataType_t qkv_t = get_cudnn_fe_dtype(q_dtype); + const cudnn_frontend::DataType_t o_t = get_cudnn_fe_dtype(o_dtype); + const cudnn_frontend::DataType_t do_t = o_t; + const cudnn_frontend::DataType_t dqkv_t = qkv_t; + size_t workspace_size = 0; + try { + fused_attn::fused_attn_fp8_bwd_impl( + static_cast(batch), static_cast(num_attn_heads), + static_cast(num_gqa_groups), static_cast(max_seqlen_q), + static_cast(max_seqlen_kv), static_cast(head_dim_qk), + static_cast(head_dim_v), /*scaling_factor=*/1.0f, p_dropout, qkv_layout, + /*o_format=*/qkv_format, /*do_format=*/qkv_format, /*dqkv_layout=*/qkv_layout, bias_type, + mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, + deterministic, + /*devPtrQ=*/nullptr, /*devPtrK=*/nullptr, /*devPtrV=*/nullptr, /*devPtrM=*/nullptr, + /*devPtrO=*/nullptr, /*devPtrdO=*/nullptr, /*devPtrSoftmaxOffset=*/nullptr, + /*devPtrdQ=*/nullptr, /*devPtrdK=*/nullptr, /*devPtrdV=*/nullptr, + /*devPtrdSoftmaxOffset=*/nullptr, /*devPtrDescaleQ=*/nullptr, + /*devPtrDescaleK=*/nullptr, /*devPtrDescaleV=*/nullptr, /*devPtrDescaleO=*/nullptr, + /*devPtrDescaledO=*/nullptr, /*devPtrDescaleS=*/nullptr, /*devPtrDescaledP=*/nullptr, + /*devPtrScaleS=*/nullptr, /*devPtrScaledP=*/nullptr, /*devPtrScaledQ=*/nullptr, + /*devPtrScaledK=*/nullptr, /*devPtrScaledV=*/nullptr, /*devPtrAmaxdP=*/nullptr, + /*devPtrAmaxdQ=*/nullptr, /*devPtrAmaxdK=*/nullptr, /*devPtrAmaxdV=*/nullptr, + /*devPtrQ_t=*/nullptr, /*devPtrK_t=*/nullptr, /*devPtrdO_f16=*/nullptr, + /*devPtrdO_t=*/nullptr, /*devPtrDescaleQ_t=*/nullptr, /*devPtrDescaleK_t=*/nullptr, + /*devPtrDescaledO_t=*/nullptr, /*devPtrcuSeqlensQ=*/nullptr, + /*devPtrcuSeqlensKV=*/nullptr, /*devPtrDropoutSeed=*/nullptr, + /*devPtrDropoutOffset=*/nullptr, qkv_t, o_t, do_t, dqkv_t, scaling_mode, + /*qkv_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, + /*do_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, + /*workspace=*/nullptr, &workspace_size, + /*stream=*/static_cast(0), handle); + return {cudnn_frontend::error_code_t::OK, ""}; + } catch (const std::exception &e) { + return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, e.what()}; + } catch (...) { + return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, + "is_supported_fp8_bwd: unknown failure"}; + } +} + } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index aaf5039eeb..5c7f11d80e 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -8,6 +8,8 @@ * \brief Functions for fused attention for FP8 with seqlen <= 512 */ +#include + #include "transformer_engine/fused_attn.h" #include "transformer_engine/transformer_engine.h" @@ -39,4 +41,27 @@ void fused_attn_fp8_bwd( const Tensor *output_dQ, const Tensor *output_dK, const Tensor *output_dV, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + +// Probe: drives cuDNN-FE (validate -> build_operation_graph -> create_execution_plans -> +// check_support -> build_plans) for an FP8 forward graph with the given configuration. +// Returns the cuDNN-FE status: error_code_t::OK iff the graph compiles end-to-end. On OK, +// the built graph is inserted into the same thread-local cache used by fused_attn_fp8_fwd_impl. +// On rejection, err_msg contains the underlying cuDNN-FE / NVTE_CHECK message. +cudnn_frontend::error_t is_supported_fp8_fwd( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, DType q_dtype, DType o_dtype, NVTEScalingMode scaling_mode, + cudnnHandle_t handle); + +// Probe: same as above for the FP8 backward graph. +cudnn_frontend::error_t is_supported_fp8_bwd( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic, DType q_dtype, DType o_dtype, + NVTEScalingMode scaling_mode, cudnnHandle_t handle); +>>>>>>> c9006435 (refactor nvte_get_fused_attn_backend with FE calls) } // namespace transformer_engine diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 912dc32d35..787e97d628 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -11,6 +11,8 @@ #ifndef TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_ #define TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_ +#include + #include "stdint.h" #include "transformer_engine.h" @@ -196,11 +198,40 @@ NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout); */ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); +/*! \struct NVTEFusedAttnBackendStatus + * \brief Diagnostic info from \c nvte_get_fused_attn_backend. + * + * Filled by \c nvte_get_fused_attn_backend when the caller passes a non-NULL pointer. + * When the routing decision is supported, \c code is 0 and \c message is the empty + * string. When the routing rejects the configuration, \c code is the underlying + * cuDNN-FE \c cudnn_frontend::error_code_t cast to \c int (TE-synthesized post-filter + * rejections use \c INVALID_VALUE), and \c message is a null-terminated human-readable + * reason that points into per-thread storage owned by TE. The pointer is valid only + * until the next call to \c nvte_get_fused_attn_backend on the same thread. + */ +typedef struct NVTEFusedAttnBackendStatus { + int code; + const char *message; +} NVTEFusedAttnBackendStatus; + /*! \brief Get fused attention backend based on input parameters. + * + * Authoritative routing: when a non-NVTE_No_Backend value is returned, the configuration + * is guaranteed to compile through cuDNN-FE (validate -> build_operation_graph -> + * create_execution_plans -> check_support -> build_plans). The router applies a small + * set of TE-specific post-filters in addition to delegating to cuDNN-FE for capability + * checks. On success the built plan is cached, so the executor avoids rebuilding. * * \param[in] is_training Whether the model is in training mode. * \param[in] q_dtype The data type of Tensor Q. * \param[in] kv_dtype The data type of Tensors K, V. + * \param[in] o_dtype The data type of output Tensor O. Used by the FP8 + * branch to disambiguate FP8 vs HALF/BF16 output; + * ignored by the F16/BF16 branch (pass q_dtype). + * \param[in] scaling_mode Scaling mode of the input tensors. Used by the FP8 + * branch to select among delayed/current/MXFP8 recipes; + * ignored by the F16/BF16 branch + * (pass NVTE_DELAYED_TENSOR_SCALING). * \param[in] qkv_layout The layout of Tensors Q, K, V. * \param[in] bias_type The attention bias type. * \param[in] attn_mask_type The attention mask type. @@ -217,13 +248,22 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); * \param[in] return_max_logit Whether to produce Max along with Stats. * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] deterministic Whether determinism is required or not. + * \param[in] handle cuDNN handle used for the support chain. Required. + * \param[out] out_status Optional. When non-NULL, populated with a code + + * message describing why the configuration was + * rejected (NVTE_No_Backend) or with code=0 and + * message="" on success. The message buffer lives in + * thread-local storage and is overwritten on every + * call on the same thread. */ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( - bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic); + bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTEDType o_dtype, + NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, + size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, + size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, + bool return_max_logit, bool cuda_graph, bool deterministic, cudnnHandle_t handle, + NVTEFusedAttnBackendStatus *out_status); /*! \brief Compute dot product attention with separate Q, K and V. * diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 76f2d92891..c6a8897089 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -5,6 +5,7 @@ ************************************************************************/ #include "../extensions.h" +#include "common/cudnn_utils.h" #include "transformer_engine/fused_attn.h" #include "transformer_engine/transformer_engine.h" @@ -17,11 +18,14 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend( float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, int64_t window_size_right, bool deterministic) { + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); auto backend = nvte_get_fused_attn_backend( - is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, - bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads, - q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - false, false, deterministic); + is_training, static_cast(q_dtype), static_cast(kv_dtype), + static_cast(q_dtype), NVTE_DELAYED_TENSOR_SCALING, qkv_layout, bias_type, + mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, + kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, + /*return_max_logit=*/false, /*cuda_graph=*/false, deterministic, handle, + /*out_status=*/nullptr); return backend; } @@ -272,11 +276,13 @@ static void FusedAttnForwardImpl( /* Prepare RNG state */ auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); + auto _handle_fwd = cudnnExecutionPlanManager::Instance().GetHandle(); auto backend = nvte_get_fused_attn_backend( - is_training, static_cast(dtype), static_cast(dtype), qkv_layout, - bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, - q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - false, false, deterministic); + is_training, static_cast(dtype), static_cast(dtype), + static_cast(dtype), NVTE_DELAYED_TENSOR_SCALING, qkv_layout, bias_type, mask_type, + softmax_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, + qk_head_dim, v_head_dim, window_size_left, window_size_right, /*return_max_logit=*/false, + /*cuda_graph=*/false, deterministic, _handle_fwd, /*out_status=*/nullptr); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); /* Auxiliary tensors (to be propagated to the backward pass later) */ @@ -548,11 +554,13 @@ static void FusedAttnBackwardImpl( /* Auxiliary tensors (propagated from the forward pass) */ NVTETensorPack aux_input_tensors; nvte_tensor_pack_create(&aux_input_tensors); + auto _handle_bwd = cudnnExecutionPlanManager::Instance().GetHandle(); auto backend = nvte_get_fused_attn_backend( - is_training, static_cast(dtype), static_cast(dtype), qkv_layout, - bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, - q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - false, false, deterministic); + is_training, static_cast(dtype), static_cast(dtype), + static_cast(dtype), NVTE_DELAYED_TENSOR_SCALING, qkv_layout, bias_type, mask_type, + softmax_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, + qk_head_dim, v_head_dim, window_size_left, window_size_right, /*return_max_logit=*/false, + /*cuda_graph=*/false, deterministic, _handle_bwd, /*out_status=*/nullptr); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, softmax_aux, rng_state, bias, softmax_offset); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index e6781bd58a..d67cd4a6b9 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -6,6 +6,7 @@ #include "../extensions.h" #include "common.h" +#include "common/cudnn_utils.h" #include "pybind.h" namespace { @@ -40,17 +41,25 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s namespace transformer_engine::pytorch { // get the fused attention backend +// +// NOTE: the underlying nvte_get_fused_attn_backend now takes o_dtype and scaling_mode in +// addition to q_dtype/kv_dtype. For the F16/BF16 routing path those are ignored, so we pass +// q_dtype as o_dtype and DELAYED_TENSOR_SCALING. This Python-facing wrapper therefore keeps +// its existing signature; FP8 callers that want authoritative routing for non-default scaling +// recipes should add o_dtype / scaling_mode parameters in a follow-up. NVTE_Fused_Attn_Backend get_fused_attn_backend( bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic) { + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, - bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups, - max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right, - return_max_logit, cuda_graph, deterministic); + is_training, static_cast(q_dtype), static_cast(kv_dtype), + static_cast(q_dtype), NVTE_DELAYED_TENSOR_SCALING, qkv_layout, bias_type, + attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, + max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right, + return_max_logit, cuda_graph, deterministic, handle, /*out_status=*/nullptr); return fused_attention_backend; } From 16b837cd0f3e27b7638bfb2a90d39056680a6b6e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 6 May 2026 01:58:25 +0000 Subject: [PATCH 02/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/fused_attn/fused_attn.cpp | 4 +-- .../fused_attn_f16_arbitrary_seqlen.cu | 34 +++++++++---------- .../common/fused_attn/fused_attn_fp8.cu | 12 +++---- .../pytorch/csrc/extensions/attention.cpp | 4 +-- 4 files changed, 26 insertions(+), 28 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 615f7c2a03..95405c0d6f 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -370,8 +370,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( auto bwd_status = is_supported_fp8_bwd( probe_batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, probe_bottom_right_diagonal, deterministic, q_t, - o_t, scaling_mode, handle); + window_size_left, window_size_right, probe_bottom_right_diagonal, deterministic, q_t, o_t, + scaling_mode, handle); if (bwd_status.is_bad()) { set_status(out_status, bwd_status); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 57ca14a3e1..1ced84755c 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -1391,11 +1391,10 @@ cudnn_frontend::error_t is_supported_f16_fwd( bool return_max_logit, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, DType q_dtype, cudnnHandle_t handle) { - const ProbeDims d = compute_probe_dims(static_cast(batch), - static_cast(num_attn_heads), - static_cast(max_seqlen_q), - static_cast(max_seqlen_kv), qkv_layout, - bias_type); + const ProbeDims d = + compute_probe_dims(static_cast(batch), static_cast(num_attn_heads), + static_cast(max_seqlen_q), static_cast(max_seqlen_kv), + qkv_layout, bias_type); const NVTE_QKV_Format o_format = nvte_get_q_format(qkv_layout); size_t workspace_size = 0; @@ -1405,17 +1404,17 @@ cudnn_frontend::error_t is_supported_f16_fwd( static_cast(num_gqa_groups), static_cast(max_seqlen_q), static_cast(max_seqlen_kv), static_cast(head_dim_qk), static_cast(head_dim_v), d.max_b, d.max_t_q, d.max_t_kv, d.num_pages_k, - d.num_pages_v, d.page_size_k, d.page_size_v, d.max_pages_per_seq_k, - d.max_pages_per_seq_v, d.bias_b, d.bias_h, d.bias_sq, d.bias_skv, is_training, - return_max_logit, /*scaling_factor=*/1.0f, p_dropout, qkv_layout, o_format, bias_type, - mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, + d.num_pages_v, d.page_size_k, d.page_size_v, d.max_pages_per_seq_k, d.max_pages_per_seq_v, + d.bias_b, d.bias_h, d.bias_sq, d.bias_skv, is_training, return_max_logit, + /*scaling_factor=*/1.0f, p_dropout, qkv_layout, o_format, bias_type, mask_type, + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, /*devPtrQ=*/nullptr, /*devPtrK=*/nullptr, /*devPtrV=*/nullptr, /*devPtrBias=*/nullptr, /*devPtrSoftmaxOffset=*/nullptr, /*devPtrS1=*/nullptr, /*devPtrS2=*/nullptr, /*devPtrO=*/nullptr, /*devPtrDropoutSeed=*/nullptr, /*devPtrDropoutOffset=*/nullptr, /*devPtrCuSeqlensQ=*/nullptr, /*devPtrCuSeqlensKV=*/nullptr, /*devPtrPageTableK=*/nullptr, /*devPtrPageTableV=*/nullptr, - /*devPtrSeqOffsetsQ=*/nullptr, /*devPtrSeqOffsetsKV=*/nullptr, - get_cudnn_fe_dtype(q_dtype), /*workspace=*/nullptr, &workspace_size, + /*devPtrSeqOffsetsQ=*/nullptr, /*devPtrSeqOffsetsKV=*/nullptr, get_cudnn_fe_dtype(q_dtype), + /*workspace=*/nullptr, &workspace_size, /*stream=*/static_cast(0), handle); return {cudnn_frontend::error_code_t::OK, ""}; } catch (const std::exception &e) { @@ -1432,11 +1431,10 @@ cudnn_frontend::error_t is_supported_f16_bwd( NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, DType q_dtype, cudnnHandle_t handle) { - const ProbeDims d = compute_probe_dims(static_cast(batch), - static_cast(num_attn_heads), - static_cast(max_seqlen_q), - static_cast(max_seqlen_kv), qkv_layout, - bias_type); + const ProbeDims d = + compute_probe_dims(static_cast(batch), static_cast(num_attn_heads), + static_cast(max_seqlen_q), static_cast(max_seqlen_kv), + qkv_layout, bias_type); const NVTE_QKV_Format o_format = nvte_get_q_format(qkv_layout); const NVTE_QKV_Format do_format = o_format; const NVTE_QKV_Layout dqkv_layout = qkv_layout; @@ -1457,8 +1455,8 @@ cudnn_frontend::error_t is_supported_f16_bwd( /*devPtrdBias=*/nullptr, /*devPtrdSoftmaxOffset=*/nullptr, /*devPtrDropoutSeed=*/nullptr, /*devPtrDropoutOffset=*/nullptr, /*devPtrCuSeqlensQ=*/nullptr, /*devPtrCuSeqlensKV=*/nullptr, - /*devPtrSeqOffsetsQ=*/nullptr, /*devPtrSeqOffsetsKV=*/nullptr, - get_cudnn_fe_dtype(q_dtype), /*workspace=*/nullptr, &workspace_size, + /*devPtrSeqOffsetsQ=*/nullptr, /*devPtrSeqOffsetsKV=*/nullptr, get_cudnn_fe_dtype(q_dtype), + /*workspace=*/nullptr, &workspace_size, /*stream=*/static_cast(0), handle); return {cudnn_frontend::error_code_t::OK, ""}; } catch (const std::exception &e) { diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index c9f7a9ee76..8a152cf489 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -3014,21 +3014,21 @@ cudnn_frontend::error_t is_supported_fp8_fwd( static_cast(num_gqa_groups), static_cast(max_seqlen_q), static_cast(max_seqlen_kv), static_cast(head_dim_qk), static_cast(head_dim_v), is_training, /*scaling_factor=*/1.0f, p_dropout, - qkv_layout, /*o_format=*/qkv_format, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, bottom_right_diagonal, + qkv_layout, /*o_format=*/qkv_format, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, bottom_right_diagonal, /*devPtrQ=*/nullptr, /*devPtrK=*/nullptr, /*devPtrV=*/nullptr, /*devPtrSoftmaxOffset=*/nullptr, /*devPtrM=*/nullptr, /*devPtrO=*/nullptr, /*devPtrDescaleQ=*/nullptr, /*devPtrDescaleK=*/nullptr, /*devPtrDescaleV=*/nullptr, /*devPtrDescaleS=*/nullptr, /*devPtrScaleS=*/nullptr, /*devPtrScaleO=*/nullptr, /*devPtrAmaxO=*/nullptr, /*devPtrAmaxS=*/nullptr, /*devPtrcuSeqlensQ=*/nullptr, /*devPtrcuSeqlensKV=*/nullptr, /*devPtrDropoutSeed=*/nullptr, - /*devPtrDropoutOffset=*/nullptr, get_cudnn_fe_dtype(q_dtype), - get_cudnn_fe_dtype(o_dtype), scaling_mode, + /*devPtrDropoutOffset=*/nullptr, get_cudnn_fe_dtype(q_dtype), get_cudnn_fe_dtype(o_dtype), + scaling_mode, /*qkv_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, /*workspace=*/nullptr, &workspace_size, /*stream=*/static_cast(0), handle); return {cudnn_frontend::error_code_t::OK, ""}; - } catch (const std::exception &e) { + } catch (const std::exception& e) { return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, e.what()}; } catch (...) { return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, @@ -3084,7 +3084,7 @@ cudnn_frontend::error_t is_supported_fp8_bwd( /*workspace=*/nullptr, &workspace_size, /*stream=*/static_cast(0), handle); return {cudnn_frontend::error_code_t::OK, ""}; - } catch (const std::exception &e) { + } catch (const std::exception& e) { return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, e.what()}; } catch (...) { return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index d67cd4a6b9..256ede6e55 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -58,8 +58,8 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), static_cast(q_dtype), NVTE_DELAYED_TENSOR_SCALING, qkv_layout, bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, - max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right, - return_max_logit, cuda_graph, deterministic, handle, /*out_status=*/nullptr); + max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right, return_max_logit, + cuda_graph, deterministic, handle, /*out_status=*/nullptr); return fused_attention_backend; } From 42bcd89036fc8d093a192985304018c4497b22ab Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 7 May 2026 14:11:46 -0700 Subject: [PATCH 03/23] replace code+string with string only Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn.cpp | 98 +++++++++---------- .../fused_attn_f16_arbitrary_seqlen.cu | 18 ++-- .../fused_attn_f16_arbitrary_seqlen.h | 19 ++-- .../common/fused_attn/fused_attn_fp8.cu | 24 ++--- .../common/fused_attn/fused_attn_fp8.h | 18 ++-- .../include/transformer_engine/fused_attn.h | 37 +++---- .../jax/csrc/extensions/attention.cpp | 6 +- .../pytorch/csrc/extensions/attention.cpp | 2 +- 8 files changed, 103 insertions(+), 119 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 95405c0d6f..961b503c1c 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -228,28 +228,17 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { namespace { -// Per-thread storage for the message string handed back through -// NVTEFusedAttnBackendStatus::message. Re-used (cleared + re-populated) on every call to -// nvte_get_fused_attn_backend on this thread, which is exactly the lifetime documented in the -// public header. -thread_local std::string g_fused_attn_backend_status_buffer; - -// Apply (code, msg) to *out_status (if non-null), routing the message through the -// thread-local buffer so the returned `const char*` outlives this function call. -void set_status(NVTEFusedAttnBackendStatus *out_status, cudnn_frontend::error_code_t code, - const std::string &message) { - if (out_status == nullptr) return; - g_fused_attn_backend_status_buffer = message; - out_status->code = static_cast(code); - out_status->message = g_fused_attn_backend_status_buffer.c_str(); -} - -void set_status(NVTEFusedAttnBackendStatus *out_status, const cudnn_frontend::error_t &err) { - set_status(out_status, err.code, err.err_msg); -} - -void set_ok(NVTEFusedAttnBackendStatus *out_status) { - set_status(out_status, cudnn_frontend::error_code_t::OK, ""); +// Per-thread storage for the diagnostic string handed back through *out_reason. Re-used +// (cleared + re-populated) on every call to nvte_get_fused_attn_backend on this thread, +// which is exactly the lifetime documented in the public header. +thread_local std::string g_fused_attn_backend_reason_buffer; + +// Stash `reason` in the thread-local buffer and (if non-null) point *out_reason at it, +// so the returned `const char*` outlives this function call. +void set_reason(const char **out_reason, const std::string &reason) { + if (out_reason == nullptr) return; + g_fused_attn_backend_reason_buffer = reason; + *out_reason = g_fused_attn_backend_reason_buffer.c_str(); } } // namespace @@ -271,9 +260,9 @@ void set_ok(NVTEFusedAttnBackendStatus *out_status) { // executor cache-hits on. // 3. Return the selected backend, or NVTE_No_Backend if any probe rejects. // -// When `out_status` is non-null, it is filled with a code + message describing the -// rejection (or {OK, ""} on success). TE post-filter rejections synthesize an -// INVALID_VALUE entry; probe rejections forward the cuDNN-FE / NVTE_CHECK error verbatim. +// When `out_reason` is non-null, it is set to "" on success or to a tagged diagnostic +// string on rejection. TE post-filter rejections are tagged "[INVALID_VALUE] ..."; +// probe rejections forward the probe's tagged string verbatim. NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTEDType o_dtype, NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, @@ -281,11 +270,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic, cudnnHandle_t handle, - NVTEFusedAttnBackendStatus *out_status) { + const char **out_reason) { using namespace transformer_engine; - // Initialize to OK so callers get a clean status on the success path without us having to + // Initialize to "" so callers get a clean status on the success path without us having to // remember to set it at every return. - set_ok(out_status); + set_reason(out_reason, ""); NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type."); const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); @@ -300,8 +289,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( layout_group, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v) == DType::kInt64); if (requires_64bit_ragged_offset && cudnn_runtime_version < 90500) { - set_status(out_status, cudnn_frontend::error_code_t::INVALID_VALUE, - "Configuration requires 64-bit ragged offsets, which require cuDNN >= 9.5."); + set_reason(out_reason, + "[INVALID_VALUE] Configuration requires 64-bit ragged offsets, which require " + "cuDNN >= 9.5."); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } @@ -310,8 +300,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) { - set_status(out_status, cudnn_frontend::error_code_t::INVALID_VALUE, - "THD-format attention requires a padding-style mask " + set_reason(out_reason, + "[INVALID_VALUE] THD-format attention requires a padding-style mask " "(PADDING / PADDING_CAUSAL / PADDING_CAUSAL_BOTTOM_RIGHT)."); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } @@ -324,8 +314,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) { - set_status(out_status, cudnn_frontend::error_code_t::INVALID_VALUE, - "Known cuDNN <= 9.15 capture quirk: training + bshd/sbhd + " + set_reason(out_reason, + "[INVALID_VALUE] Known cuDNN <= 9.15 capture quirk: training + bshd/sbhd + " "max_seqlen_kv % 128 != 0 + cuda_graph + non-padding mask is unsupported."); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } @@ -346,34 +336,34 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( if (is_fp8) { // TE-only FP8 post-filters: no 64-bit ragged offsets, no max-logit output. if (requires_64bit_ragged_offset) { - set_status(out_status, cudnn_frontend::error_code_t::INVALID_VALUE, - "FP8 fused attention does not support 64-bit ragged offsets."); + set_reason(out_reason, + "[INVALID_VALUE] FP8 fused attention does not support 64-bit ragged offsets."); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } if (return_max_logit) { - set_status(out_status, cudnn_frontend::error_code_t::INVALID_VALUE, - "FP8 fused attention does not support return_max_logit."); + set_reason(out_reason, + "[INVALID_VALUE] FP8 fused attention does not support return_max_logit."); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } const DType q_t = static_cast(q_dtype); const DType o_t = static_cast(o_dtype); - auto fwd_status = is_supported_fp8_fwd( + std::string fwd_reason = is_supported_fp8_fwd( probe_batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, is_training, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, probe_bottom_right_diagonal, q_t, o_t, scaling_mode, handle); - if (fwd_status.is_bad()) { - set_status(out_status, fwd_status); + if (!fwd_reason.empty()) { + set_reason(out_reason, fwd_reason); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } if (is_training) { - auto bwd_status = is_supported_fp8_bwd( + std::string bwd_reason = is_supported_fp8_bwd( probe_batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, probe_bottom_right_diagonal, deterministic, q_t, o_t, scaling_mode, handle); - if (bwd_status.is_bad()) { - set_status(out_status, bwd_status); + if (!bwd_reason.empty()) { + set_reason(out_reason, bwd_reason); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } } @@ -382,31 +372,31 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( if (is_f16_or_bf16) { const DType q_t = static_cast(q_dtype); - auto fwd_status = is_supported_f16_fwd( + std::string fwd_reason = is_supported_f16_fwd( probe_batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, is_training, return_max_logit, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, probe_bottom_right_diagonal, q_t, handle); - if (fwd_status.is_bad()) { - set_status(out_status, fwd_status); + if (!fwd_reason.empty()) { + set_reason(out_reason, fwd_reason); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } if (is_training) { - auto bwd_status = is_supported_f16_bwd( + std::string bwd_reason = is_supported_f16_bwd( probe_batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, probe_bottom_right_diagonal, deterministic, q_t, handle); - if (bwd_status.is_bad()) { - set_status(out_status, bwd_status); + if (!bwd_reason.empty()) { + set_reason(out_reason, bwd_reason); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } } return NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; } - set_status(out_status, cudnn_frontend::error_code_t::INVALID_VALUE, - "Unsupported Q dtype for fused attention " + set_reason(out_reason, + "[INVALID_VALUE] Unsupported Q dtype for fused attention " "(only FP16/BF16/FP8_E4M3/FP8_E5M2 are routable)."); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } @@ -500,7 +490,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso is_training, Q_type, KV_type, O_type, scaling_mode, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, return_max_logit, cuda_graph, /*deterministic=*/false, handle, - /*out_status=*/nullptr); + /*out_reason=*/nullptr); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, @@ -589,7 +579,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso /*is_training=*/true, Q_type, KV_type, O_type, scaling_mode, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, /*return_max_logit=*/false, cuda_graph, deterministic, - handle, /*out_status=*/nullptr); + handle, /*out_reason=*/nullptr); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 1ced84755c..70094c4e93 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -1385,7 +1385,7 @@ ProbeDims compute_probe_dims(int64_t batch, int64_t num_attn_heads, int64_t max_ } } // namespace -cudnn_frontend::error_t is_supported_f16_fwd( +std::string is_supported_f16_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, bool return_max_logit, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, @@ -1416,16 +1416,15 @@ cudnn_frontend::error_t is_supported_f16_fwd( /*devPtrSeqOffsetsQ=*/nullptr, /*devPtrSeqOffsetsKV=*/nullptr, get_cudnn_fe_dtype(q_dtype), /*workspace=*/nullptr, &workspace_size, /*stream=*/static_cast(0), handle); - return {cudnn_frontend::error_code_t::OK, ""}; + return ""; } catch (const std::exception &e) { - return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, e.what()}; + return std::string("[GRAPH_NOT_SUPPORTED] ") + e.what(); } catch (...) { - return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, - "is_supported_f16_fwd: unknown failure"}; + return "[GRAPH_NOT_SUPPORTED] is_supported_f16_fwd: unknown failure"; } } -cudnn_frontend::error_t is_supported_f16_bwd( +std::string is_supported_f16_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, @@ -1458,12 +1457,11 @@ cudnn_frontend::error_t is_supported_f16_bwd( /*devPtrSeqOffsetsQ=*/nullptr, /*devPtrSeqOffsetsKV=*/nullptr, get_cudnn_fe_dtype(q_dtype), /*workspace=*/nullptr, &workspace_size, /*stream=*/static_cast(0), handle); - return {cudnn_frontend::error_code_t::OK, ""}; + return ""; } catch (const std::exception &e) { - return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, e.what()}; + return std::string("[GRAPH_NOT_SUPPORTED] ") + e.what(); } catch (...) { - return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, - "is_supported_f16_bwd: unknown failure"}; + return "[GRAPH_NOT_SUPPORTED] is_supported_f16_bwd: unknown failure"; } } diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 38cf48c1f0..0eabe3e8dc 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -12,7 +12,8 @@ #define TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_ #include -#include + +#include #include "common/common.h" #include "transformer_engine/fused_attn.h" @@ -50,11 +51,15 @@ void fused_attn_arbitrary_seqlen_bwd( // Probe: drives cuDNN-FE (validate -> build_operation_graph -> create_execution_plans -> // check_support -> build_plans) for an F16/BF16 forward graph with the given configuration. -// Returns the cuDNN-FE status: error_code_t::OK iff the graph compiles end-to-end. On OK, -// the built graph is inserted into the same thread-local cache used by -// fused_attn_arbitrary_seqlen_fwd_impl, so the executor cache-hits on matching descriptors. -// On rejection, err_msg contains the underlying cuDNN-FE / NVTE_CHECK message. -cudnn_frontend::error_t is_supported_f16_fwd( +// Returns an empty string iff the graph compiles end-to-end; on OK the built graph is +// inserted into the same thread-local cache used by fused_attn_arbitrary_seqlen_fwd_impl, +// so the executor cache-hits on matching descriptors. +// +// On rejection, returns a non-empty diagnostic of the form +// "[] " +// where is a stable tag mirroring cudnn_frontend::error_code_t names +// (e.g. GRAPH_NOT_SUPPORTED for cuDNN-FE rejections forwarded from the support chain). +std::string is_supported_f16_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, bool return_max_logit, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, @@ -62,7 +67,7 @@ cudnn_frontend::error_t is_supported_f16_fwd( int64_t window_size_right, bool bottom_right_diagonal, DType q_dtype, cudnnHandle_t handle); // Probe: same as above for the F16/BF16 backward graph. -cudnn_frontend::error_t is_supported_f16_bwd( +std::string is_supported_f16_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 8a152cf489..27bd0af3f3 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2992,7 +2992,7 @@ void fused_attn_fp8_bwd( } } -cudnn_frontend::error_t is_supported_fp8_fwd( +std::string is_supported_fp8_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, @@ -3004,8 +3004,7 @@ cudnn_frontend::error_t is_supported_fp8_fwd( const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); if (qkv_format != NVTE_QKV_Format::NVTE_BSHD && qkv_format != NVTE_QKV_Format::NVTE_SBHD && qkv_format != NVTE_QKV_Format::NVTE_BHSD) { - return {cudnn_frontend::error_code_t::INVALID_VALUE, - "FP8 fused attention only supports BSHD/SBHD/BHSD layouts."}; + return "[INVALID_VALUE] FP8 fused attention only supports BSHD/SBHD/BHSD layouts."; } size_t workspace_size = 0; try { @@ -3027,16 +3026,15 @@ cudnn_frontend::error_t is_supported_fp8_fwd( /*qkv_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, /*workspace=*/nullptr, &workspace_size, /*stream=*/static_cast(0), handle); - return {cudnn_frontend::error_code_t::OK, ""}; + return ""; } catch (const std::exception& e) { - return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, e.what()}; + return std::string("[GRAPH_NOT_SUPPORTED] ") + e.what(); } catch (...) { - return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, - "is_supported_fp8_fwd: unknown failure"}; + return "[GRAPH_NOT_SUPPORTED] is_supported_fp8_fwd: unknown failure"; } } -cudnn_frontend::error_t is_supported_fp8_bwd( +std::string is_supported_fp8_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, @@ -3046,8 +3044,7 @@ cudnn_frontend::error_t is_supported_fp8_bwd( const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); if (qkv_format != NVTE_QKV_Format::NVTE_BSHD && qkv_format != NVTE_QKV_Format::NVTE_SBHD && qkv_format != NVTE_QKV_Format::NVTE_BHSD) { - return {cudnn_frontend::error_code_t::INVALID_VALUE, - "FP8 fused attention only supports BSHD/SBHD/BHSD layouts."}; + return "[INVALID_VALUE] FP8 fused attention only supports BSHD/SBHD/BHSD layouts."; } // For FP8 bwd, dO data type matches O data type and dQKV data type matches Q data type // (this mirrors the assumption used by callers of fused_attn_fp8_bwd in TE). @@ -3083,12 +3080,11 @@ cudnn_frontend::error_t is_supported_fp8_bwd( /*do_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, /*workspace=*/nullptr, &workspace_size, /*stream=*/static_cast(0), handle); - return {cudnn_frontend::error_code_t::OK, ""}; + return ""; } catch (const std::exception& e) { - return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, e.what()}; + return std::string("[GRAPH_NOT_SUPPORTED] ") + e.what(); } catch (...) { - return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, - "is_supported_fp8_bwd: unknown failure"}; + return "[GRAPH_NOT_SUPPORTED] is_supported_fp8_bwd: unknown failure"; } } diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 5c7f11d80e..f91cdcf291 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -8,7 +8,7 @@ * \brief Functions for fused attention for FP8 with seqlen <= 512 */ -#include +#include #include "transformer_engine/fused_attn.h" #include "transformer_engine/transformer_engine.h" @@ -44,10 +44,15 @@ void fused_attn_fp8_bwd( // Probe: drives cuDNN-FE (validate -> build_operation_graph -> create_execution_plans -> // check_support -> build_plans) for an FP8 forward graph with the given configuration. -// Returns the cuDNN-FE status: error_code_t::OK iff the graph compiles end-to-end. On OK, -// the built graph is inserted into the same thread-local cache used by fused_attn_fp8_fwd_impl. -// On rejection, err_msg contains the underlying cuDNN-FE / NVTE_CHECK message. -cudnn_frontend::error_t is_supported_fp8_fwd( +// Returns an empty string iff the graph compiles end-to-end; on OK the built graph is +// inserted into the same thread-local cache used by fused_attn_fp8_fwd_impl. +// +// On rejection, returns a non-empty diagnostic of the form +// "[] " +// where mirrors cudnn_frontend::error_code_t names (INVALID_VALUE for the +// FP8-only layout pre-filter, GRAPH_NOT_SUPPORTED for cuDNN-FE rejections forwarded +// from the support chain). +std::string is_supported_fp8_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, @@ -56,12 +61,11 @@ cudnn_frontend::error_t is_supported_fp8_fwd( cudnnHandle_t handle); // Probe: same as above for the FP8 backward graph. -cudnn_frontend::error_t is_supported_fp8_bwd( +std::string is_supported_fp8_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, DType q_dtype, DType o_dtype, NVTEScalingMode scaling_mode, cudnnHandle_t handle); ->>>>>>> c9006435 (refactor nvte_get_fused_attn_backend with FE calls) } // namespace transformer_engine diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 787e97d628..bbcdf08995 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -198,22 +198,6 @@ NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout); */ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); -/*! \struct NVTEFusedAttnBackendStatus - * \brief Diagnostic info from \c nvte_get_fused_attn_backend. - * - * Filled by \c nvte_get_fused_attn_backend when the caller passes a non-NULL pointer. - * When the routing decision is supported, \c code is 0 and \c message is the empty - * string. When the routing rejects the configuration, \c code is the underlying - * cuDNN-FE \c cudnn_frontend::error_code_t cast to \c int (TE-synthesized post-filter - * rejections use \c INVALID_VALUE), and \c message is a null-terminated human-readable - * reason that points into per-thread storage owned by TE. The pointer is valid only - * until the next call to \c nvte_get_fused_attn_backend on the same thread. - */ -typedef struct NVTEFusedAttnBackendStatus { - int code; - const char *message; -} NVTEFusedAttnBackendStatus; - /*! \brief Get fused attention backend based on input parameters. * * Authoritative routing: when a non-NVTE_No_Backend value is returned, the configuration @@ -249,12 +233,19 @@ typedef struct NVTEFusedAttnBackendStatus { * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] deterministic Whether determinism is required or not. * \param[in] handle cuDNN handle used for the support chain. Required. - * \param[out] out_status Optional. When non-NULL, populated with a code + - * message describing why the configuration was - * rejected (NVTE_No_Backend) or with code=0 and - * message="" on success. The message buffer lives in - * thread-local storage and is overwritten on every - * call on the same thread. + * \param[out] out_reason Optional. When non-NULL, set to a null-terminated + * diagnostic string describing why the configuration + * was rejected (NVTE_No_Backend) or set to "" on + * success. Rejection messages are tagged with a + * stable category prefix that mirrors + * \c cudnn_frontend::error_code_t, e.g. + * \c "[INVALID_VALUE] ..." for TE post-filter + * rejections and FP8 layout pre-filter rejections, + * \c "[GRAPH_NOT_SUPPORTED] ..." for cuDNN-FE + * rejections forwarded from the support chain. The + * pointer points into per-thread storage owned by TE + * and is valid only until the next call to + * \c nvte_get_fused_attn_backend on the same thread. */ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTEDType o_dtype, @@ -263,7 +254,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic, cudnnHandle_t handle, - NVTEFusedAttnBackendStatus *out_status); + const char **out_reason); /*! \brief Compute dot product attention with separate Q, K and V. * diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index c6a8897089..669570daa5 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -25,7 +25,7 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend( mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, /*return_max_logit=*/false, /*cuda_graph=*/false, deterministic, handle, - /*out_status=*/nullptr); + /*out_reason=*/nullptr); return backend; } @@ -282,7 +282,7 @@ static void FusedAttnForwardImpl( static_cast(dtype), NVTE_DELAYED_TENSOR_SCALING, qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, /*return_max_logit=*/false, - /*cuda_graph=*/false, deterministic, _handle_fwd, /*out_status=*/nullptr); + /*cuda_graph=*/false, deterministic, _handle_fwd, /*out_reason=*/nullptr); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); /* Auxiliary tensors (to be propagated to the backward pass later) */ @@ -560,7 +560,7 @@ static void FusedAttnBackwardImpl( static_cast(dtype), NVTE_DELAYED_TENSOR_SCALING, qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, /*return_max_logit=*/false, - /*cuda_graph=*/false, deterministic, _handle_bwd, /*out_status=*/nullptr); + /*cuda_graph=*/false, deterministic, _handle_bwd, /*out_reason=*/nullptr); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, softmax_aux, rng_state, bias, softmax_offset); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 256ede6e55..3af4ba3831 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -59,7 +59,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( static_cast(q_dtype), NVTE_DELAYED_TENSOR_SCALING, qkv_layout, bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right, return_max_logit, - cuda_graph, deterministic, handle, /*out_status=*/nullptr); + cuda_graph, deterministic, handle, /*out_reason=*/nullptr); return fused_attention_backend; } From de8e81457b8576b222313ae068042ff6f5b68598 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 7 May 2026 17:05:13 -0700 Subject: [PATCH 04/23] clean up logic/comments/structure Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/jax/test_fused_attn.py | 4 +- .../common/fused_attn/fused_attn.cpp | 113 +++++--------- .../fused_attn_f16_arbitrary_seqlen.cu | 140 ++++++++---------- .../fused_attn_f16_arbitrary_seqlen.h | 21 +-- .../common/fused_attn/fused_attn_fp8.cu | 26 ++-- .../common/fused_attn/fused_attn_fp8.h | 21 +-- .../include/transformer_engine/fused_attn.h | 34 +---- .../common/util/pybind_helper.h | 7 + .../jax/cpp_extensions/attention.py | 20 ++- transformer_engine/jax/csrc/extensions.h | 15 +- .../jax/csrc/extensions/attention.cpp | 32 ++-- .../attention/dot_product_attention/utils.py | 9 +- transformer_engine/pytorch/csrc/extensions.h | 15 +- .../pytorch/csrc/extensions/attention.cpp | 30 ++-- 14 files changed, 209 insertions(+), 278 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 8b727b1d43..f21c6cb2d0 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -444,7 +444,7 @@ def _check_configs(self): "is either BSHD_BSHD_BSHD or THD_THD_THD" ) - self.backend = FusedAttnHelper( + self.backend, message = FusedAttnHelper( self.is_training, self.dtype, self.dtype, @@ -462,7 +462,7 @@ def _check_configs(self): (-1, -1) if self.window_size is None else self.window_size, ).get_fused_attn_backend() if self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: - pytest.skip("Unsupported inputs combination or device compute capability.") + pytest.skip(message) if ( self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 961b503c1c..9587693645 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -228,41 +228,19 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { namespace { -// Per-thread storage for the diagnostic string handed back through *out_reason. Re-used -// (cleared + re-populated) on every call to nvte_get_fused_attn_backend on this thread, -// which is exactly the lifetime documented in the public header. -thread_local std::string g_fused_attn_backend_reason_buffer; - -// Stash `reason` in the thread-local buffer and (if non-null) point *out_reason at it, -// so the returned `const char*` outlives this function call. -void set_reason(const char **out_reason, const std::string &reason) { - if (out_reason == nullptr) return; - g_fused_attn_backend_reason_buffer = reason; - *out_reason = g_fused_attn_backend_reason_buffer.c_str(); +// per-thread storage for the diagnostic string +// re-used (cleared + re-populated) on every call to nvte_get_fused_attn_backend on this thread +thread_local std::string fused_attn_backend_message_buffer; + +void set_message(const char **message, const std::string &reason) { + if (message == nullptr) return; + fused_attn_backend_message_buffer = reason; + *message = fused_attn_backend_message_buffer.c_str(); } } // namespace // select a backend for fused attention -// -// Routing flow: -// 1. Apply TE post-filters that encode policies cuDNN-FE doesn't model directly: -// a. requires_64bit_ragged_offset -> cudnn >= 9.5 -// b. qkv_format == THD requires a padding-style mask -// c. cuDNN <= 9.15 + is_training + bshd/sbhd + max_seqlen_kv % 128 != 0 + -// cuda_graph + non-padding mask is rejected (known capture quirk) -// 2. Dispatch by dtype to the appropriate probe(s): -// - FP8 (E4M3/E5M2): is_supported_fp8_fwd (+ is_supported_fp8_bwd if training) -// - FP16/BF16: is_supported_f16_fwd (+ is_supported_f16_bwd if training) -// The probes call the same _impl that the executor uses, with workspace=nullptr. -// They run validate -> build_operation_graph -> create_execution_plans -> -// check_support -> build_plans, and populate a thread-local cache that the -// executor cache-hits on. -// 3. Return the selected backend, or NVTE_No_Backend if any probe rejects. -// -// When `out_reason` is non-null, it is set to "" on success or to a tagged diagnostic -// string on rejection. TE post-filter rejections are tagged "[INVALID_VALUE] ..."; -// probe rejections forward the probe's tagged string verbatim. NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTEDType o_dtype, NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, @@ -270,62 +248,50 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic, cudnnHandle_t handle, - const char **out_reason) { + const char **message) { using namespace transformer_engine; - // Initialize to "" so callers get a clean status on the success path without us having to - // remember to set it at every return. - set_reason(out_reason, ""); + set_message(message, ""); NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type."); const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); const auto cudnn_runtime_version = cudnnGetVersion(); - // ---------- TE post-filters (apply before delegating to cuDNN-FE) ---------- - - // (1) Ragged-offset width: cuDNN < 9.5 only supports 32-bit offsets. + // THD + 64-bit ragged offsets require cuDNN >= 9.5 const bool requires_64bit_ragged_offset = (qkv_format == NVTE_THD && fused_attn::get_ragged_offset_dtype( layout_group, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v) == DType::kInt64); if (requires_64bit_ragged_offset && cudnn_runtime_version < 90500) { - set_reason(out_reason, - "[INVALID_VALUE] Configuration requires 64-bit ragged offsets, which require " + set_message(message, + "Configuration requires 64-bit ragged offsets, which require " "cuDNN >= 9.5."); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } - // (2) THD requires a padding-style mask. + // THD requires padding-style mask if (qkv_format == NVTE_QKV_Format::NVTE_THD && attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) { - set_reason(out_reason, - "[INVALID_VALUE] THD-format attention requires a padding-style mask " - "(PADDING / PADDING_CAUSAL / PADDING_CAUSAL_BOTTOM_RIGHT)."); + set_message(message, + "THD format requires PADDING / PADDING_CAUSAL / PADDING_CAUSAL_BOTTOM_RIGHT mask."); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } - // (3) cuDNN-Graph capture quirk on cuDNN <= 9.15: training + bshd/sbhd with - // max_seqlen_kv % 128 != 0 + cuda_graph + non-padding mask hangs/miscompiles. + // avoid CUDA graph issue with cuDNN <= 9.15 if (cudnn_runtime_version <= 91500 && is_training && (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && (max_seqlen_kv % 128 != 0) && cuda_graph && attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) { - set_reason(out_reason, - "[INVALID_VALUE] Known cuDNN <= 9.15 capture quirk: training + bshd/sbhd + " - "max_seqlen_kv % 128 != 0 + cuda_graph + non-padding mask is unsupported."); + set_message(message, + "Known cuDNN <= 9.15 issue with CUDA graph. Please upgrade cuDNN."); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } - // ---------- Dispatch by dtype ---------- - - // Probes use a single-batch graph; capability checks in cuDNN-FE are batch-agnostic. constexpr size_t probe_batch = 1; - // bottom_right_diagonal is a runtime API knob the router doesn't see; the BRCM-via-mask - // case is captured by attn_mask_type, so we probe with the default top-left alignment. constexpr bool probe_bottom_right_diagonal = false; const bool is_fp8 = @@ -334,36 +300,30 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (q_dtype == NVTEDType::kNVTEFloat16 || q_dtype == NVTEDType::kNVTEBFloat16); if (is_fp8) { - // TE-only FP8 post-filters: no 64-bit ragged offsets, no max-logit output. - if (requires_64bit_ragged_offset) { - set_reason(out_reason, - "[INVALID_VALUE] FP8 fused attention does not support 64-bit ragged offsets."); - return NVTE_Fused_Attn_Backend::NVTE_No_Backend; - } if (return_max_logit) { - set_reason(out_reason, - "[INVALID_VALUE] FP8 fused attention does not support return_max_logit."); + set_message(message, + "FP8 fused attention does not support return_max_logit=True."); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } - const DType q_t = static_cast(q_dtype); + const DType qkv_t = static_cast(q_dtype); const DType o_t = static_cast(o_dtype); std::string fwd_reason = is_supported_fp8_fwd( probe_batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, is_training, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, probe_bottom_right_diagonal, q_t, o_t, scaling_mode, + window_size_left, window_size_right, probe_bottom_right_diagonal, qkv_t, o_t, scaling_mode, handle); if (!fwd_reason.empty()) { - set_reason(out_reason, fwd_reason); + set_message(message, fwd_reason); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } if (is_training) { std::string bwd_reason = is_supported_fp8_bwd( probe_batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, probe_bottom_right_diagonal, deterministic, q_t, o_t, - scaling_mode, handle); + window_size_left, window_size_right, probe_bottom_right_diagonal, deterministic, qkv_t, + o_t, scaling_mode, handle); if (!bwd_reason.empty()) { - set_reason(out_reason, bwd_reason); + set_message(message, bwd_reason); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } } @@ -371,33 +331,32 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } if (is_f16_or_bf16) { - const DType q_t = static_cast(q_dtype); + const DType qkv_t = static_cast(q_dtype); std::string fwd_reason = is_supported_f16_fwd( probe_batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, is_training, return_max_logit, dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size_left, window_size_right, probe_bottom_right_diagonal, q_t, + softmax_type, window_size_left, window_size_right, probe_bottom_right_diagonal, qkv_t, handle); if (!fwd_reason.empty()) { - set_reason(out_reason, fwd_reason); + set_message(message, fwd_reason); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } if (is_training) { std::string bwd_reason = is_supported_f16_bwd( probe_batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, probe_bottom_right_diagonal, deterministic, q_t, + window_size_left, window_size_right, probe_bottom_right_diagonal, deterministic, qkv_t, handle); if (!bwd_reason.empty()) { - set_reason(out_reason, bwd_reason); + set_message(message, bwd_reason); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } } return NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; } - set_reason(out_reason, - "[INVALID_VALUE] Unsupported Q dtype for fused attention " - "(only FP16/BF16/FP8_E4M3/FP8_E5M2 are routable)."); + set_message(message, + "Unsupported QKV dtype qkv_dtype=" + std::to_string(q_dtype) + " ."); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } @@ -490,7 +449,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso is_training, Q_type, KV_type, O_type, scaling_mode, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, return_max_logit, cuda_graph, /*deterministic=*/false, handle, - /*out_reason=*/nullptr); + /*message=*/nullptr); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, @@ -579,7 +538,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso /*is_training=*/true, Q_type, KV_type, O_type, scaling_mode, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, /*return_max_logit=*/false, cuda_graph, deterministic, - handle, /*out_reason=*/nullptr); + handle, /*message=*/nullptr); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 70094c4e93..81e66d8800 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -1334,31 +1334,18 @@ void fused_attn_arbitrary_seqlen_bwd( } } -namespace { -// Probe-time defaults for runtime-only quantities the router doesn't see (paged-KV dims, -// ragged max-tokens, bias dims). These produce a graph whose support surface matches the -// real executor's: for non-paged / non-ragged paths these are unused inside the impl; -// for ragged-THD we rebind to worst-case bounds; for paged we use 1 page of full s_kv per -// batch (= same dims as non-paged), so cuDNN-FE applies the paged-attention support rules. -struct ProbeDims { - int64_t max_b; - int64_t max_t_q; - int64_t max_t_kv; - int64_t num_pages_k; - int64_t num_pages_v; - int64_t page_size_k; - int64_t page_size_v; - int64_t max_pages_per_seq_k; - int64_t max_pages_per_seq_v; - int64_t bias_b; - int64_t bias_h; - int64_t bias_sq; - int64_t bias_skv; -}; - -ProbeDims compute_probe_dims(int64_t batch, int64_t num_attn_heads, int64_t max_seqlen_q, - int64_t max_seqlen_kv, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type) { +std::string is_supported_f16_fwd( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, + bool return_max_logit, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, DType qkv_dtype, + cudnnHandle_t handle) { + const auto b = static_cast(batch); + const auto h = static_cast(num_attn_heads); + const auto sq = static_cast(max_seqlen_q); + const auto skv = static_cast(max_seqlen_kv); + const NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); const NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); const bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD); @@ -1367,45 +1354,29 @@ ProbeDims compute_probe_dims(int64_t batch, int64_t num_attn_heads, int64_t max_ const bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD); const bool has_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); - ProbeDims d{}; - d.max_b = (is_ragged_q || is_ragged_kv) ? batch : 0; - d.max_t_q = is_ragged_q ? batch * max_seqlen_q : 0; - d.max_t_kv = is_ragged_kv ? batch * max_seqlen_kv : 0; - d.num_pages_k = is_paged_kv ? batch : 0; - d.num_pages_v = is_paged_kv ? batch : 0; - d.page_size_k = is_paged_kv ? max_seqlen_kv : 0; - d.page_size_v = is_paged_kv ? max_seqlen_kv : 0; - d.max_pages_per_seq_k = is_paged_kv ? 1 : 0; - d.max_pages_per_seq_v = is_paged_kv ? 1 : 0; - d.bias_b = has_bias ? batch : 0; - d.bias_h = has_bias ? num_attn_heads : 0; - d.bias_sq = has_bias ? max_seqlen_q : 0; - d.bias_skv = has_bias ? max_seqlen_kv : 0; - return d; -} -} // namespace - -std::string is_supported_f16_fwd( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, - bool return_max_logit, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool bottom_right_diagonal, DType q_dtype, cudnnHandle_t handle) { - const ProbeDims d = - compute_probe_dims(static_cast(batch), static_cast(num_attn_heads), - static_cast(max_seqlen_q), static_cast(max_seqlen_kv), - qkv_layout, bias_type); - const NVTE_QKV_Format o_format = nvte_get_q_format(qkv_layout); + const int64_t max_b = (is_ragged_q || is_ragged_kv) ? b : 0; + const int64_t max_t_q = is_ragged_q ? b * sq : 0; + const int64_t max_t_kv = is_ragged_kv ? b * skv : 0; + const int64_t num_pages_k = is_paged_kv ? b : 0; + const int64_t num_pages_v = is_paged_kv ? b : 0; + const int64_t page_size_k = is_paged_kv ? skv : 0; + const int64_t page_size_v = is_paged_kv ? skv : 0; + const int64_t max_pages_per_seq_k = is_paged_kv ? 1 : 0; + const int64_t max_pages_per_seq_v = is_paged_kv ? 1 : 0; + const int64_t bias_b = has_bias ? b : 0; + const int64_t bias_h = has_bias ? h : 0; + const int64_t bias_sq = has_bias ? sq : 0; + const int64_t bias_skv = has_bias ? skv : 0; + + const NVTE_QKV_Format o_format = q_format; size_t workspace_size = 0; try { fused_attn::fused_attn_arbitrary_seqlen_fwd_impl( - static_cast(batch), static_cast(num_attn_heads), - static_cast(num_gqa_groups), static_cast(max_seqlen_q), - static_cast(max_seqlen_kv), static_cast(head_dim_qk), - static_cast(head_dim_v), d.max_b, d.max_t_q, d.max_t_kv, d.num_pages_k, - d.num_pages_v, d.page_size_k, d.page_size_v, d.max_pages_per_seq_k, d.max_pages_per_seq_v, - d.bias_b, d.bias_h, d.bias_sq, d.bias_skv, is_training, return_max_logit, + b, h, static_cast(num_gqa_groups), sq, skv, static_cast(head_dim_qk), + static_cast(head_dim_v), max_b, max_t_q, max_t_kv, num_pages_k, num_pages_v, + page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, + bias_sq, bias_skv, is_training, return_max_logit, /*scaling_factor=*/1.0f, p_dropout, qkv_layout, o_format, bias_type, mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, /*devPtrQ=*/nullptr, /*devPtrK=*/nullptr, /*devPtrV=*/nullptr, /*devPtrBias=*/nullptr, @@ -1413,14 +1384,15 @@ std::string is_supported_f16_fwd( /*devPtrO=*/nullptr, /*devPtrDropoutSeed=*/nullptr, /*devPtrDropoutOffset=*/nullptr, /*devPtrCuSeqlensQ=*/nullptr, /*devPtrCuSeqlensKV=*/nullptr, /*devPtrPageTableK=*/nullptr, /*devPtrPageTableV=*/nullptr, - /*devPtrSeqOffsetsQ=*/nullptr, /*devPtrSeqOffsetsKV=*/nullptr, get_cudnn_fe_dtype(q_dtype), + /*devPtrSeqOffsetsQ=*/nullptr, /*devPtrSeqOffsetsKV=*/nullptr, + get_cudnn_fe_dtype(qkv_dtype), /*workspace=*/nullptr, &workspace_size, /*stream=*/static_cast(0), handle); return ""; } catch (const std::exception &e) { - return std::string("[GRAPH_NOT_SUPPORTED] ") + e.what(); + return e.what(); } catch (...) { - return "[GRAPH_NOT_SUPPORTED] is_supported_f16_fwd: unknown failure"; + return "is_supported_f16_fwd: unknown failure."; } } @@ -1429,23 +1401,36 @@ std::string is_supported_f16_bwd( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic, DType q_dtype, cudnnHandle_t handle) { - const ProbeDims d = - compute_probe_dims(static_cast(batch), static_cast(num_attn_heads), - static_cast(max_seqlen_q), static_cast(max_seqlen_kv), - qkv_layout, bias_type); - const NVTE_QKV_Format o_format = nvte_get_q_format(qkv_layout); + bool bottom_right_diagonal, bool deterministic, DType qkv_dtype, cudnnHandle_t handle) { + const auto b = static_cast(batch); + const auto h = static_cast(num_attn_heads); + const auto sq = static_cast(max_seqlen_q); + const auto skv = static_cast(max_seqlen_kv); + + const NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + const NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + const bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD); + const bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD); + const bool has_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); + + const int64_t max_b = (is_ragged_q || is_ragged_kv) ? b : 0; + const int64_t max_t_q = is_ragged_q ? b * sq : 0; + const int64_t max_t_kv = is_ragged_kv ? b * skv : 0; + const int64_t bias_b = has_bias ? b : 0; + const int64_t bias_h = has_bias ? h : 0; + const int64_t bias_sq = has_bias ? sq : 0; + const int64_t bias_skv = has_bias ? skv : 0; + + const NVTE_QKV_Format o_format = q_format; const NVTE_QKV_Format do_format = o_format; const NVTE_QKV_Layout dqkv_layout = qkv_layout; size_t workspace_size = 0; try { fused_attn::fused_attn_arbitrary_seqlen_bwd_impl( - static_cast(batch), static_cast(num_attn_heads), - static_cast(num_gqa_groups), static_cast(max_seqlen_q), - static_cast(max_seqlen_kv), static_cast(head_dim_qk), - static_cast(head_dim_v), d.max_b, d.max_t_q, d.max_t_kv, d.bias_b, d.bias_h, - d.bias_sq, d.bias_skv, /*scaling_factor=*/1.0f, p_dropout, qkv_layout, o_format, do_format, + b, h, static_cast(num_gqa_groups), sq, skv, static_cast(head_dim_qk), + static_cast(head_dim_v), max_b, max_t_q, max_t_kv, bias_b, bias_h, bias_sq, + bias_skv, /*scaling_factor=*/1.0f, p_dropout, qkv_layout, o_format, do_format, dqkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, /*devPtrQ=*/nullptr, /*devPtrKTranspose=*/nullptr, /*devPtrVTranspose=*/nullptr, /*devPtrO=*/nullptr, /*devPtrSoftmaxStats=*/nullptr, @@ -1454,14 +1439,15 @@ std::string is_supported_f16_bwd( /*devPtrdBias=*/nullptr, /*devPtrdSoftmaxOffset=*/nullptr, /*devPtrDropoutSeed=*/nullptr, /*devPtrDropoutOffset=*/nullptr, /*devPtrCuSeqlensQ=*/nullptr, /*devPtrCuSeqlensKV=*/nullptr, - /*devPtrSeqOffsetsQ=*/nullptr, /*devPtrSeqOffsetsKV=*/nullptr, get_cudnn_fe_dtype(q_dtype), + /*devPtrSeqOffsetsQ=*/nullptr, /*devPtrSeqOffsetsKV=*/nullptr, + get_cudnn_fe_dtype(qkv_dtype), /*workspace=*/nullptr, &workspace_size, /*stream=*/static_cast(0), handle); return ""; } catch (const std::exception &e) { - return std::string("[GRAPH_NOT_SUPPORTED] ") + e.what(); + return e.what(); } catch (...) { - return "[GRAPH_NOT_SUPPORTED] is_supported_f16_bwd: unknown failure"; + return "is_supported_f16_bwd: unknown failure."; } } diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 0eabe3e8dc..3f5ae717bb 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -49,30 +49,25 @@ void fused_attn_arbitrary_seqlen_bwd( const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); -// Probe: drives cuDNN-FE (validate -> build_operation_graph -> create_execution_plans -> -// check_support -> build_plans) for an F16/BF16 forward graph with the given configuration. -// Returns an empty string iff the graph compiles end-to-end; on OK the built graph is -// inserted into the same thread-local cache used by fused_attn_arbitrary_seqlen_fwd_impl, -// so the executor cache-hits on matching descriptors. -// -// On rejection, returns a non-empty diagnostic of the form -// "[] " -// where is a stable tag mirroring cudnn_frontend::error_code_t names -// (e.g. GRAPH_NOT_SUPPORTED for cuDNN-FE rejections forwarded from the support chain). +// check if a given configuration is supported for F16/BF16 forward; +// if it is, cache the graph built for this config, and return an empty string; +// if not, return a diagnostic message in the form of a string. std::string is_supported_f16_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, bool return_max_logit, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool bottom_right_diagonal, DType q_dtype, cudnnHandle_t handle); + int64_t window_size_right, bool bottom_right_diagonal, DType qkv_dtype, cudnnHandle_t handle); -// Probe: same as above for the F16/BF16 backward graph. +// check if a given configuration is supported for F16/BF16 backward; +// if it is, cache the graph built for this config, and return an empty string; +// if not, return a diagnostic message in the form of a string. std::string is_supported_f16_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic, DType q_dtype, cudnnHandle_t handle); + bool bottom_right_diagonal, bool deterministic, DType qkv_dtype, cudnnHandle_t handle); } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 27bd0af3f3..c0b515138c 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2997,14 +2997,12 @@ std::string is_supported_fp8_fwd( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, DType q_dtype, DType o_dtype, NVTEScalingMode scaling_mode, + bool bottom_right_diagonal, DType qkv_dtype, DType o_dtype, NVTEScalingMode scaling_mode, cudnnHandle_t handle) { - // FP8 fwd impl rejects any qkv_format other than BSHD/SBHD/BHSD with NVTE_ERROR; mirror that - // here so the probe returns a typed rejection instead of catching the throw. const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); if (qkv_format != NVTE_QKV_Format::NVTE_BSHD && qkv_format != NVTE_QKV_Format::NVTE_SBHD && qkv_format != NVTE_QKV_Format::NVTE_BHSD) { - return "[INVALID_VALUE] FP8 fused attention only supports BSHD/SBHD/BHSD layouts."; + return "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " + qkv_format + "."; } size_t workspace_size = 0; try { @@ -3021,16 +3019,16 @@ std::string is_supported_fp8_fwd( /*devPtrDescaleS=*/nullptr, /*devPtrScaleS=*/nullptr, /*devPtrScaleO=*/nullptr, /*devPtrAmaxO=*/nullptr, /*devPtrAmaxS=*/nullptr, /*devPtrcuSeqlensQ=*/nullptr, /*devPtrcuSeqlensKV=*/nullptr, /*devPtrDropoutSeed=*/nullptr, - /*devPtrDropoutOffset=*/nullptr, get_cudnn_fe_dtype(q_dtype), get_cudnn_fe_dtype(o_dtype), - scaling_mode, + /*devPtrDropoutOffset=*/nullptr, get_cudnn_fe_dtype(qkv_dtype), + get_cudnn_fe_dtype(o_dtype), scaling_mode, /*qkv_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, /*workspace=*/nullptr, &workspace_size, /*stream=*/static_cast(0), handle); return ""; } catch (const std::exception& e) { - return std::string("[GRAPH_NOT_SUPPORTED] ") + e.what(); + return e.what(); } catch (...) { - return "[GRAPH_NOT_SUPPORTED] is_supported_fp8_fwd: unknown failure"; + return "is_supported_fp8_fwd: unknown failure."; } } @@ -3039,16 +3037,14 @@ std::string is_supported_fp8_bwd( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic, DType q_dtype, DType o_dtype, + bool bottom_right_diagonal, bool deterministic, DType qkv_dtype, DType o_dtype, NVTEScalingMode scaling_mode, cudnnHandle_t handle) { const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); if (qkv_format != NVTE_QKV_Format::NVTE_BSHD && qkv_format != NVTE_QKV_Format::NVTE_SBHD && qkv_format != NVTE_QKV_Format::NVTE_BHSD) { - return "[INVALID_VALUE] FP8 fused attention only supports BSHD/SBHD/BHSD layouts."; + return "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " + qkv_format + "."; } - // For FP8 bwd, dO data type matches O data type and dQKV data type matches Q data type - // (this mirrors the assumption used by callers of fused_attn_fp8_bwd in TE). - const cudnn_frontend::DataType_t qkv_t = get_cudnn_fe_dtype(q_dtype); + const cudnn_frontend::DataType_t qkv_t = get_cudnn_fe_dtype(qkv_dtype); const cudnn_frontend::DataType_t o_t = get_cudnn_fe_dtype(o_dtype); const cudnn_frontend::DataType_t do_t = o_t; const cudnn_frontend::DataType_t dqkv_t = qkv_t; @@ -3082,9 +3078,9 @@ std::string is_supported_fp8_bwd( /*stream=*/static_cast(0), handle); return ""; } catch (const std::exception& e) { - return std::string("[GRAPH_NOT_SUPPORTED] ") + e.what(); + return e.what(); } catch (...) { - return "[GRAPH_NOT_SUPPORTED] is_supported_fp8_bwd: unknown failure"; + return "is_supported_fp8_bwd: unknown failure."; } } diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index f91cdcf291..7c7460e4ea 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -42,30 +42,25 @@ void fused_attn_fp8_bwd( Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); -// Probe: drives cuDNN-FE (validate -> build_operation_graph -> create_execution_plans -> -// check_support -> build_plans) for an FP8 forward graph with the given configuration. -// Returns an empty string iff the graph compiles end-to-end; on OK the built graph is -// inserted into the same thread-local cache used by fused_attn_fp8_fwd_impl. -// -// On rejection, returns a non-empty diagnostic of the form -// "[] " -// where mirrors cudnn_frontend::error_code_t names (INVALID_VALUE for the -// FP8-only layout pre-filter, GRAPH_NOT_SUPPORTED for cuDNN-FE rejections forwarded -// from the support chain). +// check if a given configuration is supported for FP8 forward; +// if it is, cache the graph built for this config, and return an empty string; +// if not, return a diagnostic message in the form of a string. std::string is_supported_fp8_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, DType q_dtype, DType o_dtype, NVTEScalingMode scaling_mode, + bool bottom_right_diagonal, DType qkv_dtype, DType o_dtype, NVTEScalingMode scaling_mode, cudnnHandle_t handle); -// Probe: same as above for the FP8 backward graph. +// check if a given configuration is supported for FP8 backward; +// if it is, cache the graph built for this config, and return an empty string; +// if not, return a diagnostic message in the form of a string. std::string is_supported_fp8_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic, DType q_dtype, DType o_dtype, + bool bottom_right_diagonal, bool deterministic, DType qkv_dtype, DType o_dtype, NVTEScalingMode scaling_mode, cudnnHandle_t handle); } // namespace transformer_engine diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index bbcdf08995..b90749c8ee 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -199,23 +199,12 @@ NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout); NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); /*! \brief Get fused attention backend based on input parameters. - * - * Authoritative routing: when a non-NVTE_No_Backend value is returned, the configuration - * is guaranteed to compile through cuDNN-FE (validate -> build_operation_graph -> - * create_execution_plans -> check_support -> build_plans). The router applies a small - * set of TE-specific post-filters in addition to delegating to cuDNN-FE for capability - * checks. On success the built plan is cached, so the executor avoids rebuilding. * * \param[in] is_training Whether the model is in training mode. * \param[in] q_dtype The data type of Tensor Q. * \param[in] kv_dtype The data type of Tensors K, V. - * \param[in] o_dtype The data type of output Tensor O. Used by the FP8 - * branch to disambiguate FP8 vs HALF/BF16 output; - * ignored by the F16/BF16 branch (pass q_dtype). - * \param[in] scaling_mode Scaling mode of the input tensors. Used by the FP8 - * branch to select among delayed/current/MXFP8 recipes; - * ignored by the F16/BF16 branch - * (pass NVTE_DELAYED_TENSOR_SCALING). + * \param[in] o_dtype The data type of Tensor O. + * \param[in] scaling_mode Scaling mode of attention. * \param[in] qkv_layout The layout of Tensors Q, K, V. * \param[in] bias_type The attention bias type. * \param[in] attn_mask_type The attention mask type. @@ -232,20 +221,9 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); * \param[in] return_max_logit Whether to produce Max along with Stats. * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] deterministic Whether determinism is required or not. - * \param[in] handle cuDNN handle used for the support chain. Required. - * \param[out] out_reason Optional. When non-NULL, set to a null-terminated - * diagnostic string describing why the configuration - * was rejected (NVTE_No_Backend) or set to "" on - * success. Rejection messages are tagged with a - * stable category prefix that mirrors - * \c cudnn_frontend::error_code_t, e.g. - * \c "[INVALID_VALUE] ..." for TE post-filter - * rejections and FP8 layout pre-filter rejections, - * \c "[GRAPH_NOT_SUPPORTED] ..." for cuDNN-FE - * rejections forwarded from the support chain. The - * pointer points into per-thread storage owned by TE - * and is valid only until the next call to - * \c nvte_get_fused_attn_backend on the same thread. + * \param[in] handle cuDNN handle. + * \param[out] message Empty string on success, otherwise a diagnostic string + * describing why the configuration was rejected. */ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTEDType o_dtype, @@ -254,7 +232,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic, cudnnHandle_t handle, - const char **out_reason); + const char **message); /*! \brief Compute dot product attention with separate Q, K and V. * diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index fdfa47da8f..fb5096b9a7 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -83,6 +83,13 @@ .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ + pybind11::enum_(m, "NVTEScalingMode", pybind11::module_local()) \ + .value("NVTE_DELAYED_TENSOR_SCALING", NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING) \ + .value("NVTE_MXFP8_1D_SCALING", NVTEScalingMode::NVTE_MXFP8_1D_SCALING) \ + .value("NVTE_BLOCK_SCALING_1D", NVTEScalingMode::NVTE_BLOCK_SCALING_1D) \ + .value("NVTE_BLOCK_SCALING_2D", NVTEScalingMode::NVTE_BLOCK_SCALING_2D) \ + .value("NVTE_NVFP4_1D_SCALING", NVTEScalingMode::NVTE_NVFP4_1D_SCALING) \ + .value("NVTE_INVALID_SCALING", NVTEScalingMode::NVTE_INVALID_SCALING); \ pybind11::enum_( \ m, "Float8BlockScaleTensorFormat", pybind11::module_local()) \ .value("GEMM_READY", transformer_engine::Float8BlockScaleTensorFormat::GEMM_READY) \ diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 40d02f40e1..2a38e5f6bd 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -16,7 +16,7 @@ from jax.experimental.custom_partitioning import SdyShardingRule import transformer_engine_jax -from transformer_engine_jax import NVTE_Fused_Attn_Backend +from transformer_engine_jax import NVTE_Fused_Attn_Backend, NVTEScalingMode from transformer_engine.jax.attention import ( AttnBiasType, AttnMaskType, @@ -125,14 +125,22 @@ class FusedAttnHelper: def is_fused_attn_kernel_available(self): """Check if there is available fused attention kernel""" - return self.get_fused_attn_backend() != NVTE_Fused_Attn_Backend.NVTE_No_Backend + backend, _ = self.get_fused_attn_backend() + return backend != NVTE_Fused_Attn_Backend.NVTE_No_Backend def get_fused_attn_backend(self): - """Get the fused attention kernel backend""" + """Get the fused attention kernel backend. + + Returns a ``(backend, message)`` tuple. ``message`` is empty on success, otherwise a + diagnostic string describing why the configuration was rejected when backend = NVTE_No_Backend. + """ + q_type = jax_dtype_to_te_dtype(self.q_dtype) return transformer_engine_jax.get_fused_attn_backend( self.is_training, - jax_dtype_to_te_dtype(self.q_dtype), + q_type, jax_dtype_to_te_dtype(self.kv_dtype), + q_type, + NVTEScalingMode.NVTE_INVALID_SCALING, self.qkv_layout.value, self.attn_bias_type.value, self.attn_mask_type.value, @@ -335,7 +343,7 @@ def abstract( out_aval = q_aval.update(shape=output_shape, dtype=q_dtype) # backend determines the softmax buffer shape/dtype - backend = FusedAttnHelper( + backend, message = FusedAttnHelper( config.is_training, q_dtype, k_dtype, @@ -372,7 +380,7 @@ def abstract( ) softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) else: - raise ValueError(f"Unsupported {backend=}") + raise ValueError(f"Unsupported backend: {message}") softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype) # JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 2ecfedc8a2..629b6dc3bf 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include "common/common.h" @@ -146,12 +147,14 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler); -NVTE_Fused_Attn_Backend GetFusedAttnBackend( - bool is_training, DType q_dtype, DType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, - int64_t window_size_right, bool deterministic); +// Returns (backend, message). `message` is empty on success, otherwise a diagnostic string +// describing why the configuration was rejected when backend = NVTE_No_Backend. +std::tuple GetFusedAttnBackend( + bool is_training, DType q_dtype, DType kv_dtype, DType o_dtype, NVTEScalingMode scaling_mode, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, float dropout_probability, size_t q_attn_heads, + size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, + size_t v_head_dim, int64_t window_size_left, int64_t window_size_right, bool deterministic); pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 669570daa5..83bddcabb1 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -12,21 +12,21 @@ namespace transformer_engine { namespace jax { -NVTE_Fused_Attn_Backend GetFusedAttnBackend( - bool is_training, DType q_dtype, DType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, - int64_t window_size_right, bool deterministic) { +std::tuple GetFusedAttnBackend( + bool is_training, DType q_dtype, DType kv_dtype, DType o_dtype, NVTEScalingMode scaling_mode, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, float dropout_probability, size_t q_attn_heads, + size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, + size_t v_head_dim, int64_t window_size_left, int64_t window_size_right, bool deterministic) { auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); + const char *message = nullptr; auto backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), - static_cast(q_dtype), NVTE_DELAYED_TENSOR_SCALING, qkv_layout, bias_type, - mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, - kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - /*return_max_logit=*/false, /*cuda_graph=*/false, deterministic, handle, - /*out_reason=*/nullptr); - return backend; + static_cast(o_dtype), scaling_mode, qkv_layout, bias_type, mask_type, softmax_type, + dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen, qk_head_dim, + v_head_dim, window_size_left, window_size_right, + /*return_max_logit=*/false, /*cuda_graph=*/false, deterministic, handle, &message); + return {backend, message ? std::string(message) : std::string()}; } /* @@ -279,10 +279,10 @@ static void FusedAttnForwardImpl( auto _handle_fwd = cudnnExecutionPlanManager::Instance().GetHandle(); auto backend = nvte_get_fused_attn_backend( is_training, static_cast(dtype), static_cast(dtype), - static_cast(dtype), NVTE_DELAYED_TENSOR_SCALING, qkv_layout, bias_type, mask_type, + static_cast(dtype), NVTE_INVALID_SCALING, qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, /*return_max_logit=*/false, - /*cuda_graph=*/false, deterministic, _handle_fwd, /*out_reason=*/nullptr); + /*cuda_graph=*/false, deterministic, _handle_fwd, /*message=*/nullptr); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); /* Auxiliary tensors (to be propagated to the backward pass later) */ @@ -557,10 +557,10 @@ static void FusedAttnBackwardImpl( auto _handle_bwd = cudnnExecutionPlanManager::Instance().GetHandle(); auto backend = nvte_get_fused_attn_backend( is_training, static_cast(dtype), static_cast(dtype), - static_cast(dtype), NVTE_DELAYED_TENSOR_SCALING, qkv_layout, bias_type, mask_type, + static_cast(dtype), NVTE_INVALID_SCALING, qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, /*return_max_logit=*/false, - /*cuda_graph=*/false, deterministic, _handle_bwd, /*out_reason=*/nullptr); + /*cuda_graph=*/false, deterministic, _handle_bwd, /*message=*/nullptr); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, softmax_aux, rng_state, bias, softmax_offset); diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index ed87423534..f236d5a26c 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -1229,10 +1229,12 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if fp8 and fp8_meta["recipe"].fp8_dpa: q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) kv_type = q_type - fused_attention_backend = tex.get_fused_attn_backend( + fused_attention_backend, reject_message = tex.get_fused_attn_backend( is_training, q_type, kv_type, + q_type, + tex.NVTEScalingMode.NVTE_INVALID_SCALING, QKVLayout[qkv_layout], AttnBiasType[fu_core_attention_bias_type], AttnMaskType[attn_mask_type], @@ -1251,7 +1253,10 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt deterministic, ) if fused_attention_backend == FusedAttnBackend["No_Backend"]: - logger.debug("Disabling FusedAttention as no backend supports the provided input") + logger.debug( + "Disabling FusedAttention as %s", + reject_message, + ) use_fused_attention = False fused_attention_backend = None if ( diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 4a2ea7412b..733f98e575 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -75,12 +75,15 @@ std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::T * Attention **************************************************************************************************/ -NVTE_Fused_Attn_Backend get_fused_attn_backend( - bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic); +// Returns (backend, reason). `reason` is empty on success, otherwise a diagnostic string +// describing why the configuration was rejected when backend = NVTE_No_Backend. +std::tuple get_fused_attn_backend( + bool is_training, const DType q_dtype, const DType kv_dtype, const DType o_dtype, + NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float p_dropout, + size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, + size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, + bool return_max_logit, bool cuda_graph, bool deterministic); std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 3af4ba3831..0c5f99ef33 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -41,26 +41,22 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s namespace transformer_engine::pytorch { // get the fused attention backend -// -// NOTE: the underlying nvte_get_fused_attn_backend now takes o_dtype and scaling_mode in -// addition to q_dtype/kv_dtype. For the F16/BF16 routing path those are ignored, so we pass -// q_dtype as o_dtype and DELAYED_TENSOR_SCALING. This Python-facing wrapper therefore keeps -// its existing signature; FP8 callers that want authoritative routing for non-default scaling -// recipes should add o_dtype / scaling_mode parameters in a follow-up. -NVTE_Fused_Attn_Backend get_fused_attn_backend( - bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic) { +std::tuple get_fused_attn_backend( + bool is_training, const DType q_dtype, const DType kv_dtype, const DType o_dtype, + NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float p_dropout, + size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, + size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, + bool return_max_logit, bool cuda_graph, bool deterministic) { auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); + const char *message = nullptr; NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), - static_cast(q_dtype), NVTE_DELAYED_TENSOR_SCALING, qkv_layout, bias_type, - attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, - max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right, return_max_logit, - cuda_graph, deterministic, handle, /*out_reason=*/nullptr); - return fused_attention_backend; + static_cast(o_dtype), scaling_mode, qkv_layout, bias_type, attn_mask_type, + softmax_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, + head_dim_qk, head_dim_v, window_size_left, window_size_right, return_max_logit, cuda_graph, + deterministic, handle, &message); + return {fused_attention_backend, message ? std::string(message) : std::string()}; } // helper function for S and dP quantizers From 81e59a9c86d8a6de3686244e7a1801e5cd3db487 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 8 May 2026 00:11:48 +0000 Subject: [PATCH 05/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/fused_attn/fused_attn.cpp | 15 ++++---- .../fused_attn_f16_arbitrary_seqlen.cu | 36 ++++++++++--------- .../fused_attn_f16_arbitrary_seqlen.h | 27 +++++++------- .../common/fused_attn/fused_attn_fp8.cu | 34 +++++++++--------- .../common/fused_attn/fused_attn_fp8.h | 30 ++++++++-------- 5 files changed, 74 insertions(+), 68 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 21b0d80f4f..41607f05f7 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -263,8 +263,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( max_seqlen_kv, head_dim_qk, head_dim_v) == DType::kInt64); if (requires_64bit_ragged_offset && cudnn_runtime_version < 90500) { set_message(message, - "Configuration requires 64-bit ragged offsets, which require " - "cuDNN >= 9.5."); + "Configuration requires 64-bit ragged offsets, which require " + "cuDNN >= 9.5."); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } @@ -274,7 +274,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) { set_message(message, - "THD format requires PADDING / PADDING_CAUSAL / PADDING_CAUSAL_BOTTOM_RIGHT mask."); + "THD format requires PADDING / PADDING_CAUSAL / PADDING_CAUSAL_BOTTOM_RIGHT mask."); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } @@ -285,8 +285,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) { - set_message(message, - "Known cuDNN <= 9.15 issue with CUDA graph. Please upgrade cuDNN."); + set_message(message, "Known cuDNN <= 9.15 issue with CUDA graph. Please upgrade cuDNN."); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } @@ -300,8 +299,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( if (is_fp8) { if (return_max_logit) { - set_message(message, - "FP8 fused attention does not support return_max_logit=True."); + set_message(message, "FP8 fused attention does not support return_max_logit=True."); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } const DType qkv_t = static_cast(q_dtype); @@ -354,8 +352,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( return NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; } - set_message(message, - "Unsupported QKV dtype qkv_dtype=" + std::to_string(q_dtype) + " ."); + set_message(message, "Unsupported QKV dtype qkv_dtype=" + std::to_string(q_dtype) + " ."); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 81e66d8800..3a2b296ffc 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -1334,13 +1334,14 @@ void fused_attn_arbitrary_seqlen_bwd( } } -std::string is_supported_f16_fwd( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, - bool return_max_logit, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool bottom_right_diagonal, DType qkv_dtype, - cudnnHandle_t handle) { +std::string is_supported_f16_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, bool is_training, bool return_max_logit, + float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, + DType qkv_dtype, cudnnHandle_t handle) { const auto b = static_cast(batch); const auto h = static_cast(num_attn_heads); const auto sq = static_cast(max_seqlen_q); @@ -1375,8 +1376,8 @@ std::string is_supported_f16_fwd( fused_attn::fused_attn_arbitrary_seqlen_fwd_impl( b, h, static_cast(num_gqa_groups), sq, skv, static_cast(head_dim_qk), static_cast(head_dim_v), max_b, max_t_q, max_t_kv, num_pages_k, num_pages_v, - page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, - bias_sq, bias_skv, is_training, return_max_logit, + page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, + bias_skv, is_training, return_max_logit, /*scaling_factor=*/1.0f, p_dropout, qkv_layout, o_format, bias_type, mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, /*devPtrQ=*/nullptr, /*devPtrK=*/nullptr, /*devPtrV=*/nullptr, /*devPtrBias=*/nullptr, @@ -1396,12 +1397,13 @@ std::string is_supported_f16_fwd( } } -std::string is_supported_f16_bwd( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic, DType qkv_dtype, cudnnHandle_t handle) { +std::string is_supported_f16_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, DType qkv_dtype, cudnnHandle_t handle) { const auto b = static_cast(batch); const auto h = static_cast(num_attn_heads); const auto sq = static_cast(max_seqlen_q); @@ -1430,8 +1432,8 @@ std::string is_supported_f16_bwd( fused_attn::fused_attn_arbitrary_seqlen_bwd_impl( b, h, static_cast(num_gqa_groups), sq, skv, static_cast(head_dim_qk), static_cast(head_dim_v), max_b, max_t_q, max_t_kv, bias_b, bias_h, bias_sq, - bias_skv, /*scaling_factor=*/1.0f, p_dropout, qkv_layout, o_format, do_format, - dqkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, + bias_skv, /*scaling_factor=*/1.0f, p_dropout, qkv_layout, o_format, do_format, dqkv_layout, + bias_type, mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, /*devPtrQ=*/nullptr, /*devPtrKTranspose=*/nullptr, /*devPtrVTranspose=*/nullptr, /*devPtrO=*/nullptr, /*devPtrSoftmaxStats=*/nullptr, /*devPtrBias=*/nullptr, /*devPtrSoftmaxOffset=*/nullptr, /*devPtrdQ=*/nullptr, diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 3f5ae717bb..fe94d0c10c 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -52,22 +52,25 @@ void fused_attn_arbitrary_seqlen_bwd( // check if a given configuration is supported for F16/BF16 forward; // if it is, cache the graph built for this config, and return an empty string; // if not, return a diagnostic message in the form of a string. -std::string is_supported_f16_fwd( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, - bool return_max_logit, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool bottom_right_diagonal, DType qkv_dtype, cudnnHandle_t handle); +std::string is_supported_f16_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, bool is_training, bool return_max_logit, + float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, + DType qkv_dtype, cudnnHandle_t handle); // check if a given configuration is supported for F16/BF16 backward; // if it is, cache the graph built for this config, and return an empty string; // if not, return a diagnostic message in the form of a string. -std::string is_supported_f16_bwd( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic, DType qkv_dtype, cudnnHandle_t handle); +std::string is_supported_f16_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, DType qkv_dtype, cudnnHandle_t handle); } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index db4f25c05f..40b4aa4299 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1325,13 +1325,14 @@ void fused_attn_fp8_bwd( } } -std::string is_supported_fp8_fwd( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, DType qkv_dtype, DType o_dtype, NVTEScalingMode scaling_mode, - cudnnHandle_t handle) { +std::string is_supported_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, bool is_training, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, DType qkv_dtype, DType o_dtype, + NVTEScalingMode scaling_mode, cudnnHandle_t handle) { const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); if (qkv_format != NVTE_QKV_Format::NVTE_BSHD && qkv_format != NVTE_QKV_Format::NVTE_SBHD && qkv_format != NVTE_QKV_Format::NVTE_BHSD) { @@ -1352,8 +1353,8 @@ std::string is_supported_fp8_fwd( /*devPtrDescaleS=*/nullptr, /*devPtrScaleS=*/nullptr, /*devPtrScaleO=*/nullptr, /*devPtrAmaxO=*/nullptr, /*devPtrAmaxS=*/nullptr, /*devPtrcuSeqlensQ=*/nullptr, /*devPtrcuSeqlensKV=*/nullptr, /*devPtrDropoutSeed=*/nullptr, - /*devPtrDropoutOffset=*/nullptr, get_cudnn_fe_dtype(qkv_dtype), - get_cudnn_fe_dtype(o_dtype), scaling_mode, + /*devPtrDropoutOffset=*/nullptr, get_cudnn_fe_dtype(qkv_dtype), get_cudnn_fe_dtype(o_dtype), + scaling_mode, /*qkv_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, /*workspace=*/nullptr, &workspace_size, /*stream=*/static_cast(0), handle); @@ -1365,13 +1366,14 @@ std::string is_supported_fp8_fwd( } } -std::string is_supported_fp8_bwd( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic, DType qkv_dtype, DType o_dtype, - NVTEScalingMode scaling_mode, cudnnHandle_t handle) { +std::string is_supported_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, DType qkv_dtype, DType o_dtype, + NVTEScalingMode scaling_mode, cudnnHandle_t handle) { const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); if (qkv_format != NVTE_QKV_Format::NVTE_BSHD && qkv_format != NVTE_QKV_Format::NVTE_SBHD && qkv_format != NVTE_QKV_Format::NVTE_BHSD) { diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 96f5d54968..d52dfd246b 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -45,22 +45,24 @@ void fused_attn_fp8_bwd( // check if a given configuration is supported for FP8 forward; // if it is, cache the graph built for this config, and return an empty string; // if not, return a diagnostic message in the form of a string. -std::string is_supported_fp8_fwd( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, DType qkv_dtype, DType o_dtype, NVTEScalingMode scaling_mode, - cudnnHandle_t handle); +std::string is_supported_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, bool is_training, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, DType qkv_dtype, DType o_dtype, + NVTEScalingMode scaling_mode, cudnnHandle_t handle); // check if a given configuration is supported for FP8 backward; // if it is, cache the graph built for this config, and return an empty string; // if not, return a diagnostic message in the form of a string. -std::string is_supported_fp8_bwd( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic, DType qkv_dtype, DType o_dtype, - NVTEScalingMode scaling_mode, cudnnHandle_t handle); +std::string is_supported_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, DType qkv_dtype, DType o_dtype, + NVTEScalingMode scaling_mode, cudnnHandle_t handle); } // namespace transformer_engine From 6c5126db51cf52657eafa01d853503fd254113e2 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 7 May 2026 17:22:09 -0700 Subject: [PATCH 06/23] fix compilation errors Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/common/fused_attn/fused_attn_fp8.cu | 6 ++++-- transformer_engine/common/fused_attn/fused_attn_fp8.h | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 40b4aa4299..842e3958bc 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1336,7 +1336,8 @@ std::string is_supported_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); if (qkv_format != NVTE_QKV_Format::NVTE_BSHD && qkv_format != NVTE_QKV_Format::NVTE_SBHD && qkv_format != NVTE_QKV_Format::NVTE_BHSD) { - return "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " + qkv_format + "."; + return "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " + + std::to_string(static_cast(qkv_format)) + "."; } size_t workspace_size = 0; try { @@ -1377,7 +1378,8 @@ std::string is_supported_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); if (qkv_format != NVTE_QKV_Format::NVTE_BSHD && qkv_format != NVTE_QKV_Format::NVTE_SBHD && qkv_format != NVTE_QKV_Format::NVTE_BHSD) { - return "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " + qkv_format + "."; + return "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " + + std::to_string(static_cast(qkv_format)) + "."; } const cudnn_frontend::DataType_t qkv_t = get_cudnn_fe_dtype(qkv_dtype); const cudnn_frontend::DataType_t o_t = get_cudnn_fe_dtype(o_dtype); diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index d52dfd246b..21487898a6 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -36,8 +36,8 @@ void fused_attn_fp8_bwd( NVTE_Softmax_Type softmax_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_dO_f16, const Tensor *input_M, const Tensor *input_ZInv, - const Tensor *input_S, const Tensor *input_SoftmaxOffset, Tensor *input_output_dP, + const Tensor *input_dO_f16, const Tensor *input_M, const Tensor *input_S, + const Tensor *input_SoftmaxOffset, Tensor *input_output_dP, const Tensor *output_dQ, const Tensor *output_dK, const Tensor *output_dV, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); From d35bff72911e239577e63aa99db96db7571dbb97 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 8 May 2026 00:23:00 +0000 Subject: [PATCH 07/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/fused_attn/fused_attn_fp8.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 21487898a6..01c7561402 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -37,10 +37,10 @@ void fused_attn_fp8_bwd( bool bottom_right_diagonal, bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_dO_f16, const Tensor *input_M, const Tensor *input_S, - const Tensor *input_SoftmaxOffset, Tensor *input_output_dP, - const Tensor *output_dQ, const Tensor *output_dK, const Tensor *output_dV, - Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + const Tensor *input_SoftmaxOffset, Tensor *input_output_dP, const Tensor *output_dQ, + const Tensor *output_dK, const Tensor *output_dV, Tensor *output_dSoftmaxOffset, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); // check if a given configuration is supported for FP8 forward; // if it is, cache the graph built for this config, and return an empty string; From f6fc58568823aae209065c387f55d59c44d2dd4f Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 7 May 2026 18:08:56 -0700 Subject: [PATCH 08/23] remove handle from API; add bottom_right_diagonal Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/jax/test_fused_attn.py | 1 + .../common/fused_attn/fused_attn.cpp | 49 ++++++++++-------- .../common/fused_attn/fused_attn_fp8.cu | 10 ---- .../include/transformer_engine/fused_attn.h | 51 ++++++++++--------- transformer_engine/jax/attention.py | 12 ++++- .../jax/cpp_extensions/attention.py | 3 ++ transformer_engine/jax/csrc/extensions.h | 3 +- .../jax/csrc/extensions/attention.cpp | 19 +++---- .../attention/dot_product_attention/utils.py | 3 +- transformer_engine/pytorch/csrc/extensions.h | 2 +- .../pytorch/csrc/extensions/attention.cpp | 8 ++- 11 files changed, 85 insertions(+), 76 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index e8da8c7366..7dc7cc4c97 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -460,6 +460,7 @@ def _check_configs(self): self.head_dim_qk, self.head_dim_v, (-1, -1) if self.window_size is None else self.window_size, + self.attn_mask_type.is_bottom_right(), ).get_fused_attn_backend() if self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: pytest.skip(message) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 41607f05f7..9c5b91d1fc 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -246,12 +246,13 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, - bool return_max_logit, bool cuda_graph, bool deterministic, cudnnHandle_t handle, + bool bottom_right_diagonal, bool return_max_logit, bool cuda_graph, bool deterministic, const char **message) { using namespace transformer_engine; set_message(message, ""); NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type."); + cudnnHandle_t handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); const auto cudnn_runtime_version = cudnnGetVersion(); @@ -278,19 +279,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } - // avoid CUDA graph issue with cuDNN <= 9.15 - if (cudnn_runtime_version <= 91500 && is_training && - (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && - (max_seqlen_kv % 128 != 0) && cuda_graph && - attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && - attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && - attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) { - set_message(message, "Known cuDNN <= 9.15 issue with CUDA graph. Please upgrade cuDNN."); - return NVTE_Fused_Attn_Backend::NVTE_No_Backend; - } - + // Use batch=1 for the probe to keep graph caches minimal; batch is not part of cuDNN-FE's + // support-check criteria. All other params are passed through verbatim so the cached graph + // matches what the eventual nvte_fused_attn_fwd/bwd call will build. constexpr size_t probe_batch = 1; - constexpr bool probe_bottom_right_diagonal = false; const bool is_fp8 = (q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2); @@ -302,12 +294,18 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( set_message(message, "FP8 fused attention does not support return_max_logit=True."); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } + if (qkv_format != NVTE_QKV_Format::NVTE_BSHD && qkv_format != NVTE_QKV_Format::NVTE_SBHD && + qkv_format != NVTE_QKV_Format::NVTE_BHSD) { + set_message(message, "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " + + std::to_string(static_cast(qkv_format)) + "."); + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; + } const DType qkv_t = static_cast(q_dtype); const DType o_t = static_cast(o_dtype); std::string fwd_reason = is_supported_fp8_fwd( probe_batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, is_training, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, probe_bottom_right_diagonal, qkv_t, o_t, scaling_mode, + window_size_left, window_size_right, bottom_right_diagonal, qkv_t, o_t, scaling_mode, handle); if (!fwd_reason.empty()) { set_message(message, fwd_reason); @@ -317,7 +315,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( std::string bwd_reason = is_supported_fp8_bwd( probe_batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, probe_bottom_right_diagonal, deterministic, qkv_t, + window_size_left, window_size_right, bottom_right_diagonal, deterministic, qkv_t, o_t, scaling_mode, handle); if (!bwd_reason.empty()) { set_message(message, bwd_reason); @@ -328,11 +326,20 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } if (is_f16_or_bf16) { + if (cudnn_runtime_version <= 91500 && is_training && + (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && + (max_seqlen_kv % 128 != 0) && cuda_graph && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) { + set_message(message, "Known cuDNN <= 9.15 issue with CUDA graph. Please upgrade cuDNN."); + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; + } const DType qkv_t = static_cast(q_dtype); std::string fwd_reason = is_supported_f16_fwd( probe_batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, is_training, return_max_logit, dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size_left, window_size_right, probe_bottom_right_diagonal, qkv_t, + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, qkv_t, handle); if (!fwd_reason.empty()) { set_message(message, fwd_reason); @@ -342,7 +349,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( std::string bwd_reason = is_supported_f16_bwd( probe_batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, probe_bottom_right_diagonal, deterministic, qkv_t, + window_size_left, window_size_right, bottom_right_diagonal, deterministic, qkv_t, handle); if (!bwd_reason.empty()) { set_message(message, bwd_reason); @@ -444,8 +451,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, Q_type, KV_type, O_type, scaling_mode, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, - window_size_right, return_max_logit, cuda_graph, /*deterministic=*/false, handle, - /*message=*/nullptr); + window_size_right, bottom_right_diagonal, return_max_logit, cuda_graph, + /*deterministic=*/false, /*message=*/nullptr); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { fused_attn_arbitrary_seqlen_fwd( @@ -528,8 +535,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( /*is_training=*/true, Q_type, KV_type, O_type, scaling_mode, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, - window_size_left, window_size_right, /*return_max_logit=*/false, cuda_graph, deterministic, - handle, /*message=*/nullptr); + window_size_left, window_size_right, bottom_right_diagonal, /*return_max_logit=*/false, + cuda_graph, deterministic, /*message=*/nullptr); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { size_t i = 0; diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 842e3958bc..f4064a8d34 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1334,11 +1334,6 @@ std::string is_supported_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num bool bottom_right_diagonal, DType qkv_dtype, DType o_dtype, NVTEScalingMode scaling_mode, cudnnHandle_t handle) { const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if (qkv_format != NVTE_QKV_Format::NVTE_BSHD && qkv_format != NVTE_QKV_Format::NVTE_SBHD && - qkv_format != NVTE_QKV_Format::NVTE_BHSD) { - return "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " + - std::to_string(static_cast(qkv_format)) + "."; - } size_t workspace_size = 0; try { fused_attn::fused_attn_fp8_fwd_impl( @@ -1376,11 +1371,6 @@ std::string is_supported_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num bool deterministic, DType qkv_dtype, DType o_dtype, NVTEScalingMode scaling_mode, cudnnHandle_t handle) { const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if (qkv_format != NVTE_QKV_Format::NVTE_BSHD && qkv_format != NVTE_QKV_Format::NVTE_SBHD && - qkv_format != NVTE_QKV_Format::NVTE_BHSD) { - return "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " + - std::to_string(static_cast(qkv_format)) + "."; - } const cudnn_frontend::DataType_t qkv_t = get_cudnn_fe_dtype(qkv_dtype); const cudnn_frontend::DataType_t o_t = get_cudnn_fe_dtype(o_dtype); const cudnn_frontend::DataType_t do_t = o_t; diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 9bcbcc5716..85e3ea68ed 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -198,30 +198,31 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); /*! \brief Get fused attention backend based on input parameters. * - * \param[in] is_training Whether the model is in training mode. - * \param[in] q_dtype The data type of Tensor Q. - * \param[in] kv_dtype The data type of Tensors K, V. - * \param[in] o_dtype The data type of Tensor O. - * \param[in] scaling_mode Scaling mode of attention. - * \param[in] qkv_layout The layout of Tensors Q, K, V. - * \param[in] bias_type The attention bias type. - * \param[in] attn_mask_type The attention mask type. - * \param[in] softmax_type The attention softmax type. - * \param[in] dropout The dropout probability. - * \param[in] num_attn_heads The number of heads in Q. - * \param[in] num_gqa_groups The number of heads in K, V. - * \param[in] max_seqlen_q The sequence length of Q. - * \param[in] max_seqlen_kv The sequence length of K, V. - * \param[in] head_dim_qk The head dimension of Q, K. - * \param[in] head_dim_v The head dimension of V. - * \param[in] window_size_left Sliding window size (the left half). - * \param[in] window_size_right Sliding window size (the right half). - * \param[in] return_max_logit Whether to produce Max along with Stats. - * \param[in] cuda_graph Whether cuda graph capture is enabled or not. - * \param[in] deterministic Whether determinism is required or not. - * \param[in] handle cuDNN handle. - * \param[out] message Empty string on success, otherwise a diagnostic string - * describing why the configuration was rejected. + * \param[in] is_training Whether the model is in training mode. + * \param[in] q_dtype The data type of Tensor Q. + * \param[in] kv_dtype The data type of Tensors K, V. + * \param[in] o_dtype The data type of Tensor O. + * \param[in] scaling_mode Scaling mode of attention. + * \param[in] qkv_layout The layout of Tensors Q, K, V. + * \param[in] bias_type The attention bias type. + * \param[in] attn_mask_type The attention mask type. + * \param[in] softmax_type The attention softmax type. + * \param[in] dropout The dropout probability. + * \param[in] num_attn_heads The number of heads in Q. + * \param[in] num_gqa_groups The number of heads in K, V. + * \param[in] max_seqlen_q The sequence length of Q. + * \param[in] max_seqlen_kv The sequence length of K, V. + * \param[in] head_dim_qk The head dimension of Q, K. + * \param[in] head_dim_v The head dimension of V. + * \param[in] window_size_left Sliding window size (the left half). + * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the + * bottom right corner of the softmax matrix. + * \param[in] return_max_logit Whether to produce Max along with Stats. + * \param[in] cuda_graph Whether cuda graph capture is enabled or not. + * \param[in] deterministic Whether determinism is required or not. + * \param[out] message Empty string on success, otherwise a diagnostic string + * describing why the configuration was rejected. */ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTEDType o_dtype, @@ -229,7 +230,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, - bool return_max_logit, bool cuda_graph, bool deterministic, cudnnHandle_t handle, + bool bottom_right_diagonal, bool return_max_logit, bool cuda_graph, bool deterministic, const char **message); /*! \brief Compute dot product attention with separate Q, K and V. diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index f54a043fd2..d0e125297f 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -339,13 +339,22 @@ def is_fused_attn_kernel_available( head_dim_qk, head_dim_v, window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, ): """ - To check whether the fused attention kernel is supported + To check whether the fused attention kernel is supported. + + If ``bottom_right_diagonal`` is None, it is derived from the mask type, matching the + convention used everywhere else in JAX TE (see ``_FusedAttnConfig`` constructions). """ window_size_tuple = (-1, -1) if window_size is None else window_size def make_helper(attn_mask_type): + bottom_right = ( + attn_mask_type.is_bottom_right() + if bottom_right_diagonal is None + else bottom_right_diagonal + ) return tex.FusedAttnHelper( is_training, q_dtype, @@ -362,6 +371,7 @@ def make_helper(attn_mask_type): head_dim_qk, head_dim_v, window_size_tuple, + bottom_right, ) return make_helper(attn_mask_type).is_fused_attn_kernel_available() diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 1631afe4f4..e6cbb10e44 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -122,6 +122,7 @@ class FusedAttnHelper: head_dim_qk: int head_dim_v: int window_size: Tuple[int, int] + bottom_right_diagonal: bool def is_fused_attn_kernel_available(self): """Check if there is available fused attention kernel""" @@ -154,6 +155,7 @@ def get_fused_attn_backend(self): self.head_dim_v, self.window_size[0], self.window_size[1], + self.bottom_right_diagonal, not self.is_non_deterministic_allowed(), ) @@ -359,6 +361,7 @@ def abstract( q_head_dim, v_head_dim, config.window_size, + config.bottom_right_diagonal, ).get_fused_attn_backend() if backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 629b6dc3bf..d958193a7d 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -154,7 +154,8 @@ std::tuple GetFusedAttnBackend( NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, - size_t v_head_dim, int64_t window_size_left, int64_t window_size_right, bool deterministic); + size_t v_head_dim, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic); pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 4ce09368d8..d5673df8a5 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -5,7 +5,6 @@ ************************************************************************/ #include "../extensions.h" -#include "common/cudnn_utils.h" #include "transformer_engine/fused_attn.h" #include "transformer_engine/transformer_engine.h" @@ -17,15 +16,15 @@ std::tuple GetFusedAttnBackend( NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, - size_t v_head_dim, int64_t window_size_left, int64_t window_size_right, bool deterministic) { - auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); + size_t v_head_dim, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic) { const char *message = nullptr; auto backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), static_cast(o_dtype), scaling_mode, qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen, qk_head_dim, - v_head_dim, window_size_left, window_size_right, - /*return_max_logit=*/false, /*cuda_graph=*/false, deterministic, handle, &message); + v_head_dim, window_size_left, window_size_right, bottom_right_diagonal, + /*return_max_logit=*/false, /*cuda_graph=*/false, deterministic, &message); return {backend, message ? std::string(message) : std::string()}; } @@ -265,13 +264,12 @@ static void FusedAttnForwardImpl( /* Prepare RNG state */ auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); - auto _handle_fwd = cudnnExecutionPlanManager::Instance().GetHandle(); auto backend = nvte_get_fused_attn_backend( is_training, static_cast(dtype), static_cast(dtype), static_cast(dtype), NVTE_INVALID_SCALING, qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, - qk_head_dim, v_head_dim, window_size_left, window_size_right, /*return_max_logit=*/false, - /*cuda_graph=*/false, deterministic, _handle_fwd, /*message=*/nullptr); + qk_head_dim, v_head_dim, window_size_left, window_size_right, bottom_right_diagonal, + /*return_max_logit=*/false, /*cuda_graph=*/false, deterministic, /*message=*/nullptr); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); /* Auxiliary tensors (to be propagated to the backward pass later) */ @@ -543,13 +541,12 @@ static void FusedAttnBackwardImpl( /* Auxiliary tensors (propagated from the forward pass) */ NVTETensorPack aux_input_tensors; nvte_tensor_pack_create(&aux_input_tensors); - auto _handle_bwd = cudnnExecutionPlanManager::Instance().GetHandle(); auto backend = nvte_get_fused_attn_backend( is_training, static_cast(dtype), static_cast(dtype), static_cast(dtype), NVTE_INVALID_SCALING, qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, - qk_head_dim, v_head_dim, window_size_left, window_size_right, /*return_max_logit=*/false, - /*cuda_graph=*/false, deterministic, _handle_bwd, /*message=*/nullptr); + qk_head_dim, v_head_dim, window_size_left, window_size_right, bottom_right_diagonal, + /*return_max_logit=*/false, /*cuda_graph=*/false, deterministic, /*message=*/nullptr); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, softmax_aux, rng_state, bias, softmax_offset); diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 7f97a1e0f2..38542586d2 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -1244,13 +1244,14 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt head_dim_v, window_size[0], window_size[1], + bottom_right_diagonal, return_max_logit, cuda_graph, deterministic, ) if fused_attention_backend == FusedAttnBackend["No_Backend"]: logger.debug( - "Disabling FusedAttention as %s", + "Disabling FusedAttention: %s", reject_message, ) use_fused_attention = False diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 733f98e575..016721f8b0 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -83,7 +83,7 @@ std::tuple get_fused_attn_backend( NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, - bool return_max_logit, bool cuda_graph, bool deterministic); + bool bottom_right_diagonal, bool return_max_logit, bool cuda_graph, bool deterministic); std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 4732d47908..2f5c7058c5 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -6,7 +6,6 @@ #include "../extensions.h" #include "common.h" -#include "common/cudnn_utils.h" #include "pybind.h" namespace { @@ -47,15 +46,14 @@ std::tuple get_fused_attn_backend( NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, - bool return_max_logit, bool cuda_graph, bool deterministic) { - auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); + bool bottom_right_diagonal, bool return_max_logit, bool cuda_graph, bool deterministic) { const char *message = nullptr; NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), static_cast(o_dtype), scaling_mode, qkv_layout, bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, - head_dim_qk, head_dim_v, window_size_left, window_size_right, return_max_logit, cuda_graph, - deterministic, handle, &message); + head_dim_qk, head_dim_v, window_size_left, window_size_right, bottom_right_diagonal, + return_max_logit, cuda_graph, deterministic, &message); return {fused_attention_backend, message ? std::string(message) : std::string()}; } From 3e666b0c59d90a2af5ab4be38b892e1549aefd91 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 7 May 2026 18:30:39 -0700 Subject: [PATCH 09/23] add batch_size to API Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/jax/test_distributed_fused_attn.py | 3 +++ tests/jax/test_fused_attn.py | 1 + .../common/fused_attn/fused_attn.cpp | 19 +++++++------------ .../include/transformer_engine/fused_attn.h | 3 ++- transformer_engine/jax/attention.py | 2 ++ .../jax/cpp_extensions/attention.py | 4 ++++ transformer_engine/jax/csrc/extensions.h | 10 +++++----- .../jax/csrc/extensions/attention.cpp | 16 ++++++++-------- transformer_engine/jax/flax/transformer.py | 3 +++ .../attention/dot_product_attention/utils.py | 1 + transformer_engine/pytorch/csrc/extensions.h | 13 +++++++------ .../pytorch/csrc/extensions/attention.cpp | 15 ++++++++------- 12 files changed, 51 insertions(+), 39 deletions(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 50c5de1db7..39efabc598 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -75,6 +75,7 @@ def impl_test_self_attn( if not is_fused_attn_kernel_available( is_training, + batch, dtype, dtype, QKVLayout.BS3HD, @@ -227,6 +228,7 @@ def test_cross_attn( if not is_fused_attn_kernel_available( is_training, + batch, dtype, dtype, QKVLayout.BSHD_BS2HD, @@ -368,6 +370,7 @@ def impl_test_context_parallel_attn( def check_has_backend_for_mask(mask_type): return is_fused_attn_kernel_available( is_training, + batch, dtype, dtype, qkv_layout, diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 7dc7cc4c97..88c485db81 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -446,6 +446,7 @@ def _check_configs(self): self.backend, message = FusedAttnHelper( self.is_training, + self.batch_size, self.dtype, self.dtype, self.qkv_layout, diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 9c5b91d1fc..e0d524c783 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -241,7 +241,7 @@ void set_message(const char **message, const std::string &reason) { // select a backend for fused attention NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( - bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTEDType o_dtype, + bool is_training, size_t batch_size, NVTEDType q_dtype, NVTEDType kv_dtype, NVTEDType o_dtype, NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, @@ -279,11 +279,6 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } - // Use batch=1 for the probe to keep graph caches minimal; batch is not part of cuDNN-FE's - // support-check criteria. All other params are passed through verbatim so the cached graph - // matches what the eventual nvte_fused_attn_fwd/bwd call will build. - constexpr size_t probe_batch = 1; - const bool is_fp8 = (q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2); const bool is_f16_or_bf16 = @@ -303,7 +298,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( const DType qkv_t = static_cast(q_dtype); const DType o_t = static_cast(o_dtype); std::string fwd_reason = is_supported_fp8_fwd( - probe_batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, + batch_size, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, is_training, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, qkv_t, o_t, scaling_mode, handle); @@ -313,7 +308,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } if (is_training) { std::string bwd_reason = is_supported_fp8_bwd( - probe_batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, + batch_size, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, qkv_t, o_t, scaling_mode, handle); @@ -337,7 +332,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } const DType qkv_t = static_cast(q_dtype); std::string fwd_reason = is_supported_f16_fwd( - probe_batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, + batch_size, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, is_training, return_max_logit, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, qkv_t, handle); @@ -347,7 +342,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } if (is_training) { std::string bwd_reason = is_supported_f16_bwd( - probe_batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, + batch_size, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, qkv_t, handle); @@ -449,7 +444,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTEScalingMode scaling_mode = input_Q->scaling_mode; NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, Q_type, KV_type, O_type, scaling_mode, qkv_layout, bias_type, attn_mask_type, + is_training, b, Q_type, KV_type, O_type, scaling_mode, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, bottom_right_diagonal, return_max_logit, cuda_graph, /*deterministic=*/false, /*message=*/nullptr); @@ -533,7 +528,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTEScalingMode scaling_mode = input_Q->scaling_mode; NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - /*is_training=*/true, Q_type, KV_type, O_type, scaling_mode, qkv_layout, bias_type, + /*is_training=*/true, b, Q_type, KV_type, O_type, scaling_mode, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, bottom_right_diagonal, /*return_max_logit=*/false, cuda_graph, deterministic, /*message=*/nullptr); diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 85e3ea68ed..227afed24e 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -199,6 +199,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); /*! \brief Get fused attention backend based on input parameters. * * \param[in] is_training Whether the model is in training mode. + * \param[in] batch_size Batch size. * \param[in] q_dtype The data type of Tensor Q. * \param[in] kv_dtype The data type of Tensors K, V. * \param[in] o_dtype The data type of Tensor O. @@ -225,7 +226,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); * describing why the configuration was rejected. */ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( - bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTEDType o_dtype, + bool is_training, size_t batch_size, NVTEDType q_dtype, NVTEDType kv_dtype, NVTEDType o_dtype, NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index d0e125297f..ac6cf8975c 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -325,6 +325,7 @@ def canonicalize_attn_mask_type(attn_mask_type: str): def is_fused_attn_kernel_available( is_training, + batch_size, q_dtype, kv_dtype, qkv_layout, @@ -357,6 +358,7 @@ def make_helper(attn_mask_type): ) return tex.FusedAttnHelper( is_training, + batch_size, q_dtype, kv_dtype, qkv_layout, diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index e6cbb10e44..2a533c3f3e 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -108,6 +108,7 @@ class FusedAttnHelper: """ is_training: bool + batch_size: int q_dtype: jnp.dtype kv_dtype: jnp.dtype qkv_layout: QKVLayout @@ -138,6 +139,7 @@ def get_fused_attn_backend(self): q_type = jax_dtype_to_te_dtype(self.q_dtype) return transformer_engine_jax.get_fused_attn_backend( self.is_training, + self.batch_size, q_type, jax_dtype_to_te_dtype(self.kv_dtype), q_type, @@ -345,8 +347,10 @@ def abstract( out_aval = q_aval.update(shape=output_shape, dtype=q_dtype) # backend determines the softmax buffer shape/dtype + input_batch = reduce(operator.mul, batch_shape) backend, message = FusedAttnHelper( config.is_training, + input_batch, q_dtype, k_dtype, config.qkv_layout, diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index d958193a7d..1e8d99c3d8 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -150,11 +150,11 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler); // Returns (backend, message). `message` is empty on success, otherwise a diagnostic string // describing why the configuration was rejected when backend = NVTE_No_Backend. std::tuple GetFusedAttnBackend( - bool is_training, DType q_dtype, DType kv_dtype, DType o_dtype, NVTEScalingMode scaling_mode, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, float dropout_probability, size_t q_attn_heads, - size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, - size_t v_head_dim, int64_t window_size_left, int64_t window_size_right, + bool is_training, size_t batch_size, DType q_dtype, DType kv_dtype, DType o_dtype, + NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, float dropout_probability, + size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, + size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic); pybind11::tuple GetFusedAttnForwardWorkspaceSizes( diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index d5673df8a5..5cd3265c3e 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -12,15 +12,15 @@ namespace transformer_engine { namespace jax { std::tuple GetFusedAttnBackend( - bool is_training, DType q_dtype, DType kv_dtype, DType o_dtype, NVTEScalingMode scaling_mode, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, float dropout_probability, size_t q_attn_heads, - size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, - size_t v_head_dim, int64_t window_size_left, int64_t window_size_right, + bool is_training, size_t batch_size, DType q_dtype, DType kv_dtype, DType o_dtype, + NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, float dropout_probability, + size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, + size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic) { const char *message = nullptr; auto backend = nvte_get_fused_attn_backend( - is_training, static_cast(q_dtype), static_cast(kv_dtype), + is_training, batch_size, static_cast(q_dtype), static_cast(kv_dtype), static_cast(o_dtype), scaling_mode, qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, bottom_right_diagonal, @@ -265,7 +265,7 @@ static void FusedAttnForwardImpl( auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); auto backend = nvte_get_fused_attn_backend( - is_training, static_cast(dtype), static_cast(dtype), + is_training, input_batch, static_cast(dtype), static_cast(dtype), static_cast(dtype), NVTE_INVALID_SCALING, qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, bottom_right_diagonal, @@ -542,7 +542,7 @@ static void FusedAttnBackwardImpl( NVTETensorPack aux_input_tensors; nvte_tensor_pack_create(&aux_input_tensors); auto backend = nvte_get_fused_attn_backend( - is_training, static_cast(dtype), static_cast(dtype), + is_training, input_batch, static_cast(dtype), static_cast(dtype), static_cast(dtype), NVTE_INVALID_SCALING, qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, bottom_right_diagonal, diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index a2e7920843..184547aa92 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -748,6 +748,8 @@ def __call__( enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "1")) sequence_dim = 0 if self.transpose_batch_sequence else 1 + batch_dim = 1 - sequence_dim + batch_size = query.shape[batch_dim] seqlen_q = query.shape[sequence_dim] if qkv_layout == QKVLayout.BS3HD: seqlen_kv = seqlen_q @@ -763,6 +765,7 @@ def __call__( has_fused_attn_kernel = is_fused_attn_kernel_available( # This needs to be fixed: TE-Jax has historically correlated training mode with deterministic mode. not deterministic, + batch_size, input_dtype, # self._assert_dtypes enforces Q, K, V, bias to have the same dtype so using input_dtype as kv dtype is sufficient input_dtype, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 38542586d2..52bb687851 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -1227,6 +1227,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt kv_type = q_type fused_attention_backend, reject_message = tex.get_fused_attn_backend( is_training, + batch_size, q_type, kv_type, q_type, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 016721f8b0..205e7eb834 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -78,12 +78,13 @@ std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::T // Returns (backend, reason). `reason` is empty on success, otherwise a diagnostic string // describing why the configuration was rejected when backend = NVTE_No_Backend. std::tuple get_fused_attn_backend( - bool is_training, const DType q_dtype, const DType kv_dtype, const DType o_dtype, - NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float p_dropout, - size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, - size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool return_max_logit, bool cuda_graph, bool deterministic); + bool is_training, size_t batch_size, const DType q_dtype, const DType kv_dtype, + const DType o_dtype, NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, bool return_max_logit, bool cuda_graph, + bool deterministic); std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 2f5c7058c5..41dcd3301a 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -41,15 +41,16 @@ namespace transformer_engine::pytorch { // get the fused attention backend std::tuple get_fused_attn_backend( - bool is_training, const DType q_dtype, const DType kv_dtype, const DType o_dtype, - NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float p_dropout, - size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, - size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool return_max_logit, bool cuda_graph, bool deterministic) { + bool is_training, size_t batch_size, const DType q_dtype, const DType kv_dtype, + const DType o_dtype, NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, bool return_max_logit, bool cuda_graph, + bool deterministic) { const char *message = nullptr; NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, static_cast(q_dtype), static_cast(kv_dtype), + is_training, batch_size, static_cast(q_dtype), static_cast(kv_dtype), static_cast(o_dtype), scaling_mode, qkv_layout, bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right, bottom_right_diagonal, From 056aba6aebbfbab3e580d3155ccb07c076bcf940 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 8 May 2026 01:31:54 +0000 Subject: [PATCH 10/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/fused_attn/fused_attn.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index e0d524c783..628bce1b54 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -310,8 +310,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( std::string bwd_reason = is_supported_fp8_bwd( batch_size, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, bottom_right_diagonal, deterministic, qkv_t, - o_t, scaling_mode, handle); + window_size_left, window_size_right, bottom_right_diagonal, deterministic, qkv_t, o_t, + scaling_mode, handle); if (!bwd_reason.empty()) { set_message(message, bwd_reason); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; @@ -334,8 +334,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( std::string fwd_reason = is_supported_f16_fwd( batch_size, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, is_training, return_max_logit, dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size_left, window_size_right, bottom_right_diagonal, qkv_t, - handle); + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, qkv_t, handle); if (!fwd_reason.empty()) { set_message(message, fwd_reason); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; @@ -344,8 +343,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( std::string bwd_reason = is_supported_f16_bwd( batch_size, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, bottom_right_diagonal, deterministic, qkv_t, - handle); + window_size_left, window_size_right, bottom_right_diagonal, deterministic, qkv_t, handle); if (!bwd_reason.empty()) { set_message(message, bwd_reason); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; From e054863c2fda315bd43866160e188f2be95e4aea Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 7 May 2026 20:34:12 -0700 Subject: [PATCH 11/23] fix jax binding Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/jax/csrc/extensions/pybind.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 70d0403b3e..2d55abedc6 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -206,6 +206,14 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("NVFP4_2D_SCALING", JAXX_Scaling_Mode::NVFP4_2D_SCALING) .export_values(); + pybind11::enum_(m, "NVTEScalingMode", pybind11::module_local()) + .value("NVTE_DELAYED_TENSOR_SCALING", NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING) + .value("NVTE_MXFP8_1D_SCALING", NVTEScalingMode::NVTE_MXFP8_1D_SCALING) + .value("NVTE_BLOCK_SCALING_1D", NVTEScalingMode::NVTE_BLOCK_SCALING_1D) + .value("NVTE_BLOCK_SCALING_2D", NVTEScalingMode::NVTE_BLOCK_SCALING_2D) + .value("NVTE_NVFP4_1D_SCALING", NVTEScalingMode::NVTE_NVFP4_1D_SCALING) + .value("NVTE_INVALID_SCALING", NVTEScalingMode::NVTE_INVALID_SCALING); + pybind11::enum_(m, "JAXX_Quantize_Layout", pybind11::module_local()) .value("ROWWISE", JAXX_Quantize_Layout::ROWWISE) .value("COLWISE", JAXX_Quantize_Layout::COLWISE) From a7fe928eb21df1528fe57af6e87d4ff06dbd1f9b Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 7 May 2026 22:28:35 -0700 Subject: [PATCH 12/23] specify o_dtype for FP8s Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../attention/dot_product_attention/utils.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 52bb687851..db55f2fbd3 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -1222,16 +1222,29 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if use_fused_attention: q_type = TE_DType[qkv_dtype] kv_type = q_type + o_type = q_type + scaling_mode = tex.NVTEScalingMode.NVTE_INVALID_SCALING if fp8 and fp8_meta["recipe"].fp8_dpa: - q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + recipe = fp8_meta["recipe"] + q_type = get_fp8_te_dtype(recipe, fprop_tensor=True) kv_type = q_type + cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1" + if recipe.mxfp8(): + scaling_mode = tex.NVTEScalingMode.NVTE_MXFP8_1D_SCALING + o_type = TE_DType[torch.bfloat16] + elif recipe.float8_current_scaling() and cs_o_in_f16: + scaling_mode = tex.NVTEScalingMode.NVTE_DELAYED_TENSOR_SCALING + o_type = TE_DType[torch.bfloat16] + else: + scaling_mode = tex.NVTEScalingMode.NVTE_DELAYED_TENSOR_SCALING + o_type = q_type fused_attention_backend, reject_message = tex.get_fused_attn_backend( is_training, batch_size, q_type, kv_type, - q_type, - tex.NVTEScalingMode.NVTE_INVALID_SCALING, + o_type, + scaling_mode, QKVLayout[qkv_layout], AttnBiasType[fu_core_attention_bias_type], AttnMaskType[attn_mask_type], From c9b22b5d187f4cce297c54eb29976a2e53f1361a Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 8 May 2026 10:18:09 -0700 Subject: [PATCH 13/23] fix BRCM and custom_fp8 tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 12 ++++++++++++ tests/pytorch/utils.py | 2 ++ 2 files changed, 14 insertions(+) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 32ea1694ee..4c8435f246 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -2570,10 +2570,21 @@ def test_custom_mha_fp8_vs_f16(dtype, model): # Test backend availability is_training = True + fp8_meta = {} + fp8_recipe = recipe.DelayedScaling( + margin=0, + fp8_format=recipe.Format.HYBRID, + amax_history_len=1, + amax_compute_algo="most_recent", + fp8_dpa=True, + ) + fp8_meta["recipe"] = fp8_recipe available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=torch.float8_e4m3fn, qkv_layout="bs3hd", + fp8=True, + fp8_meta=fp8_meta, is_training=is_training, deterministic=_deterministic, ) @@ -2651,6 +2662,7 @@ def _run_custom_mha_fp8(dtype, config, backend): fp8_format=recipe.Format.HYBRID, amax_history_len=1, amax_compute_algo="most_recent", + fp8_dpa=True, ) mha = Custom_MHA_FP8(config).to(dtype=dtype, device="cuda") diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 3b2e50be3f..1169849044 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -275,6 +275,7 @@ def __init__( self.attn_type = "self" if (self.max_seqlen_q == self.max_seqlen_kv) else "cross" self.bias_shape = bias_shape self.window_size = check_set_window_size(self.attn_mask_type, window_size) + self.bottom_right_diagonal = self.attn_mask_type in {"causal_bottom_right", "padding_causal_bottom_right"} self.context_parallel = context_parallel self.cp_comm_type = cp_comm_type self.return_max_logit = return_max_logit @@ -351,6 +352,7 @@ def test(): head_dim_v=config.head_dim_v, attn_mask_type=config.attn_mask_type, window_size=config.window_size, + bottom_right_diagonal=config.bottom_right_diagonal, alibi_slopes_shape=alibi_slopes_shape, core_attention_bias_type=config.attn_bias_type, core_attention_bias_shape=core_attention_bias_shape, From ac44e66beba56f16d1dd4c00d13aff8435025cb2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 8 May 2026 17:19:08 +0000 Subject: [PATCH 14/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 1169849044..240acb1bd0 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -275,7 +275,10 @@ def __init__( self.attn_type = "self" if (self.max_seqlen_q == self.max_seqlen_kv) else "cross" self.bias_shape = bias_shape self.window_size = check_set_window_size(self.attn_mask_type, window_size) - self.bottom_right_diagonal = self.attn_mask_type in {"causal_bottom_right", "padding_causal_bottom_right"} + self.bottom_right_diagonal = self.attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + } self.context_parallel = context_parallel self.cp_comm_type = cp_comm_type self.return_max_logit = return_max_logit From 9131b2de638ae1a40b6bae4681b2cdf73632a33e Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 8 May 2026 12:18:45 -0700 Subject: [PATCH 15/23] add o_format/etc to API and other tweaks Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn.cpp | 65 +++++++++------- .../fused_attn_f16_arbitrary_seqlen.cu | 33 ++++---- .../fused_attn_f16_arbitrary_seqlen.h | 21 +++-- .../common/fused_attn/fused_attn_fp8.cu | 52 +++++++------ .../common/fused_attn/fused_attn_fp8.h | 30 +++++--- .../include/transformer_engine/fused_attn.h | 38 +++++++-- transformer_engine/jax/attention.py | 12 ++- .../jax/cpp_extensions/attention.py | 18 ++++- transformer_engine/jax/csrc/extensions.h | 17 ++-- .../jax/csrc/extensions/attention.cpp | 77 ++++++++++++++----- .../jax/csrc/extensions/pybind.cpp | 6 +- transformer_engine/jax/flax/transformer.py | 6 +- .../attention/dot_product_attention/utils.py | 13 ++++ transformer_engine/pytorch/csrc/extensions.h | 2 + .../pytorch/csrc/extensions/attention.cpp | 9 ++- 15 files changed, 274 insertions(+), 125 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 628bce1b54..47ee2ff9e5 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -231,10 +231,13 @@ namespace { // re-used (cleared + re-populated) on every call to nvte_get_fused_attn_backend on this thread thread_local std::string fused_attn_backend_message_buffer; -void set_message(const char **message, const std::string &reason) { - if (message == nullptr) return; - fused_attn_backend_message_buffer = reason; - *message = fused_attn_backend_message_buffer.c_str(); +// Stash `reason` in the thread-local buffer and, if the caller asked for a diagnostic, +// publish a NUL-terminated pointer to it via `*message`. Safe to call with `message == nullptr`. +void set_message(const char **message, std::string reason) { + fused_attn_backend_message_buffer = std::move(reason); + if (message != nullptr) { + *message = fused_attn_backend_message_buffer.c_str(); + } } } // namespace @@ -242,12 +245,14 @@ void set_message(const char **message, const std::string &reason) { // select a backend for fused attention NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, size_t batch_size, NVTEDType q_dtype, NVTEDType kv_dtype, NVTEDType o_dtype, - NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, - size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, - size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool return_max_logit, bool cuda_graph, bool deterministic, - const char **message) { + NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, + NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + float attn_scale, float dropout, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool return_max_logit, bool cuda_graph, bool deterministic, const char **message) { using namespace transformer_engine; set_message(message, ""); NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type."); @@ -299,21 +304,22 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( const DType o_t = static_cast(o_dtype); std::string fwd_reason = is_supported_fp8_fwd( batch_size, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, - head_dim_v, is_training, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, bottom_right_diagonal, qkv_t, o_t, scaling_mode, - handle); + head_dim_v, is_training, return_max_logit, attn_scale, dropout, qkv_layout, o_format, + qkv_scale_inv_format, bias_type, attn_mask_type, softmax_type, window_size_left, + window_size_right, bottom_right_diagonal, qkv_t, o_t, scaling_mode, handle); if (!fwd_reason.empty()) { - set_message(message, fwd_reason); + set_message(message, std::move(fwd_reason)); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } if (is_training) { std::string bwd_reason = is_supported_fp8_bwd( batch_size, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, - head_dim_v, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, + head_dim_v, attn_scale, dropout, qkv_layout, o_format, do_format, dqkv_layout, + qkv_scale_inv_format, do_scale_inv_format, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, qkv_t, o_t, scaling_mode, handle); if (!bwd_reason.empty()) { - set_message(message, bwd_reason); + set_message(message, std::move(bwd_reason)); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } } @@ -331,21 +337,25 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } const DType qkv_t = static_cast(q_dtype); + const DType o_t = static_cast(o_dtype); std::string fwd_reason = is_supported_f16_fwd( batch_size, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, - head_dim_v, is_training, return_max_logit, dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size_left, window_size_right, bottom_right_diagonal, qkv_t, handle); + head_dim_v, is_training, return_max_logit, attn_scale, dropout, qkv_layout, o_format, + qkv_scale_inv_format, bias_type, attn_mask_type, softmax_type, window_size_left, + window_size_right, bottom_right_diagonal, qkv_t, o_t, scaling_mode, handle); if (!fwd_reason.empty()) { - set_message(message, fwd_reason); + set_message(message, std::move(fwd_reason)); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } if (is_training) { std::string bwd_reason = is_supported_f16_bwd( batch_size, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, - head_dim_v, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, bottom_right_diagonal, deterministic, qkv_t, handle); + head_dim_v, attn_scale, dropout, qkv_layout, o_format, do_format, dqkv_layout, + qkv_scale_inv_format, do_scale_inv_format, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, bottom_right_diagonal, deterministic, qkv_t, o_t, + scaling_mode, handle); if (!bwd_reason.empty()) { - set_message(message, bwd_reason); + set_message(message, std::move(bwd_reason)); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } } @@ -442,8 +452,10 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTEScalingMode scaling_mode = input_Q->scaling_mode; NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, b, Q_type, KV_type, O_type, scaling_mode, qkv_layout, bias_type, attn_mask_type, - softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, + is_training, b, Q_type, KV_type, O_type, scaling_mode, qkv_layout, o_format, + /*do_format=*/o_format, /*dqkv_layout=*/qkv_layout, qkv_scale_inv_format, + /*do_scale_inv_format=*/qkv_scale_inv_format, bias_type, attn_mask_type, softmax_type, + attn_scale, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, bottom_right_diagonal, return_max_logit, cuda_graph, /*deterministic=*/false, /*message=*/nullptr); @@ -526,8 +538,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTEScalingMode scaling_mode = input_Q->scaling_mode; NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - /*is_training=*/true, b, Q_type, KV_type, O_type, scaling_mode, qkv_layout, bias_type, - attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, + /*is_training=*/true, b, Q_type, KV_type, O_type, scaling_mode, qkv_layout, o_format, + do_format, dqkv_layout, qkv_scale_inv_format, do_scale_inv_format, bias_type, attn_mask_type, + softmax_type, attn_scale, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, bottom_right_diagonal, /*return_max_logit=*/false, cuda_graph, deterministic, /*message=*/nullptr); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 3a2b296ffc..2b6ed2fca4 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -1337,11 +1337,15 @@ void fused_attn_arbitrary_seqlen_bwd( std::string is_supported_f16_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, bool return_max_logit, - float p_dropout, NVTE_QKV_Layout qkv_layout, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, + [[maybe_unused]] NVTE_QKV_Format qkv_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - DType qkv_dtype, cudnnHandle_t handle) { + DType qkv_dtype, [[maybe_unused]] DType o_dtype, + [[maybe_unused]] NVTEScalingMode scaling_mode, + cudnnHandle_t handle) { const auto b = static_cast(batch); const auto h = static_cast(num_attn_heads); const auto sq = static_cast(max_seqlen_q); @@ -1369,17 +1373,15 @@ std::string is_supported_f16_fwd(size_t batch, size_t num_attn_heads, size_t num const int64_t bias_sq = has_bias ? sq : 0; const int64_t bias_skv = has_bias ? skv : 0; - const NVTE_QKV_Format o_format = q_format; - size_t workspace_size = 0; try { fused_attn::fused_attn_arbitrary_seqlen_fwd_impl( b, h, static_cast(num_gqa_groups), sq, skv, static_cast(head_dim_qk), static_cast(head_dim_v), max_b, max_t_q, max_t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, - bias_skv, is_training, return_max_logit, - /*scaling_factor=*/1.0f, p_dropout, qkv_layout, o_format, bias_type, mask_type, - softmax_type, window_size_left, window_size_right, bottom_right_diagonal, + bias_skv, is_training, return_max_logit, attn_scale, p_dropout, qkv_layout, o_format, + bias_type, mask_type, softmax_type, window_size_left, window_size_right, + bottom_right_diagonal, /*devPtrQ=*/nullptr, /*devPtrK=*/nullptr, /*devPtrV=*/nullptr, /*devPtrBias=*/nullptr, /*devPtrSoftmaxOffset=*/nullptr, /*devPtrS1=*/nullptr, /*devPtrS2=*/nullptr, /*devPtrO=*/nullptr, /*devPtrDropoutSeed=*/nullptr, /*devPtrDropoutOffset=*/nullptr, @@ -1399,11 +1401,18 @@ std::string is_supported_f16_fwd(size_t batch, size_t num_attn_heads, size_t num std::string is_supported_f16_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, float p_dropout, NVTE_QKV_Layout qkv_layout, + size_t head_dim_v, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, + [[maybe_unused]] NVTE_QKV_Format qkv_scale_inv_format, + [[maybe_unused]] NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - bool deterministic, DType qkv_dtype, cudnnHandle_t handle) { + bool deterministic, DType qkv_dtype, + [[maybe_unused]] DType o_dtype, + [[maybe_unused]] NVTEScalingMode scaling_mode, + cudnnHandle_t handle) { const auto b = static_cast(batch); const auto h = static_cast(num_attn_heads); const auto sq = static_cast(max_seqlen_q); @@ -1423,16 +1432,12 @@ std::string is_supported_f16_bwd(size_t batch, size_t num_attn_heads, size_t num const int64_t bias_sq = has_bias ? sq : 0; const int64_t bias_skv = has_bias ? skv : 0; - const NVTE_QKV_Format o_format = q_format; - const NVTE_QKV_Format do_format = o_format; - const NVTE_QKV_Layout dqkv_layout = qkv_layout; - size_t workspace_size = 0; try { fused_attn::fused_attn_arbitrary_seqlen_bwd_impl( b, h, static_cast(num_gqa_groups), sq, skv, static_cast(head_dim_qk), static_cast(head_dim_v), max_b, max_t_q, max_t_kv, bias_b, bias_h, bias_sq, - bias_skv, /*scaling_factor=*/1.0f, p_dropout, qkv_layout, o_format, do_format, dqkv_layout, + bias_skv, attn_scale, p_dropout, qkv_layout, o_format, do_format, dqkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, /*devPtrQ=*/nullptr, /*devPtrKTranspose=*/nullptr, /*devPtrVTranspose=*/nullptr, /*devPtrO=*/nullptr, /*devPtrSoftmaxStats=*/nullptr, diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index fe94d0c10c..078b6c700d 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -55,22 +55,29 @@ void fused_attn_arbitrary_seqlen_bwd( std::string is_supported_f16_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, bool return_max_logit, - float p_dropout, NVTE_QKV_Layout qkv_layout, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_QKV_Format qkv_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - DType qkv_dtype, cudnnHandle_t handle); + DType qkv_dtype, DType o_dtype, NVTEScalingMode scaling_mode, + cudnnHandle_t handle); // check if a given configuration is supported for F16/BF16 backward; // if it is, cache the graph built for this config, and return an empty string; // if not, return a diagnostic message in the form of a string. std::string is_supported_f16_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool bottom_right_diagonal, - bool deterministic, DType qkv_dtype, cudnnHandle_t handle); + size_t head_dim_v, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, + NVTE_QKV_Format qkv_scale_inv_format, + NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic, DType qkv_dtype, + DType o_dtype, NVTEScalingMode scaling_mode, + cudnnHandle_t handle); } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index f4064a8d34..fb0790f230 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1327,22 +1327,24 @@ void fused_attn_fp8_bwd( std::string is_supported_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, bool is_training, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, DType qkv_dtype, DType o_dtype, - NVTEScalingMode scaling_mode, cudnnHandle_t handle) { - const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + size_t head_dim_v, bool is_training, + [[maybe_unused]] bool return_max_logit, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_QKV_Format qkv_scale_inv_format, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, + DType qkv_dtype, DType o_dtype, NVTEScalingMode scaling_mode, + cudnnHandle_t handle) { size_t workspace_size = 0; try { fused_attn::fused_attn_fp8_fwd_impl( static_cast(batch), static_cast(num_attn_heads), static_cast(num_gqa_groups), static_cast(max_seqlen_q), static_cast(max_seqlen_kv), static_cast(head_dim_qk), - static_cast(head_dim_v), is_training, /*scaling_factor=*/1.0f, p_dropout, - qkv_layout, /*o_format=*/qkv_format, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, bottom_right_diagonal, + static_cast(head_dim_v), is_training, attn_scale, p_dropout, qkv_layout, o_format, + bias_type, mask_type, softmax_type, window_size_left, window_size_right, + bottom_right_diagonal, /*devPtrQ=*/nullptr, /*devPtrK=*/nullptr, /*devPtrV=*/nullptr, /*devPtrSoftmaxOffset=*/nullptr, /*devPtrM=*/nullptr, /*devPtrO=*/nullptr, /*devPtrDescaleQ=*/nullptr, /*devPtrDescaleK=*/nullptr, /*devPtrDescaleV=*/nullptr, @@ -1350,8 +1352,7 @@ std::string is_supported_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num /*devPtrAmaxO=*/nullptr, /*devPtrAmaxS=*/nullptr, /*devPtrcuSeqlensQ=*/nullptr, /*devPtrcuSeqlensKV=*/nullptr, /*devPtrDropoutSeed=*/nullptr, /*devPtrDropoutOffset=*/nullptr, get_cudnn_fe_dtype(qkv_dtype), get_cudnn_fe_dtype(o_dtype), - scaling_mode, - /*qkv_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, + scaling_mode, qkv_scale_inv_format, /*workspace=*/nullptr, &workspace_size, /*stream=*/static_cast(0), handle); return ""; @@ -1364,13 +1365,16 @@ std::string is_supported_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num std::string is_supported_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool bottom_right_diagonal, - bool deterministic, DType qkv_dtype, DType o_dtype, - NVTEScalingMode scaling_mode, cudnnHandle_t handle) { - const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + size_t head_dim_v, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, + NVTE_QKV_Format qkv_scale_inv_format, + NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic, DType qkv_dtype, + DType o_dtype, NVTEScalingMode scaling_mode, + cudnnHandle_t handle) { const cudnn_frontend::DataType_t qkv_t = get_cudnn_fe_dtype(qkv_dtype); const cudnn_frontend::DataType_t o_t = get_cudnn_fe_dtype(o_dtype); const cudnn_frontend::DataType_t do_t = o_t; @@ -1381,10 +1385,9 @@ std::string is_supported_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num static_cast(batch), static_cast(num_attn_heads), static_cast(num_gqa_groups), static_cast(max_seqlen_q), static_cast(max_seqlen_kv), static_cast(head_dim_qk), - static_cast(head_dim_v), /*scaling_factor=*/1.0f, p_dropout, qkv_layout, - /*o_format=*/qkv_format, /*do_format=*/qkv_format, /*dqkv_layout=*/qkv_layout, bias_type, - mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, - deterministic, + static_cast(head_dim_v), attn_scale, p_dropout, qkv_layout, o_format, + do_format, dqkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, bottom_right_diagonal, deterministic, /*devPtrQ=*/nullptr, /*devPtrK=*/nullptr, /*devPtrV=*/nullptr, /*devPtrM=*/nullptr, /*devPtrO=*/nullptr, /*devPtrdO=*/nullptr, /*devPtrSoftmaxOffset=*/nullptr, /*devPtrdQ=*/nullptr, /*devPtrdK=*/nullptr, /*devPtrdV=*/nullptr, @@ -1399,8 +1402,7 @@ std::string is_supported_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num /*devPtrDescaledO_t=*/nullptr, /*devPtrcuSeqlensQ=*/nullptr, /*devPtrcuSeqlensKV=*/nullptr, /*devPtrDropoutSeed=*/nullptr, /*devPtrDropoutOffset=*/nullptr, qkv_t, o_t, do_t, dqkv_t, scaling_mode, - /*qkv_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, - /*do_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, + qkv_scale_inv_format, do_scale_inv_format, /*workspace=*/nullptr, &workspace_size, /*stream=*/static_cast(0), handle); return ""; diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 01c7561402..078737b99e 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -47,22 +47,28 @@ void fused_attn_fp8_bwd( // if not, return a diagnostic message in the form of a string. std::string is_supported_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, bool is_training, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, DType qkv_dtype, DType o_dtype, - NVTEScalingMode scaling_mode, cudnnHandle_t handle); + size_t head_dim_v, bool is_training, bool return_max_logit, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_QKV_Format qkv_scale_inv_format, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, + DType qkv_dtype, DType o_dtype, NVTEScalingMode scaling_mode, + cudnnHandle_t handle); // check if a given configuration is supported for FP8 backward; // if it is, cache the graph built for this config, and return an empty string; // if not, return a diagnostic message in the form of a string. std::string is_supported_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool bottom_right_diagonal, - bool deterministic, DType qkv_dtype, DType o_dtype, - NVTEScalingMode scaling_mode, cudnnHandle_t handle); + size_t head_dim_v, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, + NVTE_QKV_Format qkv_scale_inv_format, + NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic, DType qkv_dtype, + DType o_dtype, NVTEScalingMode scaling_mode, + cudnnHandle_t handle); } // namespace transformer_engine diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 227afed24e..0d20712207 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -197,6 +197,11 @@ NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout); NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); /*! \brief Get fused attention backend based on input parameters. + * + * This call exercises cudnn-frontend's support checks by building (and caching) the + * cuDNN execution graph for the supported configurations. The configuration parameters + * are a superset of those of ``nvte_fused_attn_fwd`` and ``nvte_fused_attn_bwd`` to + * maintain a consistent signature between graph building and runtime calls. * * \param[in] is_training Whether the model is in training mode. * \param[in] batch_size Batch size. @@ -205,9 +210,19 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); * \param[in] o_dtype The data type of Tensor O. * \param[in] scaling_mode Scaling mode of attention. * \param[in] qkv_layout The layout of Tensors Q, K, V. + * \param[in] o_format The format of Tensor O. + * \param[in] do_format The format of Tensor dO. + * \param[in] dqkv_layout The layout of Tensors dQ, dK, dV. + * \param[in] qkv_scale_inv_format Format of the scale-inverse tensors for QKV in FP8 + * configurations; pass NVTE_QKV_Format_NOT_SET to let the + * backend infer it from ``qkv_layout`` otherwise. + * \param[in] do_scale_inv_format Format of the scale-inverse tensor for dO in FP8 backward + * configurations; pass NVTE_QKV_Format_NOT_SET to let the + * backend infer it from ``do_format`` otherwise. * \param[in] bias_type The attention bias type. * \param[in] attn_mask_type The attention mask type. * \param[in] softmax_type The attention softmax type. + * \param[in] attn_scale Scaling factor for Q * K^T. * \param[in] dropout The dropout probability. * \param[in] num_attn_heads The number of heads in Q. * \param[in] num_gqa_groups The number of heads in K, V. @@ -222,17 +237,24 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); * \param[in] return_max_logit Whether to produce Max along with Stats. * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] deterministic Whether determinism is required or not. - * \param[out] message Empty string on success, otherwise a diagnostic string - * describing why the configuration was rejected. + * \param[out] message Empty on success, otherwise a diagnostic string describing + * why the configuration was rejected. The string pointer refers to a + * per-thread buffer owned by the library and remains valid + * only until the next call to ``nvte_get_fused_attn_backend`` + * on the same thread; callers that need to retain the + * message across further calls must copy it. Pass NULL to + * skip diagnostics. */ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, size_t batch_size, NVTEDType q_dtype, NVTEDType kv_dtype, NVTEDType o_dtype, - NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, - size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, - size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool return_max_logit, bool cuda_graph, bool deterministic, - const char **message); + NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, + NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + float attn_scale, float dropout, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool return_max_logit, bool cuda_graph, bool deterministic, const char **message); /*! \brief Compute dot product attention with separate Q, K and V. * diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index ac6cf8975c..735383d26d 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -13,6 +13,7 @@ import jax.numpy as jnp from transformer_engine_jax import NVTE_Bias_Type +from transformer_engine_jax import NVTE_Fused_Attn_Backend from transformer_engine_jax import NVTE_Mask_Type from transformer_engine_jax import NVTE_QKV_Layout from transformer_engine_jax import NVTE_QKV_Format @@ -341,12 +342,16 @@ def is_fused_attn_kernel_available( head_dim_v, window_size: Optional[Tuple[int, int]] = None, bottom_right_diagonal: Optional[bool] = None, + return_reason: bool = False, ): """ To check whether the fused attention kernel is supported. If ``bottom_right_diagonal`` is None, it is derived from the mask type, matching the convention used everywhere else in JAX TE (see ``_FusedAttnConfig`` constructions). + + When ``return_reason`` is ``True``, returns ``(available, message)`` where ``message`` is + the diagnostic string the backend produced (empty on success). """ window_size_tuple = (-1, -1) if window_size is None else window_size @@ -376,7 +381,12 @@ def make_helper(attn_mask_type): bottom_right, ) - return make_helper(attn_mask_type).is_fused_attn_kernel_available() + helper = make_helper(attn_mask_type) + if return_reason: + backend, message = helper.get_fused_attn_backend() + available = backend != NVTE_Fused_Attn_Backend.NVTE_No_Backend + return available, message + return helper.is_fused_attn_kernel_available() def _obtain_batch_and_max_seqlen(qkv, qkv_layout): diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 2a533c3f3e..c00feec748 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -16,7 +16,12 @@ from jax.experimental.custom_partitioning import SdyShardingRule import transformer_engine_jax -from transformer_engine_jax import NVTE_Fused_Attn_Backend, NVTEScalingMode +from transformer_engine_jax import ( + NVTE_Fused_Attn_Backend, + NVTE_QKV_Format, + NVTE_QKV_Layout, + NVTEScalingMode, +) from transformer_engine.jax.attention import ( AttnBiasType, AttnMaskType, @@ -126,7 +131,11 @@ class FusedAttnHelper: bottom_right_diagonal: bool def is_fused_attn_kernel_available(self): - """Check if there is available fused attention kernel""" + """Check if there is available fused attention kernel. + + Use ``get_fused_attn_backend()`` directly to also get the diagnostic message + explaining why a configuration was rejected. + """ backend, _ = self.get_fused_attn_backend() return backend != NVTE_Fused_Attn_Backend.NVTE_No_Backend @@ -145,6 +154,11 @@ def get_fused_attn_backend(self): q_type, NVTEScalingMode.NVTE_INVALID_SCALING, self.qkv_layout.value, + NVTE_QKV_Format.NVTE_QKV_Format_NOT_SET, + NVTE_QKV_Format.NVTE_QKV_Format_NOT_SET, + NVTE_QKV_Layout.NVTE_QKV_Layout_NOT_SET, + NVTE_QKV_Format.NVTE_QKV_Format_NOT_SET, + NVTE_QKV_Format.NVTE_QKV_Format_NOT_SET, self.attn_bias_type.value, self.attn_mask_type.value, self.softmax_type.value, diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 1e8d99c3d8..813ea7db4c 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -149,13 +149,20 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler); // Returns (backend, message). `message` is empty on success, otherwise a diagnostic string // describing why the configuration was rejected when backend = NVTE_No_Backend. +// `o_format`, `do_format`, and `dqkv_layout` describe the output, output-gradient, and +// QKV-gradient formats/layouts the actual fwd/bwd kernels will use; pass NVTE_QKV_Format_NOT_SET +// / NVTE_QKV_Layout_NOT_SET to request that they be inferred from `qkv_layout`. +// `qkv_scale_inv_format` / `do_scale_inv_format` describe the FP8 scale-inverse layouts; pass +// NVTE_QKV_Format_NOT_SET to let the backend infer them. std::tuple GetFusedAttnBackend( bool is_training, size_t batch_size, DType q_dtype, DType kv_dtype, DType o_dtype, - NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, float dropout_probability, - size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, - size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic); + NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, + NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, + size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic); pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 5cd3265c3e..1044674856 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -13,19 +13,46 @@ namespace jax { std::tuple GetFusedAttnBackend( bool is_training, size_t batch_size, DType q_dtype, DType kv_dtype, DType o_dtype, - NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, float dropout_probability, - size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, - size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic) { + NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, + NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, + size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic) { + // For convenience, allow callers to pass *_NOT_SET sentinels and infer the missing values + // from `qkv_layout`; JAX's fused-attn path always uses matching output / dQKV layouts so + // this preserves the existing behavior without forcing every Python call site to compute them. + // The scale-inv formats stay as NOT_SET when the caller passes NOT_SET because cuDNN-frontend + // already infers them from the QKV layout for the recipes JAX currently exercises. + if (o_format == NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET) { + o_format = nvte_get_q_format(qkv_layout); + } + if (do_format == NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET) { + do_format = o_format; + } + if (dqkv_layout == NVTE_QKV_Layout::NVTE_QKV_Layout_NOT_SET) { + dqkv_layout = qkv_layout; + } + // The pointer returned via `message` aliases a thread-local buffer in libtransformer_engine that + // is overwritten by the next nvte_get_fused_attn_backend call on this thread. We copy it into a + // std::string here so the value we return is safe to retain. + // + // NOTE: attn_scale is part of the cuDNN-frontend graph cache key (FADescriptor_v1::attnScale). + // Passing 1.0f here means the graph this probe builds will not be reused at the corresponding + // FusedAttnForwardImpl/FusedAttnBackwardImpl call (which forwards the user's actual scale). + // The lost reuse is a known performance gap that will be addressed when the future + // config-struct refactor also updates this Python-facing wrapper. const char *message = nullptr; auto backend = nvte_get_fused_attn_backend( is_training, batch_size, static_cast(q_dtype), static_cast(kv_dtype), - static_cast(o_dtype), scaling_mode, qkv_layout, bias_type, mask_type, softmax_type, - dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen, qk_head_dim, - v_head_dim, window_size_left, window_size_right, bottom_right_diagonal, - /*return_max_logit=*/false, /*cuda_graph=*/false, deterministic, &message); - return {backend, message ? std::string(message) : std::string()}; + static_cast(o_dtype), scaling_mode, qkv_layout, o_format, do_format, dqkv_layout, + qkv_scale_inv_format, do_scale_inv_format, bias_type, mask_type, softmax_type, + /*attn_scale=*/1.0f, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, + kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, + bottom_right_diagonal, /*return_max_logit=*/false, /*cuda_graph=*/false, deterministic, + &message); + return {backend, message != nullptr ? std::string(message) : std::string()}; } /* @@ -264,12 +291,19 @@ static void FusedAttnForwardImpl( /* Prepare RNG state */ auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); + // JAX uses the same layout for output / dQKV as for QKV, so derive the formats from qkv_layout. + // Scale-inv formats stay NOT_SET because JAX's fused-attn path here is non-FP8. + const NVTE_QKV_Format probe_o_format = nvte_get_q_format(qkv_layout); auto backend = nvte_get_fused_attn_backend( is_training, input_batch, static_cast(dtype), static_cast(dtype), - static_cast(dtype), NVTE_INVALID_SCALING, qkv_layout, bias_type, mask_type, - softmax_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, - qk_head_dim, v_head_dim, window_size_left, window_size_right, bottom_right_diagonal, - /*return_max_logit=*/false, /*cuda_graph=*/false, deterministic, /*message=*/nullptr); + static_cast(dtype), NVTE_INVALID_SCALING, qkv_layout, probe_o_format, + /*do_format=*/probe_o_format, /*dqkv_layout=*/qkv_layout, + /*qkv_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, + /*do_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, bias_type, mask_type, + softmax_type, scaling_factor, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, + kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, + bottom_right_diagonal, /*return_max_logit=*/false, /*cuda_graph=*/false, deterministic, + /*message=*/nullptr); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); /* Auxiliary tensors (to be propagated to the backward pass later) */ @@ -541,12 +575,19 @@ static void FusedAttnBackwardImpl( /* Auxiliary tensors (propagated from the forward pass) */ NVTETensorPack aux_input_tensors; nvte_tensor_pack_create(&aux_input_tensors); + // JAX uses the same layout for output / dQKV as for QKV, so derive the formats from qkv_layout. + // Scale-inv formats stay NOT_SET because JAX's fused-attn path here is non-FP8. + const NVTE_QKV_Format probe_o_format = nvte_get_q_format(qkv_layout); auto backend = nvte_get_fused_attn_backend( is_training, input_batch, static_cast(dtype), static_cast(dtype), - static_cast(dtype), NVTE_INVALID_SCALING, qkv_layout, bias_type, mask_type, - softmax_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, - qk_head_dim, v_head_dim, window_size_left, window_size_right, bottom_right_diagonal, - /*return_max_logit=*/false, /*cuda_graph=*/false, deterministic, /*message=*/nullptr); + static_cast(dtype), NVTE_INVALID_SCALING, qkv_layout, probe_o_format, + /*do_format=*/probe_o_format, /*dqkv_layout=*/qkv_layout, + /*qkv_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, + /*do_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, bias_type, mask_type, + softmax_type, scaling_factor, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, + kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, + bottom_right_diagonal, /*return_max_logit=*/false, /*cuda_graph=*/false, deterministic, + /*message=*/nullptr); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, softmax_aux, rng_state, bias, softmax_offset); diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 2d55abedc6..bdfec12b8b 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -160,12 +160,14 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) - .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); + .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD) + .value("NVTE_QKV_Layout_NOT_SET", NVTE_QKV_Layout::NVTE_QKV_Layout_NOT_SET); pybind11::enum_(m, "NVTE_QKV_Format", pybind11::module_local()) .value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) .value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) - .value("NVTE_THD", NVTE_QKV_Format::NVTE_THD); + .value("NVTE_THD", NVTE_QKV_Format::NVTE_THD) + .value("NVTE_QKV_Format_NOT_SET", NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET); pybind11::enum_(m, "NVTE_Softmax_Type", pybind11::module_local()) .value("NVTE_VANILLA_SOFTMAX", NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 184547aa92..3b8682d7bc 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -762,7 +762,7 @@ def __call__( head_dim_qk = self.head_dim head_dim_v = self.head_dim - has_fused_attn_kernel = is_fused_attn_kernel_available( + has_fused_attn_kernel, fused_attn_reject_reason = is_fused_attn_kernel_available( # This needs to be fixed: TE-Jax has historically correlated training mode with deterministic mode. not deterministic, batch_size, @@ -781,15 +781,17 @@ def __call__( head_dim_qk, head_dim_v, self.window_size, + return_reason=True, ) use_fused_attn = enable_fused_attn and has_fused_attn_kernel if enable_fused_attn and not has_fused_attn_kernel: + reason = fused_attn_reject_reason or "(no diagnostic message available)" warnings.warn( "Fused attention is not enabled because there is no available kernel.\n" "Fall back to the unfused attention.\n" - "Please try to update the cuDNN and TE to the latest version.\n" + f"Reason for this rejection is: {reason}\n" f"{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n" f"{self.attention_dropout=}\n{self.num_attention_heads=}\n{self.window_size=}\n" f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{head_dim_qk=}\n{head_dim_v=}\n" diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index db55f2fbd3..cf2f297083 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -23,6 +23,7 @@ import transformer_engine as te from transformer_engine.pytorch.cpp_extensions.fused_attn import ( QKVLayout, + QKVFormat, AttnBiasType, AttnMaskType, SoftmaxType, @@ -1224,6 +1225,8 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt kv_type = q_type o_type = q_type scaling_mode = tex.NVTEScalingMode.NVTE_INVALID_SCALING + qkv_scale_inv_format = None + do_scale_inv_format = None if fp8 and fp8_meta["recipe"].fp8_dpa: recipe = fp8_meta["recipe"] q_type = get_fp8_te_dtype(recipe, fprop_tensor=True) @@ -1232,12 +1235,17 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if recipe.mxfp8(): scaling_mode = tex.NVTEScalingMode.NVTE_MXFP8_1D_SCALING o_type = TE_DType[torch.bfloat16] + qkv_scale_inv_format = "bhsd" + do_scale_inv_format = "bhsd" elif recipe.float8_current_scaling() and cs_o_in_f16: scaling_mode = tex.NVTEScalingMode.NVTE_DELAYED_TENSOR_SCALING o_type = TE_DType[torch.bfloat16] else: scaling_mode = tex.NVTEScalingMode.NVTE_DELAYED_TENSOR_SCALING o_type = q_type + o_format = q_format + do_format = o_format + dqkv_layout = qkv_layout fused_attention_backend, reject_message = tex.get_fused_attn_backend( is_training, batch_size, @@ -1246,6 +1254,11 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt o_type, scaling_mode, QKVLayout[qkv_layout], + QKVFormat[o_format], + QKVFormat[do_format], + QKVLayout[dqkv_layout], + QKVFormat[qkv_scale_inv_format], + QKVFormat[do_scale_inv_format], AttnBiasType[fu_core_attention_bias_type], AttnMaskType[attn_mask_type], SoftmaxType[softmax_type], diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 205e7eb834..74021b81b5 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -80,6 +80,8 @@ std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::T std::tuple get_fused_attn_backend( bool is_training, size_t batch_size, const DType q_dtype, const DType kv_dtype, const DType o_dtype, NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, + NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 41dcd3301a..4cda724a8b 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -43,6 +43,8 @@ namespace transformer_engine::pytorch { std::tuple get_fused_attn_backend( bool is_training, size_t batch_size, const DType q_dtype, const DType kv_dtype, const DType o_dtype, NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, + NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, @@ -51,11 +53,12 @@ std::tuple get_fused_attn_backend( const char *message = nullptr; NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, batch_size, static_cast(q_dtype), static_cast(kv_dtype), - static_cast(o_dtype), scaling_mode, qkv_layout, bias_type, attn_mask_type, - softmax_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, + static_cast(o_dtype), scaling_mode, qkv_layout, o_format, do_format, dqkv_layout, + qkv_scale_inv_format, do_scale_inv_format, bias_type, attn_mask_type, softmax_type, + /*attn_scale=*/1.0f, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right, bottom_right_diagonal, return_max_logit, cuda_graph, deterministic, &message); - return {fused_attention_backend, message ? std::string(message) : std::string()}; + return {fused_attention_backend, message != nullptr ? std::string(message) : std::string()}; } // helper function for S and dP quantizers From 956f159d794ae11e53a8eee5092a80a2dbf3a525 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 8 May 2026 19:20:12 +0000 Subject: [PATCH 16/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/fused_attn/fused_attn.cpp | 14 ++--- .../fused_attn_f16_arbitrary_seqlen.cu | 51 ++++++++----------- .../fused_attn_f16_arbitrary_seqlen.h | 3 +- .../common/fused_attn/fused_attn_fp8.cu | 46 ++++++++--------- .../common/fused_attn/fused_attn_fp8.h | 3 +- .../include/transformer_engine/fused_attn.h | 14 ++--- transformer_engine/jax/csrc/extensions.h | 12 ++--- .../jax/csrc/extensions/attention.cpp | 12 ++--- 8 files changed, 70 insertions(+), 85 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 47ee2ff9e5..4c64d9a3d8 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -246,13 +246,13 @@ void set_message(const char **message, std::string reason) { NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, size_t batch_size, NVTEDType q_dtype, NVTEDType kv_dtype, NVTEDType o_dtype, NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, - NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, - NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - float attn_scale, float dropout, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, - int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - bool return_max_logit, bool cuda_graph, bool deterministic, const char **message) { + NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_QKV_Format qkv_scale_inv_format, + NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, float attn_scale, float dropout, size_t num_attn_heads, + size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool return_max_logit, bool cuda_graph, bool deterministic, + const char **message) { using namespace transformer_engine; set_message(message, ""); NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type."); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 2b6ed2fca4..372bc68288 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -1334,18 +1334,15 @@ void fused_attn_arbitrary_seqlen_bwd( } } -std::string is_supported_f16_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, bool is_training, bool return_max_logit, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_QKV_Format o_format, - [[maybe_unused]] NVTE_QKV_Format qkv_scale_inv_format, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool bottom_right_diagonal, - DType qkv_dtype, [[maybe_unused]] DType o_dtype, - [[maybe_unused]] NVTEScalingMode scaling_mode, - cudnnHandle_t handle) { +std::string is_supported_f16_fwd( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, + bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, [[maybe_unused]] NVTE_QKV_Format qkv_scale_inv_format, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + DType qkv_dtype, [[maybe_unused]] DType o_dtype, [[maybe_unused]] NVTEScalingMode scaling_mode, + cudnnHandle_t handle) { const auto b = static_cast(batch); const auto h = static_cast(num_attn_heads); const auto sq = static_cast(max_seqlen_q); @@ -1399,20 +1396,16 @@ std::string is_supported_f16_fwd(size_t batch, size_t num_attn_heads, size_t num } } -std::string is_supported_f16_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, - NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, - [[maybe_unused]] NVTE_QKV_Format qkv_scale_inv_format, - [[maybe_unused]] NVTE_QKV_Format do_scale_inv_format, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool bottom_right_diagonal, - bool deterministic, DType qkv_dtype, - [[maybe_unused]] DType o_dtype, - [[maybe_unused]] NVTEScalingMode scaling_mode, - cudnnHandle_t handle) { +std::string is_supported_f16_bwd( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, + NVTE_QKV_Layout dqkv_layout, [[maybe_unused]] NVTE_QKV_Format qkv_scale_inv_format, + [[maybe_unused]] NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, DType qkv_dtype, + [[maybe_unused]] DType o_dtype, [[maybe_unused]] NVTEScalingMode scaling_mode, + cudnnHandle_t handle) { const auto b = static_cast(batch); const auto h = static_cast(num_attn_heads); const auto sq = static_cast(max_seqlen_q); @@ -1437,9 +1430,9 @@ std::string is_supported_f16_bwd(size_t batch, size_t num_attn_heads, size_t num fused_attn::fused_attn_arbitrary_seqlen_bwd_impl( b, h, static_cast(num_gqa_groups), sq, skv, static_cast(head_dim_qk), static_cast(head_dim_v), max_b, max_t_q, max_t_kv, bias_b, bias_h, bias_sq, - bias_skv, attn_scale, p_dropout, qkv_layout, o_format, do_format, dqkv_layout, - bias_type, mask_type, softmax_type, window_size_left, window_size_right, - bottom_right_diagonal, deterministic, /*devPtrQ=*/nullptr, /*devPtrKTranspose=*/nullptr, + bias_skv, attn_scale, p_dropout, qkv_layout, o_format, do_format, dqkv_layout, bias_type, + mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, + deterministic, /*devPtrQ=*/nullptr, /*devPtrKTranspose=*/nullptr, /*devPtrVTranspose=*/nullptr, /*devPtrO=*/nullptr, /*devPtrSoftmaxStats=*/nullptr, /*devPtrBias=*/nullptr, /*devPtrSoftmaxOffset=*/nullptr, /*devPtrdQ=*/nullptr, /*devPtrdK=*/nullptr, /*devPtrdV=*/nullptr, /*devPtrdO=*/nullptr, diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 078b6c700d..9d6b57d0a0 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -76,8 +76,7 @@ std::string is_supported_f16_bwd(size_t batch, size_t num_attn_heads, size_t num NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, DType qkv_dtype, - DType o_dtype, NVTEScalingMode scaling_mode, - cudnnHandle_t handle); + DType o_dtype, NVTEScalingMode scaling_mode, cudnnHandle_t handle); } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index fb0790f230..352e47fbfd 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1325,17 +1325,14 @@ void fused_attn_fp8_bwd( } } -std::string is_supported_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, bool is_training, - [[maybe_unused]] bool return_max_logit, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_QKV_Format o_format, NVTE_QKV_Format qkv_scale_inv_format, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool bottom_right_diagonal, - DType qkv_dtype, DType o_dtype, NVTEScalingMode scaling_mode, - cudnnHandle_t handle) { +std::string is_supported_fp8_fwd( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, + [[maybe_unused]] bool return_max_logit, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format qkv_scale_inv_format, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + DType qkv_dtype, DType o_dtype, NVTEScalingMode scaling_mode, cudnnHandle_t handle) { size_t workspace_size = 0; try { fused_attn::fused_attn_fp8_fwd_impl( @@ -1363,18 +1360,15 @@ std::string is_supported_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num } } -std::string is_supported_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, - NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, - NVTE_QKV_Format qkv_scale_inv_format, - NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic, DType qkv_dtype, - DType o_dtype, NVTEScalingMode scaling_mode, - cudnnHandle_t handle) { +std::string is_supported_fp8_bwd( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, + NVTE_QKV_Layout dqkv_layout, NVTE_QKV_Format qkv_scale_inv_format, + NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic, DType qkv_dtype, DType o_dtype, + NVTEScalingMode scaling_mode, cudnnHandle_t handle) { const cudnn_frontend::DataType_t qkv_t = get_cudnn_fe_dtype(qkv_dtype); const cudnn_frontend::DataType_t o_t = get_cudnn_fe_dtype(o_dtype); const cudnn_frontend::DataType_t do_t = o_t; @@ -1385,9 +1379,9 @@ std::string is_supported_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num static_cast(batch), static_cast(num_attn_heads), static_cast(num_gqa_groups), static_cast(max_seqlen_q), static_cast(max_seqlen_kv), static_cast(head_dim_qk), - static_cast(head_dim_v), attn_scale, p_dropout, qkv_layout, o_format, - do_format, dqkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, bottom_right_diagonal, deterministic, + static_cast(head_dim_v), attn_scale, p_dropout, qkv_layout, o_format, do_format, + dqkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, + bottom_right_diagonal, deterministic, /*devPtrQ=*/nullptr, /*devPtrK=*/nullptr, /*devPtrV=*/nullptr, /*devPtrM=*/nullptr, /*devPtrO=*/nullptr, /*devPtrdO=*/nullptr, /*devPtrSoftmaxOffset=*/nullptr, /*devPtrdQ=*/nullptr, /*devPtrdK=*/nullptr, /*devPtrdV=*/nullptr, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 078737b99e..8dfdb8c412 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -69,6 +69,5 @@ std::string is_supported_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, DType qkv_dtype, - DType o_dtype, NVTEScalingMode scaling_mode, - cudnnHandle_t handle); + DType o_dtype, NVTEScalingMode scaling_mode, cudnnHandle_t handle); } // namespace transformer_engine diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 0d20712207..8e9864f916 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -248,13 +248,13 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, size_t batch_size, NVTEDType q_dtype, NVTEDType kv_dtype, NVTEDType o_dtype, NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, - NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, - NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - float attn_scale, float dropout, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, - int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - bool return_max_logit, bool cuda_graph, bool deterministic, const char **message); + NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_QKV_Format qkv_scale_inv_format, + NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, float attn_scale, float dropout, size_t num_attn_heads, + size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool return_max_logit, bool cuda_graph, bool deterministic, + const char **message); /*! \brief Compute dot product attention with separate Q, K and V. * diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 813ea7db4c..c543ae6019 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -157,12 +157,12 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler); std::tuple GetFusedAttnBackend( bool is_training, size_t batch_size, DType q_dtype, DType kv_dtype, DType o_dtype, NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, - NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, - NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, - int64_t window_size_right, bool bottom_right_diagonal, bool deterministic); + NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_QKV_Format qkv_scale_inv_format, + NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, float dropout_probability, size_t q_attn_heads, + size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, + size_t v_head_dim, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic); pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 1044674856..498d614812 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -14,12 +14,12 @@ namespace jax { std::tuple GetFusedAttnBackend( bool is_training, size_t batch_size, DType q_dtype, DType kv_dtype, DType o_dtype, NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, - NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, - NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, - int64_t window_size_right, bool bottom_right_diagonal, bool deterministic) { + NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_QKV_Format qkv_scale_inv_format, + NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, float dropout_probability, size_t q_attn_heads, + size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, + size_t v_head_dim, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic) { // For convenience, allow callers to pass *_NOT_SET sentinels and infer the missing values // from `qkv_layout`; JAX's fused-attn path always uses matching output / dQKV layouts so // this preserves the existing behavior without forcing every Python call site to compute them. From b21f6065a1cd900b36cfd2cc8879074718646c76 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 8 May 2026 12:33:51 -0700 Subject: [PATCH 17/23] minor tweaks for docstring Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/jax/attention.py | 5 +--- transformer_engine/jax/csrc/extensions.h | 7 ----- .../jax/csrc/extensions/attention.cpp | 28 ++++--------------- transformer_engine/jax/flax/transformer.py | 5 ++-- 4 files changed, 9 insertions(+), 36 deletions(-) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 735383d26d..e4fce42ce7 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -347,11 +347,8 @@ def is_fused_attn_kernel_available( """ To check whether the fused attention kernel is supported. - If ``bottom_right_diagonal`` is None, it is derived from the mask type, matching the - convention used everywhere else in JAX TE (see ``_FusedAttnConfig`` constructions). - When ``return_reason`` is ``True``, returns ``(available, message)`` where ``message`` is - the diagnostic string the backend produced (empty on success). + the diagnostic string for the reason why the fused attention kernel is not supported (empty on success). """ window_size_tuple = (-1, -1) if window_size is None else window_size diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index c543ae6019..cf455357c6 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -147,13 +147,6 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler); -// Returns (backend, message). `message` is empty on success, otherwise a diagnostic string -// describing why the configuration was rejected when backend = NVTE_No_Backend. -// `o_format`, `do_format`, and `dqkv_layout` describe the output, output-gradient, and -// QKV-gradient formats/layouts the actual fwd/bwd kernels will use; pass NVTE_QKV_Format_NOT_SET -// / NVTE_QKV_Layout_NOT_SET to request that they be inferred from `qkv_layout`. -// `qkv_scale_inv_format` / `do_scale_inv_format` describe the FP8 scale-inverse layouts; pass -// NVTE_QKV_Format_NOT_SET to let the backend infer them. std::tuple GetFusedAttnBackend( bool is_training, size_t batch_size, DType q_dtype, DType kv_dtype, DType o_dtype, NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 498d614812..e37bd4442c 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -14,17 +14,12 @@ namespace jax { std::tuple GetFusedAttnBackend( bool is_training, size_t batch_size, DType q_dtype, DType kv_dtype, DType o_dtype, NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, - NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_QKV_Format qkv_scale_inv_format, - NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, float dropout_probability, size_t q_attn_heads, - size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, - size_t v_head_dim, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic) { - // For convenience, allow callers to pass *_NOT_SET sentinels and infer the missing values - // from `qkv_layout`; JAX's fused-attn path always uses matching output / dQKV layouts so - // this preserves the existing behavior without forcing every Python call site to compute them. - // The scale-inv formats stay as NOT_SET when the caller passes NOT_SET because cuDNN-frontend - // already infers them from the QKV layout for the recipes JAX currently exercises. + NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, + NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, + size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic) { if (o_format == NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET) { o_format = nvte_get_q_format(qkv_layout); } @@ -34,15 +29,6 @@ std::tuple GetFusedAttnBackend( if (dqkv_layout == NVTE_QKV_Layout::NVTE_QKV_Layout_NOT_SET) { dqkv_layout = qkv_layout; } - // The pointer returned via `message` aliases a thread-local buffer in libtransformer_engine that - // is overwritten by the next nvte_get_fused_attn_backend call on this thread. We copy it into a - // std::string here so the value we return is safe to retain. - // - // NOTE: attn_scale is part of the cuDNN-frontend graph cache key (FADescriptor_v1::attnScale). - // Passing 1.0f here means the graph this probe builds will not be reused at the corresponding - // FusedAttnForwardImpl/FusedAttnBackwardImpl call (which forwards the user's actual scale). - // The lost reuse is a known performance gap that will be addressed when the future - // config-struct refactor also updates this Python-facing wrapper. const char *message = nullptr; auto backend = nvte_get_fused_attn_backend( is_training, batch_size, static_cast(q_dtype), static_cast(kv_dtype), @@ -291,8 +277,6 @@ static void FusedAttnForwardImpl( /* Prepare RNG state */ auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); - // JAX uses the same layout for output / dQKV as for QKV, so derive the formats from qkv_layout. - // Scale-inv formats stay NOT_SET because JAX's fused-attn path here is non-FP8. const NVTE_QKV_Format probe_o_format = nvte_get_q_format(qkv_layout); auto backend = nvte_get_fused_attn_backend( is_training, input_batch, static_cast(dtype), static_cast(dtype), diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 3b8682d7bc..f5ca3ff04d 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -789,12 +789,11 @@ def __call__( if enable_fused_attn and not has_fused_attn_kernel: reason = fused_attn_reject_reason or "(no diagnostic message available)" warnings.warn( - "Fused attention is not enabled because there is no available kernel.\n" - "Fall back to the unfused attention.\n" - f"Reason for this rejection is: {reason}\n" + "Falling back to the unfused attention backend as fused attention does not support:\n" f"{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n" f"{self.attention_dropout=}\n{self.num_attention_heads=}\n{self.window_size=}\n" f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{head_dim_qk=}\n{head_dim_v=}\n" + f"Reason for this rejection: {reason}\n" ) dropout_rng = None From 34219205eab29cfc68c0afcda53b44309edd53c6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 8 May 2026 19:36:47 +0000 Subject: [PATCH 18/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/csrc/extensions/attention.cpp | 12 ++++++------ transformer_engine/jax/flax/transformer.py | 8 +++----- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index e37bd4442c..9c6f2483cd 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -14,12 +14,12 @@ namespace jax { std::tuple GetFusedAttnBackend( bool is_training, size_t batch_size, DType q_dtype, DType kv_dtype, DType o_dtype, NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, - NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, - NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, - int64_t window_size_right, bool bottom_right_diagonal, bool deterministic) { + NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_QKV_Format qkv_scale_inv_format, + NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, float dropout_probability, size_t q_attn_heads, + size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, + size_t v_head_dim, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic) { if (o_format == NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET) { o_format = nvte_get_q_format(qkv_layout); } diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index f5ca3ff04d..35a48442d2 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -789,11 +789,9 @@ def __call__( if enable_fused_attn and not has_fused_attn_kernel: reason = fused_attn_reject_reason or "(no diagnostic message available)" warnings.warn( - "Falling back to the unfused attention backend as fused attention does not support:\n" - f"{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n" - f"{self.attention_dropout=}\n{self.num_attention_heads=}\n{self.window_size=}\n" - f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{head_dim_qk=}\n{head_dim_v=}\n" - f"Reason for this rejection: {reason}\n" + "Falling back to the unfused attention backend as fused attention does not" + f" support:\n{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n{self.attention_dropout=}\n{self.num_attention_heads=}\n{self.window_size=}\n{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{head_dim_qk=}\n{head_dim_v=}\nReason" + f" for this rejection: {reason}\n" ) dropout_rng = None From 7956b4339ccf44f4d4792dd02a378c3520ef7b8a Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 8 May 2026 14:41:34 -0700 Subject: [PATCH 19/23] replace with nvte_get_fused_attn_backend_v2 and add NVTEFusedAttnConfig Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn.cpp | 231 ++++++++++++------ .../fused_attn_f16_arbitrary_seqlen.cu | 70 ++++-- .../fused_attn_f16_arbitrary_seqlen.h | 23 +- .../common/fused_attn/fused_attn_fp8.cu | 68 ++++-- .../common/fused_attn/fused_attn_fp8.h | 23 +- .../include/transformer_engine/fused_attn.h | 185 +++++++++----- .../jax/cpp_extensions/attention.py | 3 + transformer_engine/jax/csrc/extensions.h | 6 +- .../jax/csrc/extensions/attention.cpp | 92 ++++--- .../dot_product_attention.py | 2 + .../attention/dot_product_attention/utils.py | 6 + transformer_engine/pytorch/csrc/extensions.h | 8 +- .../pytorch/csrc/extensions/attention.cpp | 48 +++- 13 files changed, 496 insertions(+), 269 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 4c64d9a3d8..6670fd59ed 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -228,7 +228,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { namespace { // per-thread storage for the diagnostic string -// re-used (cleared + re-populated) on every call to nvte_get_fused_attn_backend on this thread +// re-used (cleared + re-populated) on every call to nvte_get_fused_attn_backend_v2 on this thread thread_local std::string fused_attn_backend_message_buffer; // Stash `reason` in the thread-local buffer and, if the caller asked for a diagnostic, @@ -243,30 +243,26 @@ void set_message(const char **message, std::string reason) { } // namespace // select a backend for fused attention -NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( - bool is_training, size_t batch_size, NVTEDType q_dtype, NVTEDType kv_dtype, NVTEDType o_dtype, - NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, - NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_QKV_Format qkv_scale_inv_format, - NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, float attn_scale, float dropout, size_t num_attn_heads, - size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool return_max_logit, bool cuda_graph, bool deterministic, - const char **message) { +NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend_v2(const NVTEFusedAttnConfig *cfg, + const char **message) { using namespace transformer_engine; set_message(message, ""); - NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type."); + NVTE_CHECK(cfg != nullptr, "NVTEFusedAttnConfig pointer must not be NULL."); + NVTE_CHECK(cfg->struct_size == sizeof(NVTEFusedAttnConfig), + "NVTEFusedAttnConfig::struct_size must equal sizeof(NVTEFusedAttnConfig); " + "did you forget NVTE_FUSED_ATTN_CONFIG_INIT?"); cudnnHandle_t handle = cudnnExecutionPlanManager::Instance().GetHandle(); - const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(cfg->qkv_layout); + const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(cfg->qkv_layout); const auto cudnn_runtime_version = cudnnGetVersion(); // THD + 64-bit ragged offsets require cuDNN >= 9.5 const bool requires_64bit_ragged_offset = - (qkv_format == NVTE_THD && fused_attn::get_ragged_offset_dtype( - layout_group, num_attn_heads, num_gqa_groups, max_seqlen_q, - max_seqlen_kv, head_dim_qk, head_dim_v) == DType::kInt64); + (qkv_format == NVTE_THD && + fused_attn::get_ragged_offset_dtype(layout_group, cfg->num_attn_heads, cfg->num_gqa_groups, + cfg->max_seqlen_q, cfg->max_seqlen_kv, cfg->head_dim_qk, + cfg->head_dim_v) == DType::kInt64); if (requires_64bit_ragged_offset && cudnn_runtime_version < 90500) { set_message(message, "Configuration requires 64-bit ragged offsets, which require " @@ -276,21 +272,21 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // THD requires padding-style mask if (qkv_format == NVTE_QKV_Format::NVTE_THD && - attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && - attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && - attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) { + cfg->attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && + cfg->attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && + cfg->attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) { set_message(message, "THD format requires PADDING / PADDING_CAUSAL / PADDING_CAUSAL_BOTTOM_RIGHT mask."); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } - const bool is_fp8 = - (q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2); - const bool is_f16_or_bf16 = - (q_dtype == NVTEDType::kNVTEFloat16 || q_dtype == NVTEDType::kNVTEBFloat16); + const bool is_fp8 = (cfg->qkv_dtype == NVTEDType::kNVTEFloat8E4M3 || + cfg->qkv_dtype == NVTEDType::kNVTEFloat8E5M2); + const bool is_f16_or_bf16 = (cfg->qkv_dtype == NVTEDType::kNVTEFloat16 || + cfg->qkv_dtype == NVTEDType::kNVTEBFloat16); if (is_fp8) { - if (return_max_logit) { + if (cfg->return_max_logit) { set_message(message, "FP8 fused attention does not support return_max_logit=True."); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } @@ -300,24 +296,13 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( std::to_string(static_cast(qkv_format)) + "."); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } - const DType qkv_t = static_cast(q_dtype); - const DType o_t = static_cast(o_dtype); - std::string fwd_reason = is_supported_fp8_fwd( - batch_size, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, - head_dim_v, is_training, return_max_logit, attn_scale, dropout, qkv_layout, o_format, - qkv_scale_inv_format, bias_type, attn_mask_type, softmax_type, window_size_left, - window_size_right, bottom_right_diagonal, qkv_t, o_t, scaling_mode, handle); + std::string fwd_reason = is_supported_fp8_fwd(cfg, handle); if (!fwd_reason.empty()) { set_message(message, std::move(fwd_reason)); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } - if (is_training) { - std::string bwd_reason = is_supported_fp8_bwd( - batch_size, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, - head_dim_v, attn_scale, dropout, qkv_layout, o_format, do_format, dqkv_layout, - qkv_scale_inv_format, do_scale_inv_format, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, bottom_right_diagonal, deterministic, qkv_t, o_t, - scaling_mode, handle); + if (cfg->is_training) { + std::string bwd_reason = is_supported_fp8_bwd(cfg, handle); if (!bwd_reason.empty()) { set_message(message, std::move(bwd_reason)); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; @@ -327,33 +312,22 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } if (is_f16_or_bf16) { - if (cudnn_runtime_version <= 91500 && is_training && + if (cudnn_runtime_version <= 91500 && cfg->is_training && (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && - (max_seqlen_kv % 128 != 0) && cuda_graph && - attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && - attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && - attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) { + (cfg->max_seqlen_kv % 128 != 0) && cfg->cuda_graph && + cfg->attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && + cfg->attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && + cfg->attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) { set_message(message, "Known cuDNN <= 9.15 issue with CUDA graph. Please upgrade cuDNN."); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } - const DType qkv_t = static_cast(q_dtype); - const DType o_t = static_cast(o_dtype); - std::string fwd_reason = is_supported_f16_fwd( - batch_size, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, - head_dim_v, is_training, return_max_logit, attn_scale, dropout, qkv_layout, o_format, - qkv_scale_inv_format, bias_type, attn_mask_type, softmax_type, window_size_left, - window_size_right, bottom_right_diagonal, qkv_t, o_t, scaling_mode, handle); + std::string fwd_reason = is_supported_f16_fwd(cfg, handle); if (!fwd_reason.empty()) { set_message(message, std::move(fwd_reason)); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } - if (is_training) { - std::string bwd_reason = is_supported_f16_bwd( - batch_size, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, - head_dim_v, attn_scale, dropout, qkv_layout, o_format, do_format, dqkv_layout, - qkv_scale_inv_format, do_scale_inv_format, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, bottom_right_diagonal, deterministic, qkv_t, o_t, - scaling_mode, handle); + if (cfg->is_training) { + std::string bwd_reason = is_supported_f16_bwd(cfg, handle); if (!bwd_reason.empty()) { set_message(message, std::move(bwd_reason)); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; @@ -362,10 +336,48 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( return NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; } - set_message(message, "Unsupported QKV dtype qkv_dtype=" + std::to_string(q_dtype) + " ."); + set_message(message, + "Unsupported QKV dtype qkv_dtype=" + std::to_string(cfg->qkv_dtype) + " ."); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } +// Deprecated: thin wrapper preserving the historical narrow signature. New callers should +// construct an NVTEFusedAttnConfig and call nvte_get_fused_attn_backend_v2 directly to access +// the additional fields (attn_scale, format/layout fields, scaling_mode, paged-KV/bias shape, etc.) +// that this wrapper cannot express. +NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( + bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic) { + NVTEFusedAttnConfig cfg = NVTE_FUSED_ATTN_CONFIG_INIT; + cfg.qkv_layout = qkv_layout; + cfg.dqkv_layout = qkv_layout; // legacy: gradient layout matches input layout + cfg.bias_type = bias_type; + cfg.attn_mask_type = attn_mask_type; + cfg.softmax_type = softmax_type; + cfg.attn_scale = 1.0f; // legacy default; matches the value pre-PR probes hardcoded + cfg.dropout = dropout; + cfg.max_seqlen_q = max_seqlen_q; + cfg.max_seqlen_kv = max_seqlen_kv; + cfg.window_size_left = window_size_left; + cfg.window_size_right = window_size_right; + cfg.cuda_graph = cuda_graph; + NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type."); + cfg.qkv_dtype = q_dtype; + cfg.o_dtype = q_dtype; // legacy: O dtype matches Q dtype + cfg.batch_size = 1; // legacy: pre-PR probes assumed batch=1 + cfg.num_attn_heads = num_attn_heads; + cfg.num_gqa_groups = num_gqa_groups; + cfg.head_dim_qk = head_dim_qk; + cfg.head_dim_v = head_dim_v; + cfg.is_training = is_training; + cfg.return_max_logit = return_max_logit; + cfg.deterministic = deterministic; + return nvte_get_fused_attn_backend_v2(&cfg, /*message=*/nullptr); +} + // NVTE fused attention FWD with separate Q, K and V void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, @@ -448,16 +460,61 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); const NVTEDType KV_type = static_cast(input_K->data.dtype); + NVTE_CHECK(Q_type == KV_type, "Q and KV must have the same data type."); const NVTEDType O_type = static_cast(output_O->data.dtype); const NVTEScalingMode scaling_mode = input_Q->scaling_mode; - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, b, Q_type, KV_type, O_type, scaling_mode, qkv_layout, o_format, - /*do_format=*/o_format, /*dqkv_layout=*/qkv_layout, qkv_scale_inv_format, - /*do_scale_inv_format=*/qkv_scale_inv_format, bias_type, attn_mask_type, softmax_type, - attn_scale, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, - window_size_right, bottom_right_diagonal, return_max_logit, cuda_graph, - /*deterministic=*/false, /*message=*/nullptr); + size_t bias_b = 0, bias_h = 0, bias_sq = 0, bias_skv = 0; + if (input_Bias->data.dptr != nullptr && input_Bias->data.shape.size() >= 4) { + bias_b = input_Bias->data.shape[0]; + bias_h = input_Bias->data.shape[1]; + bias_sq = input_Bias->data.shape[2]; + bias_skv = input_Bias->data.shape[3]; + } + + NVTEFusedAttnConfig cfg = NVTE_FUSED_ATTN_CONFIG_INIT; + cfg.qkv_layout = qkv_layout; + cfg.o_format = o_format; + cfg.do_format = o_format; // fwd path: same format used for dO if/when probed for bwd + cfg.dqkv_layout = qkv_layout; // fwd path: same layout used for dQKV if/when probed for bwd + cfg.qkv_scale_inv_format = qkv_scale_inv_format; + cfg.do_scale_inv_format = qkv_scale_inv_format; // fwd path: mirror QKV + cfg.bias_type = bias_type; + cfg.attn_mask_type = attn_mask_type; + cfg.softmax_type = softmax_type; + cfg.scaling_mode = scaling_mode; + cfg.attn_scale = attn_scale; + cfg.dropout = dropout; + cfg.max_seqlen_q = max_seqlen_q; + cfg.max_seqlen_kv = max_seqlen_kv; + cfg.window_size_left = window_size_left; + cfg.window_size_right = window_size_right; + cfg.bottom_right_diagonal = bottom_right_diagonal; + cfg.cuda_graph = cuda_graph; + cfg.qkv_dtype = Q_type; + cfg.o_dtype = O_type; + cfg.do_dtype = O_type; // fwd path: dO assumed to share dtype with O + cfg.dqkv_dtype = Q_type; // fwd path: dQKV assumed to share dtype with QKV + cfg.batch_size = b; + cfg.num_attn_heads = h_q; + cfg.num_gqa_groups = h_kv; + cfg.head_dim_qk = d_qk; + cfg.head_dim_v = d_v; + cfg.num_pages_k = static_cast(num_pages_k); + cfg.num_pages_v = static_cast(num_pages_v); + cfg.page_size_k = static_cast(page_size_k); + cfg.page_size_v = static_cast(page_size_v); + cfg.max_pages_per_seq_k = static_cast(max_pages_per_seq_k); + cfg.max_pages_per_seq_v = static_cast(max_pages_per_seq_v); + cfg.bias_batch_size = bias_b; + cfg.bias_num_heads = bias_h; + cfg.bias_seqlen_q = bias_sq; + cfg.bias_seqlen_kv = bias_skv; + cfg.is_training = is_training; + cfg.return_max_logit = return_max_logit; + cfg.deterministic = false; + NVTE_Fused_Attn_Backend fused_attention_backend = + nvte_get_fused_attn_backend_v2(&cfg, /*message=*/nullptr); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { fused_attn_arbitrary_seqlen_fwd( @@ -534,15 +591,45 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); const NVTEDType KV_type = static_cast(input_K->data.dtype); + NVTE_CHECK(Q_type == KV_type, "Q and KV must have the same data type."); const NVTEDType O_type = static_cast(input_O->data.dtype); + const NVTEDType dO_type = static_cast(input_dO->data.dtype); + const NVTEDType dQKV_type = static_cast(output_dQ->data.dtype); const NVTEScalingMode scaling_mode = input_Q->scaling_mode; - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - /*is_training=*/true, b, Q_type, KV_type, O_type, scaling_mode, qkv_layout, o_format, - do_format, dqkv_layout, qkv_scale_inv_format, do_scale_inv_format, bias_type, attn_mask_type, - softmax_type, attn_scale, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, - window_size_left, window_size_right, bottom_right_diagonal, /*return_max_logit=*/false, - cuda_graph, deterministic, /*message=*/nullptr); + NVTEFusedAttnConfig cfg = NVTE_FUSED_ATTN_CONFIG_INIT; + cfg.qkv_layout = qkv_layout; + cfg.o_format = o_format; + cfg.do_format = do_format; + cfg.dqkv_layout = dqkv_layout; + cfg.qkv_scale_inv_format = qkv_scale_inv_format; + cfg.do_scale_inv_format = do_scale_inv_format; + cfg.bias_type = bias_type; + cfg.attn_mask_type = attn_mask_type; + cfg.softmax_type = softmax_type; + cfg.scaling_mode = scaling_mode; + cfg.attn_scale = attn_scale; + cfg.dropout = dropout; + cfg.max_seqlen_q = max_seqlen_q; + cfg.max_seqlen_kv = max_seqlen_kv; + cfg.window_size_left = window_size_left; + cfg.window_size_right = window_size_right; + cfg.bottom_right_diagonal = bottom_right_diagonal; + cfg.cuda_graph = cuda_graph; + cfg.qkv_dtype = Q_type; + cfg.o_dtype = O_type; + cfg.do_dtype = dO_type; + cfg.dqkv_dtype = dQKV_type; + cfg.batch_size = b; + cfg.num_attn_heads = h_q; + cfg.num_gqa_groups = h_kv; + cfg.head_dim_qk = d_qk; + cfg.head_dim_v = d_v; + cfg.is_training = true; + cfg.return_max_logit = false; + cfg.deterministic = deterministic; + NVTE_Fused_Attn_Backend fused_attention_backend = + nvte_get_fused_attn_backend_v2(&cfg, /*message=*/nullptr); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { size_t i = 0; diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 372bc68288..9cdee256ed 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -1334,19 +1334,27 @@ void fused_attn_arbitrary_seqlen_bwd( } } -std::string is_supported_f16_fwd( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, - bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_QKV_Format o_format, [[maybe_unused]] NVTE_QKV_Format qkv_scale_inv_format, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - DType qkv_dtype, [[maybe_unused]] DType o_dtype, [[maybe_unused]] NVTEScalingMode scaling_mode, - cudnnHandle_t handle) { - const auto b = static_cast(batch); - const auto h = static_cast(num_attn_heads); - const auto sq = static_cast(max_seqlen_q); - const auto skv = static_cast(max_seqlen_kv); +std::string is_supported_f16_fwd(const NVTEFusedAttnConfig *cfg, cudnnHandle_t handle) { + const size_t num_gqa_groups = cfg->num_gqa_groups; + const size_t head_dim_qk = cfg->head_dim_qk; + const size_t head_dim_v = cfg->head_dim_v; + const bool is_training = cfg->is_training; + const bool return_max_logit = cfg->return_max_logit; + const float attn_scale = cfg->attn_scale; + const float p_dropout = cfg->dropout; + const NVTE_QKV_Layout qkv_layout = cfg->qkv_layout; + const NVTE_QKV_Format o_format = cfg->o_format; + const NVTE_Bias_Type bias_type = cfg->bias_type; + const NVTE_Mask_Type mask_type = cfg->attn_mask_type; + const NVTE_Softmax_Type softmax_type = cfg->softmax_type; + const int64_t window_size_left = cfg->window_size_left; + const int64_t window_size_right = cfg->window_size_right; + const bool bottom_right_diagonal = cfg->bottom_right_diagonal; + const DType qkv_dtype = static_cast(cfg->qkv_dtype); + const auto b = static_cast(cfg->batch_size); + const auto h = static_cast(cfg->num_attn_heads); + const auto sq = static_cast(cfg->max_seqlen_q); + const auto skv = static_cast(cfg->max_seqlen_kv); const NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); const NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); @@ -1396,20 +1404,28 @@ std::string is_supported_f16_fwd( } } -std::string is_supported_f16_bwd( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, - NVTE_QKV_Layout dqkv_layout, [[maybe_unused]] NVTE_QKV_Format qkv_scale_inv_format, - [[maybe_unused]] NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, DType qkv_dtype, - [[maybe_unused]] DType o_dtype, [[maybe_unused]] NVTEScalingMode scaling_mode, - cudnnHandle_t handle) { - const auto b = static_cast(batch); - const auto h = static_cast(num_attn_heads); - const auto sq = static_cast(max_seqlen_q); - const auto skv = static_cast(max_seqlen_kv); +std::string is_supported_f16_bwd(const NVTEFusedAttnConfig *cfg, cudnnHandle_t handle) { + const size_t num_gqa_groups = cfg->num_gqa_groups; + const size_t head_dim_qk = cfg->head_dim_qk; + const size_t head_dim_v = cfg->head_dim_v; + const float attn_scale = cfg->attn_scale; + const float p_dropout = cfg->dropout; + const NVTE_QKV_Layout qkv_layout = cfg->qkv_layout; + const NVTE_QKV_Format o_format = cfg->o_format; + const NVTE_QKV_Format do_format = cfg->do_format; + const NVTE_QKV_Layout dqkv_layout = cfg->dqkv_layout; + const NVTE_Bias_Type bias_type = cfg->bias_type; + const NVTE_Mask_Type mask_type = cfg->attn_mask_type; + const NVTE_Softmax_Type softmax_type = cfg->softmax_type; + const int64_t window_size_left = cfg->window_size_left; + const int64_t window_size_right = cfg->window_size_right; + const bool bottom_right_diagonal = cfg->bottom_right_diagonal; + const bool deterministic = cfg->deterministic; + const DType qkv_dtype = static_cast(cfg->qkv_dtype); + const auto b = static_cast(cfg->batch_size); + const auto h = static_cast(cfg->num_attn_heads); + const auto sq = static_cast(cfg->max_seqlen_q); + const auto skv = static_cast(cfg->max_seqlen_kv); const NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); const NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 9d6b57d0a0..5d27e82278 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -52,31 +52,12 @@ void fused_attn_arbitrary_seqlen_bwd( // check if a given configuration is supported for F16/BF16 forward; // if it is, cache the graph built for this config, and return an empty string; // if not, return a diagnostic message in the form of a string. -std::string is_supported_f16_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, bool is_training, bool return_max_logit, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_QKV_Format o_format, NVTE_QKV_Format qkv_scale_inv_format, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool bottom_right_diagonal, - DType qkv_dtype, DType o_dtype, NVTEScalingMode scaling_mode, - cudnnHandle_t handle); +std::string is_supported_f16_fwd(const NVTEFusedAttnConfig *cfg, cudnnHandle_t handle); // check if a given configuration is supported for F16/BF16 backward; // if it is, cache the graph built for this config, and return an empty string; // if not, return a diagnostic message in the form of a string. -std::string is_supported_f16_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, - NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, - NVTE_QKV_Format qkv_scale_inv_format, - NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic, DType qkv_dtype, - DType o_dtype, NVTEScalingMode scaling_mode, cudnnHandle_t handle); +std::string is_supported_f16_bwd(const NVTEFusedAttnConfig *cfg, cudnnHandle_t handle); } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 352e47fbfd..3fe5b7fb10 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1325,14 +1325,30 @@ void fused_attn_fp8_bwd( } } -std::string is_supported_fp8_fwd( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, - [[maybe_unused]] bool return_max_logit, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format qkv_scale_inv_format, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - DType qkv_dtype, DType o_dtype, NVTEScalingMode scaling_mode, cudnnHandle_t handle) { +std::string is_supported_fp8_fwd(const NVTEFusedAttnConfig *cfg, cudnnHandle_t handle) { + const size_t batch = cfg->batch_size; + const size_t num_attn_heads = cfg->num_attn_heads; + const size_t num_gqa_groups = cfg->num_gqa_groups; + const size_t max_seqlen_q = cfg->max_seqlen_q; + const size_t max_seqlen_kv = cfg->max_seqlen_kv; + const size_t head_dim_qk = cfg->head_dim_qk; + const size_t head_dim_v = cfg->head_dim_v; + const bool is_training = cfg->is_training; + const float attn_scale = cfg->attn_scale; + const float p_dropout = cfg->dropout; + const NVTE_QKV_Layout qkv_layout = cfg->qkv_layout; + const NVTE_QKV_Format o_format = cfg->o_format; + const NVTE_QKV_Format qkv_scale_inv_format = cfg->qkv_scale_inv_format; + const NVTE_Bias_Type bias_type = cfg->bias_type; + const NVTE_Mask_Type mask_type = cfg->attn_mask_type; + const NVTE_Softmax_Type softmax_type = cfg->softmax_type; + const int64_t window_size_left = cfg->window_size_left; + const int64_t window_size_right = cfg->window_size_right; + const bool bottom_right_diagonal = cfg->bottom_right_diagonal; + const DType qkv_dtype = static_cast(cfg->qkv_dtype); + const DType o_dtype = static_cast(cfg->o_dtype); + const NVTEScalingMode scaling_mode = cfg->scaling_mode; + size_t workspace_size = 0; try { fused_attn::fused_attn_fp8_fwd_impl( @@ -1360,15 +1376,33 @@ std::string is_supported_fp8_fwd( } } -std::string is_supported_fp8_bwd( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, - NVTE_QKV_Layout dqkv_layout, NVTE_QKV_Format qkv_scale_inv_format, - NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic, DType qkv_dtype, DType o_dtype, - NVTEScalingMode scaling_mode, cudnnHandle_t handle) { +std::string is_supported_fp8_bwd(const NVTEFusedAttnConfig *cfg, cudnnHandle_t handle) { + const size_t batch = cfg->batch_size; + const size_t num_attn_heads = cfg->num_attn_heads; + const size_t num_gqa_groups = cfg->num_gqa_groups; + const size_t max_seqlen_q = cfg->max_seqlen_q; + const size_t max_seqlen_kv = cfg->max_seqlen_kv; + const size_t head_dim_qk = cfg->head_dim_qk; + const size_t head_dim_v = cfg->head_dim_v; + const float attn_scale = cfg->attn_scale; + const float p_dropout = cfg->dropout; + const NVTE_QKV_Layout qkv_layout = cfg->qkv_layout; + const NVTE_QKV_Format o_format = cfg->o_format; + const NVTE_QKV_Format do_format = cfg->do_format; + const NVTE_QKV_Layout dqkv_layout = cfg->dqkv_layout; + const NVTE_QKV_Format qkv_scale_inv_format = cfg->qkv_scale_inv_format; + const NVTE_QKV_Format do_scale_inv_format = cfg->do_scale_inv_format; + const NVTE_Bias_Type bias_type = cfg->bias_type; + const NVTE_Mask_Type mask_type = cfg->attn_mask_type; + const NVTE_Softmax_Type softmax_type = cfg->softmax_type; + const int64_t window_size_left = cfg->window_size_left; + const int64_t window_size_right = cfg->window_size_right; + const bool bottom_right_diagonal = cfg->bottom_right_diagonal; + const bool deterministic = cfg->deterministic; + const DType qkv_dtype = static_cast(cfg->qkv_dtype); + const DType o_dtype = static_cast(cfg->o_dtype); + const NVTEScalingMode scaling_mode = cfg->scaling_mode; + const cudnn_frontend::DataType_t qkv_t = get_cudnn_fe_dtype(qkv_dtype); const cudnn_frontend::DataType_t o_t = get_cudnn_fe_dtype(o_dtype); const cudnn_frontend::DataType_t do_t = o_t; diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 8dfdb8c412..fc60987cf3 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -45,29 +45,10 @@ void fused_attn_fp8_bwd( // check if a given configuration is supported for FP8 forward; // if it is, cache the graph built for this config, and return an empty string; // if not, return a diagnostic message in the form of a string. -std::string is_supported_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, bool is_training, bool return_max_logit, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_QKV_Format o_format, NVTE_QKV_Format qkv_scale_inv_format, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool bottom_right_diagonal, - DType qkv_dtype, DType o_dtype, NVTEScalingMode scaling_mode, - cudnnHandle_t handle); +std::string is_supported_fp8_fwd(const NVTEFusedAttnConfig *cfg, cudnnHandle_t handle); // check if a given configuration is supported for FP8 backward; // if it is, cache the graph built for this config, and return an empty string; // if not, return a diagnostic message in the form of a string. -std::string is_supported_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, - NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, - NVTE_QKV_Format qkv_scale_inv_format, - NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic, DType qkv_dtype, - DType o_dtype, NVTEScalingMode scaling_mode, cudnnHandle_t handle); +std::string is_supported_fp8_bwd(const NVTEFusedAttnConfig *cfg, cudnnHandle_t handle); } // namespace transformer_engine diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 8e9864f916..df15148350 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -196,65 +196,140 @@ NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout); */ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); +/*! \struct NVTEFusedAttnConfig + * \brief Attention configuration. + * + * Versioning rules: + * - ``struct_size`` MUST be set to ``sizeof(NVTEFusedAttnConfig)`` by the + * caller (use ``NVTE_FUSED_ATTN_CONFIG_INIT``). + * - New fields may only be appended at the end; existing fields are never + * reordered, removed, or resized. The library reads only fields that are + * in range according to ``struct_size`` and uses safe defaults otherwise. + */ +typedef struct NVTEFusedAttnConfig { + size_t struct_size; /*!< MUST equal sizeof(NVTEFusedAttnConfig). */ + uint32_t reserved0; /*!< Padding for layout stability; set to 0. */ + uint32_t reserved1; /*!< Padding for layout stability; set to 0. */ + + NVTE_QKV_Layout qkv_layout; /*!< QKV tensors' layout. */ + NVTE_QKV_Format o_format; /*!< Output O tensor format. */ + NVTE_QKV_Format do_format; /*!< Output-grad dO tensor format (bwd). */ + NVTE_QKV_Layout dqkv_layout; /*!< Gradient dQKV tensor layout (bwd). */ + NVTE_QKV_Format qkv_scale_inv_format; /*!< QKV scale_inv tensor format (FP8). */ + NVTE_QKV_Format do_scale_inv_format; /*!< dO scale_inv tensor format (FP8 bwd). */ + NVTE_Bias_Type bias_type; /*!< Attention bias type. */ + NVTE_Mask_Type attn_mask_type; /*!< Attention mask type. */ + NVTE_Softmax_Type softmax_type; /*!< Attention softmax type. */ + NVTEScalingMode scaling_mode; /*!< Scaling mode (e.g. delayed, MXFP8). */ + float attn_scale; /*!< Pre-softmax attention scale factor. */ + float dropout; /*!< Dropout probability. */ + size_t max_seqlen_q; /*!< Max sequence length for Q. */ + size_t max_seqlen_kv; /*!< Max sequence length for K, V. */ + int64_t window_size_left; /*!< Sliding window size (left half); -1 = unlimited. */ + int64_t window_size_right; /*!< Sliding window size (right half); -1 = unlimited. */ + bool bottom_right_diagonal; /*!< Whether causal mask aligns to the bottom-right diagonal. */ + bool cuda_graph; /*!< Whether CUDA graph capture is enabled. */ + + NVTEDType qkv_dtype; /*!< Data type of Tensors Q, K, V. Q and K/V must share a dtype. */ + NVTEDType o_dtype; /*!< Data type of Tensor O. */ + NVTEDType do_dtype; /*!< Data type of Tensor dO (bwd). */ + NVTEDType dqkv_dtype; /*!< Data type of Tensors dQ, dK, dV (bwd). */ + size_t batch_size; /*!< Batch size. */ + size_t num_attn_heads; /*!< Number of heads in Q. */ + size_t num_gqa_groups; /*!< Number of heads in K, V. */ + size_t head_dim_qk; /*!< Head dimension of Q, K. */ + size_t head_dim_v; /*!< Head dimension of V. */ + + size_t num_pages_k; /*!< Total number of K cache pages. */ + size_t num_pages_v; /*!< Total number of V cache pages. */ + size_t page_size_k; /*!< Tokens per K cache page. */ + size_t page_size_v; /*!< Tokens per V cache page. */ + size_t max_pages_per_seq_k; /*!< Max K pages per sequence in the batch. */ + size_t max_pages_per_seq_v; /*!< Max V pages per sequence in the batch. */ + + size_t bias_batch_size; /*!< Bias broadcast dim for batch. */ + size_t bias_num_heads; /*!< Bias broadcast dim for heads. */ + size_t bias_seqlen_q; /*!< Bias broadcast dim for Q sequence length. */ + size_t bias_seqlen_kv; /*!< Bias broadcast dim for K/V sequence length. */ + + bool is_training; /*!< Whether the model is in training mode. */ + bool return_max_logit; /*!< Whether to produce Max along with Stats (fwd-only). */ + bool deterministic; /*!< Whether determinism is required (bwd-only). */ +} NVTEFusedAttnConfig; + +/*! \brief Default-initialize an ``NVTEFusedAttnConfig``. + * + * Sets ``struct_size`` and the categorical fields (layouts, formats, masks, + * window sizes, scaling mode) to safe NOT_SET / no-op defaults. Numeric and + * tensor-derived fields, paged-KV shape, bias broadcast shape, and direction + * flags all default to zero/false; callers must set the fields relevant to + * their query. + */ +#define NVTE_FUSED_ATTN_CONFIG_INIT \ + { \ + .struct_size = sizeof(NVTEFusedAttnConfig), \ + .qkv_layout = NVTE_QKV_Layout_NOT_SET, .o_format = NVTE_QKV_Format_NOT_SET, \ + .do_format = NVTE_QKV_Format_NOT_SET, .dqkv_layout = NVTE_QKV_Layout_NOT_SET, \ + .qkv_scale_inv_format = NVTE_QKV_Format_NOT_SET, \ + .do_scale_inv_format = NVTE_QKV_Format_NOT_SET, .bias_type = NVTE_NO_BIAS, \ + .attn_mask_type = NVTE_NO_MASK, .softmax_type = NVTE_VANILLA_SOFTMAX, \ + .scaling_mode = NVTE_DELAYED_TENSOR_SCALING, .window_size_left = -1, .window_size_right = -1, \ + } + +/*! \brief Get fused attention backend based on input parameters. + * + * This call exercises cudnn-frontend's support checks by building (and caching) + * the cuDNN execution graph for the supported configurations. The configuration + * parameters are a superset of those of ``nvte_fused_attn_fwd`` and + * ``nvte_fused_attn_bwd`` to maintain a consistent signature between graph + * building and runtime calls. + * + * \param[in] cfg Attention configuration. Must be initialized + * with ``NVTE_FUSED_ATTN_CONFIG_INIT`` and have + * ``cfg->struct_size`` set to ``sizeof(NVTEFusedAttnConfig)``. + * \param[out] message Empty on success, otherwise a diagnostic string describing + * why the configuration was rejected. The string pointer + * refers to a per-thread buffer owned by the library and + * remains valid only until the next call to + * ``nvte_get_fused_attn_backend_v2`` on the same thread; + * callers that need to retain the message across further + * calls must copy it. Pass NULL to skip diagnostics. + * + * \return Backend able to execute this configuration, or ``NVTE_No_Backend`` if none. + */ +NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend_v2(const NVTEFusedAttnConfig *cfg, + const char **message); + /*! \brief Get fused attention backend based on input parameters. * - * This call exercises cudnn-frontend's support checks by building (and caching) the - * cuDNN execution graph for the supported configurations. The configuration parameters - * are a superset of those of ``nvte_fused_attn_fwd`` and ``nvte_fused_attn_bwd`` to - * maintain a consistent signature between graph building and runtime calls. - * - * \param[in] is_training Whether the model is in training mode. - * \param[in] batch_size Batch size. - * \param[in] q_dtype The data type of Tensor Q. - * \param[in] kv_dtype The data type of Tensors K, V. - * \param[in] o_dtype The data type of Tensor O. - * \param[in] scaling_mode Scaling mode of attention. - * \param[in] qkv_layout The layout of Tensors Q, K, V. - * \param[in] o_format The format of Tensor O. - * \param[in] do_format The format of Tensor dO. - * \param[in] dqkv_layout The layout of Tensors dQ, dK, dV. - * \param[in] qkv_scale_inv_format Format of the scale-inverse tensors for QKV in FP8 - * configurations; pass NVTE_QKV_Format_NOT_SET to let the - * backend infer it from ``qkv_layout`` otherwise. - * \param[in] do_scale_inv_format Format of the scale-inverse tensor for dO in FP8 backward - * configurations; pass NVTE_QKV_Format_NOT_SET to let the - * backend infer it from ``do_format`` otherwise. - * \param[in] bias_type The attention bias type. - * \param[in] attn_mask_type The attention mask type. - * \param[in] softmax_type The attention softmax type. - * \param[in] attn_scale Scaling factor for Q * K^T. - * \param[in] dropout The dropout probability. - * \param[in] num_attn_heads The number of heads in Q. - * \param[in] num_gqa_groups The number of heads in K, V. - * \param[in] max_seqlen_q The sequence length of Q. - * \param[in] max_seqlen_kv The sequence length of K, V. - * \param[in] head_dim_qk The head dimension of Q, K. - * \param[in] head_dim_v The head dimension of V. - * \param[in] window_size_left Sliding window size (the left half). - * \param[in] window_size_right Sliding window size (the right half). - * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the - * bottom right corner of the softmax matrix. - * \param[in] return_max_logit Whether to produce Max along with Stats. - * \param[in] cuda_graph Whether cuda graph capture is enabled or not. - * \param[in] deterministic Whether determinism is required or not. - * \param[out] message Empty on success, otherwise a diagnostic string describing - * why the configuration was rejected. The string pointer refers to a - * per-thread buffer owned by the library and remains valid - * only until the next call to ``nvte_get_fused_attn_backend`` - * on the same thread; callers that need to retain the - * message across further calls must copy it. Pass NULL to - * skip diagnostics. + * \deprecated This function has been deprecated in favor of nvte_get_fused_attn_backend_v2. + * + * \param[in] is_training Whether the model is in training mode. + * \param[in] q_dtype The data type of Tensor Q. + * \param[in] kv_dtype The data type of Tensors K, V. + * \param[in] qkv_layout The layout of Tensors Q, K, V. + * \param[in] bias_type The attention bias type. + * \param[in] attn_mask_type The attention mask type. + * \param[in] softmax_type The attention softmax type. + * \param[in] dropout The dropout probability. + * \param[in] num_attn_heads The number of heads in Q. + * \param[in] num_gqa_groups The number of heads in K, V. + * \param[in] max_seqlen_q The sequence length of Q. + * \param[in] max_seqlen_kv The sequence length of K, V. + * \param[in] head_dim_qk The head dimension of Q, K. + * \param[in] head_dim_v The head dimension of V. + * \param[in] window_size_left Sliding window size (the left half). + * \param[in] window_size_right Sliding window size (the right half). + * \param[in] return_max_logit Whether to produce Max along with Stats. + * \param[in] cuda_graph Whether cuda graph capture is enabled or not. + * \param[in] deterministic Whether determinism is required or not. */ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( - bool is_training, size_t batch_size, NVTEDType q_dtype, NVTEDType kv_dtype, NVTEDType o_dtype, - NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, - NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_QKV_Format qkv_scale_inv_format, - NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, float attn_scale, float dropout, size_t num_attn_heads, - size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool return_max_logit, bool cuda_graph, bool deterministic, - const char **message); + bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic); /*! \brief Compute dot product attention with separate Q, K and V. * diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index c00feec748..e24a5c4b1b 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -129,6 +129,7 @@ class FusedAttnHelper: head_dim_v: int window_size: Tuple[int, int] bottom_right_diagonal: bool + attn_scale: float = 1.0 def is_fused_attn_kernel_available(self): """Check if there is available fused attention kernel. @@ -162,6 +163,7 @@ def get_fused_attn_backend(self): self.attn_bias_type.value, self.attn_mask_type.value, self.softmax_type.value, + self.attn_scale, self.dropout_probability, self.q_num_heads, self.kv_num_heads, @@ -380,6 +382,7 @@ def abstract( v_head_dim, config.window_size, config.bottom_right_diagonal, + attn_scale=float(config.scaling_factor), ).get_fused_attn_backend() if backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index cf455357c6..e181f7ed5a 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -152,9 +152,9 @@ std::tuple GetFusedAttnBackend( NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, float dropout_probability, size_t q_attn_heads, - size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, - size_t v_head_dim, int64_t window_size_left, int64_t window_size_right, + NVTE_Softmax_Type softmax_type, float attn_scale, float dropout_probability, + size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, + size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic); pybind11::tuple GetFusedAttnForwardWorkspaceSizes( diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 9c6f2483cd..b5e40aaf6a 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -14,12 +14,13 @@ namespace jax { std::tuple GetFusedAttnBackend( bool is_training, size_t batch_size, DType q_dtype, DType kv_dtype, DType o_dtype, NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, - NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_QKV_Format qkv_scale_inv_format, - NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, float dropout_probability, size_t q_attn_heads, - size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, - size_t v_head_dim, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic) { + NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, + NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + float attn_scale, float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, + size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic) { if (o_format == NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET) { o_format = nvte_get_q_format(qkv_layout); } @@ -29,15 +30,40 @@ std::tuple GetFusedAttnBackend( if (dqkv_layout == NVTE_QKV_Layout::NVTE_QKV_Layout_NOT_SET) { dqkv_layout = qkv_layout; } + NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type."); + + NVTEFusedAttnConfig cfg = NVTE_FUSED_ATTN_CONFIG_INIT; + cfg.qkv_layout = qkv_layout; + cfg.o_format = o_format; + cfg.do_format = do_format; + cfg.dqkv_layout = dqkv_layout; + cfg.qkv_scale_inv_format = qkv_scale_inv_format; + cfg.do_scale_inv_format = do_scale_inv_format; + cfg.bias_type = bias_type; + cfg.attn_mask_type = mask_type; + cfg.softmax_type = softmax_type; + cfg.scaling_mode = scaling_mode; + cfg.attn_scale = attn_scale; + cfg.dropout = dropout_probability; + cfg.max_seqlen_q = q_max_seqlen; + cfg.max_seqlen_kv = kv_max_seqlen; + cfg.window_size_left = window_size_left; + cfg.window_size_right = window_size_right; + cfg.bottom_right_diagonal = bottom_right_diagonal; + cfg.cuda_graph = false; + cfg.qkv_dtype = static_cast(q_dtype); + cfg.o_dtype = static_cast(o_dtype); + cfg.batch_size = batch_size; + cfg.num_attn_heads = q_attn_heads; + cfg.num_gqa_groups = kv_attn_heads; + cfg.head_dim_qk = qk_head_dim; + cfg.head_dim_v = v_head_dim; + cfg.is_training = is_training; + cfg.return_max_logit = false; + cfg.deterministic = deterministic; + const char *message = nullptr; - auto backend = nvte_get_fused_attn_backend( - is_training, batch_size, static_cast(q_dtype), static_cast(kv_dtype), - static_cast(o_dtype), scaling_mode, qkv_layout, o_format, do_format, dqkv_layout, - qkv_scale_inv_format, do_scale_inv_format, bias_type, mask_type, softmax_type, - /*attn_scale=*/1.0f, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, - kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - bottom_right_diagonal, /*return_max_logit=*/false, /*cuda_graph=*/false, deterministic, - &message); + auto backend = nvte_get_fused_attn_backend_v2(&cfg, &message); return {backend, message != nullptr ? std::string(message) : std::string()}; } @@ -277,17 +303,13 @@ static void FusedAttnForwardImpl( /* Prepare RNG state */ auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); - const NVTE_QKV_Format probe_o_format = nvte_get_q_format(qkv_layout); - auto backend = nvte_get_fused_attn_backend( - is_training, input_batch, static_cast(dtype), static_cast(dtype), - static_cast(dtype), NVTE_INVALID_SCALING, qkv_layout, probe_o_format, - /*do_format=*/probe_o_format, /*dqkv_layout=*/qkv_layout, - /*qkv_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, - /*do_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, bias_type, mask_type, - softmax_type, scaling_factor, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, - kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - bottom_right_diagonal, /*return_max_logit=*/false, /*cuda_graph=*/false, deterministic, - /*message=*/nullptr); + auto [backend, _fwd_msg] = GetFusedAttnBackend( + is_training, input_batch, dtype, dtype, dtype, NVTE_INVALID_SCALING, qkv_layout, + NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, + NVTE_QKV_Layout::NVTE_QKV_Layout_NOT_SET, NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, + NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, bias_type, mask_type, softmax_type, scaling_factor, + dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, qk_head_dim, + v_head_dim, window_size_left, window_size_right, bottom_right_diagonal, deterministic); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); /* Auxiliary tensors (to be propagated to the backward pass later) */ @@ -559,19 +581,13 @@ static void FusedAttnBackwardImpl( /* Auxiliary tensors (propagated from the forward pass) */ NVTETensorPack aux_input_tensors; nvte_tensor_pack_create(&aux_input_tensors); - // JAX uses the same layout for output / dQKV as for QKV, so derive the formats from qkv_layout. - // Scale-inv formats stay NOT_SET because JAX's fused-attn path here is non-FP8. - const NVTE_QKV_Format probe_o_format = nvte_get_q_format(qkv_layout); - auto backend = nvte_get_fused_attn_backend( - is_training, input_batch, static_cast(dtype), static_cast(dtype), - static_cast(dtype), NVTE_INVALID_SCALING, qkv_layout, probe_o_format, - /*do_format=*/probe_o_format, /*dqkv_layout=*/qkv_layout, - /*qkv_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, - /*do_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, bias_type, mask_type, - softmax_type, scaling_factor, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, - kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - bottom_right_diagonal, /*return_max_logit=*/false, /*cuda_graph=*/false, deterministic, - /*message=*/nullptr); + auto [backend, _bwd_msg] = GetFusedAttnBackend( + is_training, input_batch, dtype, dtype, dtype, NVTE_INVALID_SCALING, qkv_layout, + NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, + NVTE_QKV_Layout::NVTE_QKV_Layout_NOT_SET, NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, + NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, bias_type, mask_type, softmax_type, scaling_factor, + dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, qk_head_dim, + v_head_dim, window_size_left, window_size_right, bottom_right_diagonal, deterministic); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, softmax_aux, rng_state, bias, softmax_offset); diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 17e9a337a4..4112e7c922 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -426,6 +426,7 @@ def __init__( softmax_scale = 1.0 / math.sqrt( kv_channels if isinstance(kv_channels, int) else kv_channels[0] ) + self.softmax_scale = softmax_scale self.deterministic = ( not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) @@ -1441,6 +1442,7 @@ def forward( return_max_logit=self.return_max_logit, cuda_graph=is_graph_capturing(), num_splits=num_splits, + softmax_scale=self.softmax_scale, ) global _attention_backends if is_in_onnx_export_mode(): diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index cf2f297083..169add4d6e 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -257,6 +257,9 @@ class AttentionParams: Whether support for cuda graph capture is needed or not. num_splits : int, default = 1 The number of kernels to split attention to. + softmax_scale : float, default = 1.0 + Pre-softmax attention scale. Plumbed through to the cuDNN graph cache key so that the + backend probe builds the same execution graph the runtime call later reuses. """ qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor @@ -290,6 +293,7 @@ class AttentionParams: return_max_logit: bool = False cuda_graph: bool = False num_splits: int = 1 + softmax_scale: float = 1.0 def __eq__(self, other): """ @@ -368,6 +372,7 @@ def get_attention_backend( return_max_logit = attention_params.return_max_logit cuda_graph = attention_params.cuda_graph num_splits = attention_params.num_splits + softmax_scale = attention_params.softmax_scale # Run config logger = logging.getLogger("DotProductAttention") @@ -1262,6 +1267,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt AttnBiasType[fu_core_attention_bias_type], AttnMaskType[attn_mask_type], SoftmaxType[softmax_type], + softmax_scale, attention_dropout, num_heads, num_gqa_groups, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 74021b81b5..1e2b3d356e 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -83,10 +83,10 @@ std::tuple get_fused_attn_backend( NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool bottom_right_diagonal, bool return_max_logit, bool cuda_graph, - bool deterministic); + float attn_scale, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool return_max_logit, bool cuda_graph, bool deterministic); std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 4cda724a8b..2eff5c21b7 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -46,18 +46,44 @@ std::tuple get_fused_attn_backend( NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool bottom_right_diagonal, bool return_max_logit, bool cuda_graph, - bool deterministic) { + float attn_scale, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool return_max_logit, bool cuda_graph, bool deterministic) { + NVTEFusedAttnConfig cfg = NVTE_FUSED_ATTN_CONFIG_INIT; + cfg.qkv_layout = qkv_layout; + cfg.o_format = o_format; + cfg.do_format = do_format; + cfg.dqkv_layout = dqkv_layout; + cfg.qkv_scale_inv_format = qkv_scale_inv_format; + cfg.do_scale_inv_format = do_scale_inv_format; + cfg.bias_type = bias_type; + cfg.attn_mask_type = attn_mask_type; + cfg.softmax_type = softmax_type; + cfg.scaling_mode = scaling_mode; + cfg.attn_scale = attn_scale; + cfg.dropout = p_dropout; + cfg.max_seqlen_q = max_seqlen_q; + cfg.max_seqlen_kv = max_seqlen_kv; + cfg.window_size_left = window_size_left; + cfg.window_size_right = window_size_right; + cfg.bottom_right_diagonal = bottom_right_diagonal; + cfg.cuda_graph = cuda_graph; + NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type."); + cfg.qkv_dtype = static_cast(q_dtype); + cfg.o_dtype = static_cast(o_dtype); + cfg.batch_size = batch_size; + cfg.num_attn_heads = num_attn_heads; + cfg.num_gqa_groups = num_gqa_groups; + cfg.head_dim_qk = head_dim_qk; + cfg.head_dim_v = head_dim_v; + cfg.is_training = is_training; + cfg.return_max_logit = return_max_logit; + cfg.deterministic = deterministic; + const char *message = nullptr; - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, batch_size, static_cast(q_dtype), static_cast(kv_dtype), - static_cast(o_dtype), scaling_mode, qkv_layout, o_format, do_format, dqkv_layout, - qkv_scale_inv_format, do_scale_inv_format, bias_type, attn_mask_type, softmax_type, - /*attn_scale=*/1.0f, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, - head_dim_qk, head_dim_v, window_size_left, window_size_right, bottom_right_diagonal, - return_max_logit, cuda_graph, deterministic, &message); + NVTE_Fused_Attn_Backend fused_attention_backend = + nvte_get_fused_attn_backend_v2(&cfg, &message); return {fused_attention_backend, message != nullptr ? std::string(message) : std::string()}; } From e86fc6712bd30b71e5da474329d64639ae063d96 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 8 May 2026 21:44:13 +0000 Subject: [PATCH 20/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/fused_attn/fused_attn.cpp | 17 +++--- .../common/fused_attn/fused_attn_fp8.cu | 4 +- .../include/transformer_engine/fused_attn.h | 54 ++++++++++--------- .../jax/csrc/extensions/attention.cpp | 13 +++-- .../pytorch/csrc/extensions/attention.cpp | 3 +- 5 files changed, 47 insertions(+), 44 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 6670fd59ed..828f98b50f 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -244,7 +244,7 @@ void set_message(const char **message, std::string reason) { // select a backend for fused attention NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend_v2(const NVTEFusedAttnConfig *cfg, - const char **message) { + const char **message) { using namespace transformer_engine; set_message(message, ""); NVTE_CHECK(cfg != nullptr, "NVTEFusedAttnConfig pointer must not be NULL."); @@ -282,8 +282,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend_v2(const NVTEFusedAttnConfig const bool is_fp8 = (cfg->qkv_dtype == NVTEDType::kNVTEFloat8E4M3 || cfg->qkv_dtype == NVTEDType::kNVTEFloat8E5M2); - const bool is_f16_or_bf16 = (cfg->qkv_dtype == NVTEDType::kNVTEFloat16 || - cfg->qkv_dtype == NVTEDType::kNVTEBFloat16); + const bool is_f16_or_bf16 = + (cfg->qkv_dtype == NVTEDType::kNVTEFloat16 || cfg->qkv_dtype == NVTEDType::kNVTEBFloat16); if (is_fp8) { if (cfg->return_max_logit) { @@ -336,8 +336,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend_v2(const NVTEFusedAttnConfig return NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; } - set_message(message, - "Unsupported QKV dtype qkv_dtype=" + std::to_string(cfg->qkv_dtype) + " ."); + set_message(message, "Unsupported QKV dtype qkv_dtype=" + std::to_string(cfg->qkv_dtype) + " ."); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } @@ -475,8 +474,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTEFusedAttnConfig cfg = NVTE_FUSED_ATTN_CONFIG_INIT; cfg.qkv_layout = qkv_layout; cfg.o_format = o_format; - cfg.do_format = o_format; // fwd path: same format used for dO if/when probed for bwd - cfg.dqkv_layout = qkv_layout; // fwd path: same layout used for dQKV if/when probed for bwd + cfg.do_format = o_format; // fwd path: same format used for dO if/when probed for bwd + cfg.dqkv_layout = qkv_layout; // fwd path: same layout used for dQKV if/when probed for bwd cfg.qkv_scale_inv_format = qkv_scale_inv_format; cfg.do_scale_inv_format = qkv_scale_inv_format; // fwd path: mirror QKV cfg.bias_type = bias_type; @@ -493,8 +492,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso cfg.cuda_graph = cuda_graph; cfg.qkv_dtype = Q_type; cfg.o_dtype = O_type; - cfg.do_dtype = O_type; // fwd path: dO assumed to share dtype with O - cfg.dqkv_dtype = Q_type; // fwd path: dQKV assumed to share dtype with QKV + cfg.do_dtype = O_type; // fwd path: dO assumed to share dtype with O + cfg.dqkv_dtype = Q_type; // fwd path: dQKV assumed to share dtype with QKV cfg.batch_size = b; cfg.num_attn_heads = h_q; cfg.num_gqa_groups = h_kv; diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 3fe5b7fb10..be689f2b0c 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1325,7 +1325,7 @@ void fused_attn_fp8_bwd( } } -std::string is_supported_fp8_fwd(const NVTEFusedAttnConfig *cfg, cudnnHandle_t handle) { +std::string is_supported_fp8_fwd(const NVTEFusedAttnConfig* cfg, cudnnHandle_t handle) { const size_t batch = cfg->batch_size; const size_t num_attn_heads = cfg->num_attn_heads; const size_t num_gqa_groups = cfg->num_gqa_groups; @@ -1376,7 +1376,7 @@ std::string is_supported_fp8_fwd(const NVTEFusedAttnConfig *cfg, cudnnHandle_t h } } -std::string is_supported_fp8_bwd(const NVTEFusedAttnConfig *cfg, cudnnHandle_t handle) { +std::string is_supported_fp8_bwd(const NVTEFusedAttnConfig* cfg, cudnnHandle_t handle) { const size_t batch = cfg->batch_size; const size_t num_attn_heads = cfg->num_attn_heads; const size_t num_gqa_groups = cfg->num_gqa_groups; diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index df15148350..dba7dd68d6 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -207,9 +207,9 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); * in range according to ``struct_size`` and uses safe defaults otherwise. */ typedef struct NVTEFusedAttnConfig { - size_t struct_size; /*!< MUST equal sizeof(NVTEFusedAttnConfig). */ - uint32_t reserved0; /*!< Padding for layout stability; set to 0. */ - uint32_t reserved1; /*!< Padding for layout stability; set to 0. */ + size_t struct_size; /*!< MUST equal sizeof(NVTEFusedAttnConfig). */ + uint32_t reserved0; /*!< Padding for layout stability; set to 0. */ + uint32_t reserved1; /*!< Padding for layout stability; set to 0. */ NVTE_QKV_Layout qkv_layout; /*!< QKV tensors' layout. */ NVTE_QKV_Format o_format; /*!< Output O tensor format. */ @@ -227,18 +227,18 @@ typedef struct NVTEFusedAttnConfig { size_t max_seqlen_kv; /*!< Max sequence length for K, V. */ int64_t window_size_left; /*!< Sliding window size (left half); -1 = unlimited. */ int64_t window_size_right; /*!< Sliding window size (right half); -1 = unlimited. */ - bool bottom_right_diagonal; /*!< Whether causal mask aligns to the bottom-right diagonal. */ - bool cuda_graph; /*!< Whether CUDA graph capture is enabled. */ - - NVTEDType qkv_dtype; /*!< Data type of Tensors Q, K, V. Q and K/V must share a dtype. */ - NVTEDType o_dtype; /*!< Data type of Tensor O. */ - NVTEDType do_dtype; /*!< Data type of Tensor dO (bwd). */ - NVTEDType dqkv_dtype; /*!< Data type of Tensors dQ, dK, dV (bwd). */ - size_t batch_size; /*!< Batch size. */ - size_t num_attn_heads; /*!< Number of heads in Q. */ - size_t num_gqa_groups; /*!< Number of heads in K, V. */ - size_t head_dim_qk; /*!< Head dimension of Q, K. */ - size_t head_dim_v; /*!< Head dimension of V. */ + bool bottom_right_diagonal; /*!< Whether causal mask aligns to the bottom-right diagonal. */ + bool cuda_graph; /*!< Whether CUDA graph capture is enabled. */ + + NVTEDType qkv_dtype; /*!< Data type of Tensors Q, K, V. Q and K/V must share a dtype. */ + NVTEDType o_dtype; /*!< Data type of Tensor O. */ + NVTEDType do_dtype; /*!< Data type of Tensor dO (bwd). */ + NVTEDType dqkv_dtype; /*!< Data type of Tensors dQ, dK, dV (bwd). */ + size_t batch_size; /*!< Batch size. */ + size_t num_attn_heads; /*!< Number of heads in Q. */ + size_t num_gqa_groups; /*!< Number of heads in K, V. */ + size_t head_dim_qk; /*!< Head dimension of Q, K. */ + size_t head_dim_v; /*!< Head dimension of V. */ size_t num_pages_k; /*!< Total number of K cache pages. */ size_t num_pages_v; /*!< Total number of V cache pages. */ @@ -265,15 +265,21 @@ typedef struct NVTEFusedAttnConfig { * flags all default to zero/false; callers must set the fields relevant to * their query. */ -#define NVTE_FUSED_ATTN_CONFIG_INIT \ - { \ - .struct_size = sizeof(NVTEFusedAttnConfig), \ - .qkv_layout = NVTE_QKV_Layout_NOT_SET, .o_format = NVTE_QKV_Format_NOT_SET, \ - .do_format = NVTE_QKV_Format_NOT_SET, .dqkv_layout = NVTE_QKV_Layout_NOT_SET, \ - .qkv_scale_inv_format = NVTE_QKV_Format_NOT_SET, \ - .do_scale_inv_format = NVTE_QKV_Format_NOT_SET, .bias_type = NVTE_NO_BIAS, \ - .attn_mask_type = NVTE_NO_MASK, .softmax_type = NVTE_VANILLA_SOFTMAX, \ - .scaling_mode = NVTE_DELAYED_TENSOR_SCALING, .window_size_left = -1, .window_size_right = -1, \ +#define NVTE_FUSED_ATTN_CONFIG_INIT \ + { \ + .struct_size = sizeof(NVTEFusedAttnConfig), \ + .qkv_layout = NVTE_QKV_Layout_NOT_SET, \ + .o_format = NVTE_QKV_Format_NOT_SET, \ + .do_format = NVTE_QKV_Format_NOT_SET, \ + .dqkv_layout = NVTE_QKV_Layout_NOT_SET, \ + .qkv_scale_inv_format = NVTE_QKV_Format_NOT_SET, \ + .do_scale_inv_format = NVTE_QKV_Format_NOT_SET, \ + .bias_type = NVTE_NO_BIAS, \ + .attn_mask_type = NVTE_NO_MASK, \ + .softmax_type = NVTE_VANILLA_SOFTMAX, \ + .scaling_mode = NVTE_DELAYED_TENSOR_SCALING, \ + .window_size_left = -1, \ + .window_size_right = -1, \ } /*! \brief Get fused attention backend based on input parameters. diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index b5e40aaf6a..cd09628be0 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -14,13 +14,12 @@ namespace jax { std::tuple GetFusedAttnBackend( bool is_training, size_t batch_size, DType q_dtype, DType kv_dtype, DType o_dtype, NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, - NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, - NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - float attn_scale, float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, - size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, - int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - bool deterministic) { + NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_QKV_Format qkv_scale_inv_format, + NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, float attn_scale, float dropout_probability, + size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, + size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic) { if (o_format == NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET) { o_format = nvte_get_q_format(qkv_layout); } diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 2eff5c21b7..84f72ff879 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -82,8 +82,7 @@ std::tuple get_fused_attn_backend( cfg.deterministic = deterministic; const char *message = nullptr; - NVTE_Fused_Attn_Backend fused_attention_backend = - nvte_get_fused_attn_backend_v2(&cfg, &message); + NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend_v2(&cfg, &message); return {fused_attention_backend, message != nullptr ? std::string(message) : std::string()}; } From e2561d0d4c406b7bef4dacf8b9c7abf9db1e6c87 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 12 May 2026 16:20:35 -0700 Subject: [PATCH 21/23] fix FP8 tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 4c8435f246..681dbea2c8 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1775,12 +1775,23 @@ def test_dpa_fp8_extra_state(model, dtype): config = model_configs_fp8_extra_state[model] # Test backend availability is_training = True + fp8_recipe = recipe.DelayedScaling( + margin=0, + fp8_format=recipe.Format.HYBRID, + amax_history_len=1, + amax_compute_algo="most_recent", + fp8_dpa=True, + ) + fp8_meta = {} + fp8_meta["recipe"] = fp8_recipe available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=torch.float8_e4m3fn, qkv_layout="sb3hd", is_training=is_training, deterministic=_deterministic, + fp8=True, + fp8_meta=fp8_meta, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if not fused_attn_supported and not flash_attn_supported: @@ -2567,6 +2578,7 @@ def test_custom_mha_fp8_vs_f16(dtype, model): Both paths take F16 input and output. QKV layout is bs3hd""" config = model_configs_fp8[model] + os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "1" # Test backend availability is_training = True From 724a12f009d96a0dc63775e5cff8cae6023d0c33 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 12 May 2026 16:20:52 -0700 Subject: [PATCH 22/23] add do_dtype and dqkv_dtype to API Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn.cpp | 15 +++++-------- .../common/fused_attn/fused_attn_fp8.cu | 6 ++++-- .../jax/cpp_extensions/attention.py | 2 ++ transformer_engine/jax/csrc/extensions.h | 15 ++++++------- .../jax/csrc/extensions/attention.cpp | 21 +++++++++++-------- .../attention/dot_product_attention/utils.py | 10 +++++++++ transformer_engine/pytorch/csrc/extensions.h | 5 +++-- .../pytorch/csrc/extensions/attention.cpp | 7 +++++-- 8 files changed, 49 insertions(+), 32 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 828f98b50f..0fbd9a21ae 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -342,17 +342,17 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend_v2(const NVTEFusedAttnConfig // Deprecated: thin wrapper preserving the historical narrow signature. New callers should // construct an NVTEFusedAttnConfig and call nvte_get_fused_attn_backend_v2 directly to access -// the additional fields (attn_scale, format/layout fields, scaling_mode, paged-KV/bias shape, etc.) -// that this wrapper cannot express. +// the additional fields (attn_scale, format/layout fields, scaling_mode, paged-KV/bias shape, +// dO/dQKV dtypes, etc.) that this wrapper cannot express. NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic) { + (void)is_training; NVTEFusedAttnConfig cfg = NVTE_FUSED_ATTN_CONFIG_INIT; cfg.qkv_layout = qkv_layout; - cfg.dqkv_layout = qkv_layout; // legacy: gradient layout matches input layout cfg.bias_type = bias_type; cfg.attn_mask_type = attn_mask_type; cfg.softmax_type = softmax_type; @@ -371,7 +371,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( cfg.num_gqa_groups = num_gqa_groups; cfg.head_dim_qk = head_dim_qk; cfg.head_dim_v = head_dim_v; - cfg.is_training = is_training; + cfg.is_training = false; // legacy wrapper cannot express dO/dQKV dtypes; skip bwd probe cfg.return_max_logit = return_max_logit; cfg.deterministic = deterministic; return nvte_get_fused_attn_backend_v2(&cfg, /*message=*/nullptr); @@ -474,10 +474,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTEFusedAttnConfig cfg = NVTE_FUSED_ATTN_CONFIG_INIT; cfg.qkv_layout = qkv_layout; cfg.o_format = o_format; - cfg.do_format = o_format; // fwd path: same format used for dO if/when probed for bwd - cfg.dqkv_layout = qkv_layout; // fwd path: same layout used for dQKV if/when probed for bwd cfg.qkv_scale_inv_format = qkv_scale_inv_format; - cfg.do_scale_inv_format = qkv_scale_inv_format; // fwd path: mirror QKV cfg.bias_type = bias_type; cfg.attn_mask_type = attn_mask_type; cfg.softmax_type = softmax_type; @@ -492,8 +489,6 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso cfg.cuda_graph = cuda_graph; cfg.qkv_dtype = Q_type; cfg.o_dtype = O_type; - cfg.do_dtype = O_type; // fwd path: dO assumed to share dtype with O - cfg.dqkv_dtype = Q_type; // fwd path: dQKV assumed to share dtype with QKV cfg.batch_size = b; cfg.num_attn_heads = h_q; cfg.num_gqa_groups = h_kv; @@ -509,7 +504,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso cfg.bias_num_heads = bias_h; cfg.bias_seqlen_q = bias_sq; cfg.bias_seqlen_kv = bias_skv; - cfg.is_training = is_training; + cfg.is_training = false; cfg.return_max_logit = return_max_logit; cfg.deterministic = false; NVTE_Fused_Attn_Backend fused_attention_backend = diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index be689f2b0c..180bee2ab0 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1401,12 +1401,14 @@ std::string is_supported_fp8_bwd(const NVTEFusedAttnConfig* cfg, cudnnHandle_t h const bool deterministic = cfg->deterministic; const DType qkv_dtype = static_cast(cfg->qkv_dtype); const DType o_dtype = static_cast(cfg->o_dtype); + const DType do_dtype = static_cast(cfg->do_dtype); + const DType dqkv_dtype = static_cast(cfg->dqkv_dtype); const NVTEScalingMode scaling_mode = cfg->scaling_mode; const cudnn_frontend::DataType_t qkv_t = get_cudnn_fe_dtype(qkv_dtype); const cudnn_frontend::DataType_t o_t = get_cudnn_fe_dtype(o_dtype); - const cudnn_frontend::DataType_t do_t = o_t; - const cudnn_frontend::DataType_t dqkv_t = qkv_t; + const cudnn_frontend::DataType_t do_t = get_cudnn_fe_dtype(do_dtype); + const cudnn_frontend::DataType_t dqkv_t = get_cudnn_fe_dtype(dqkv_dtype); size_t workspace_size = 0; try { fused_attn::fused_attn_fp8_bwd_impl( diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index e24a5c4b1b..a895d8eac3 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -153,6 +153,8 @@ def get_fused_attn_backend(self): q_type, jax_dtype_to_te_dtype(self.kv_dtype), q_type, + q_type, + q_type, NVTEScalingMode.NVTE_INVALID_SCALING, self.qkv_layout.value, NVTE_QKV_Format.NVTE_QKV_Format_NOT_SET, diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index e181f7ed5a..b2adb3b042 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -149,13 +149,14 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler); std::tuple GetFusedAttnBackend( bool is_training, size_t batch_size, DType q_dtype, DType kv_dtype, DType o_dtype, - NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, - NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_QKV_Format qkv_scale_inv_format, - NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, float attn_scale, float dropout_probability, - size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, - size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic); + DType do_dtype, DType dqkv_dtype, NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, + NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + float attn_scale, float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, + size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic); pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index cd09628be0..573186b78d 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -13,13 +13,14 @@ namespace jax { std::tuple GetFusedAttnBackend( bool is_training, size_t batch_size, DType q_dtype, DType kv_dtype, DType o_dtype, - NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, - NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_QKV_Format qkv_scale_inv_format, - NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, float attn_scale, float dropout_probability, - size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, - size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, bool deterministic) { + DType do_dtype, DType dqkv_dtype, NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, + NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + float attn_scale, float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, + size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic) { if (o_format == NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET) { o_format = nvte_get_q_format(qkv_layout); } @@ -52,6 +53,8 @@ std::tuple GetFusedAttnBackend( cfg.cuda_graph = false; cfg.qkv_dtype = static_cast(q_dtype); cfg.o_dtype = static_cast(o_dtype); + cfg.do_dtype = static_cast(do_dtype); + cfg.dqkv_dtype = static_cast(dqkv_dtype); cfg.batch_size = batch_size; cfg.num_attn_heads = q_attn_heads; cfg.num_gqa_groups = kv_attn_heads; @@ -303,7 +306,7 @@ static void FusedAttnForwardImpl( auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); auto [backend, _fwd_msg] = GetFusedAttnBackend( - is_training, input_batch, dtype, dtype, dtype, NVTE_INVALID_SCALING, qkv_layout, + is_training, input_batch, dtype, dtype, dtype, dtype, dtype, NVTE_INVALID_SCALING, qkv_layout, NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, NVTE_QKV_Layout::NVTE_QKV_Layout_NOT_SET, NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, bias_type, mask_type, softmax_type, scaling_factor, @@ -581,7 +584,7 @@ static void FusedAttnBackwardImpl( NVTETensorPack aux_input_tensors; nvte_tensor_pack_create(&aux_input_tensors); auto [backend, _bwd_msg] = GetFusedAttnBackend( - is_training, input_batch, dtype, dtype, dtype, NVTE_INVALID_SCALING, qkv_layout, + is_training, input_batch, dtype, dtype, dtype, dtype, dtype, NVTE_INVALID_SCALING, qkv_layout, NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, NVTE_QKV_Layout::NVTE_QKV_Layout_NOT_SET, NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, bias_type, mask_type, softmax_type, scaling_factor, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 169add4d6e..a345e2c352 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -1229,6 +1229,8 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt q_type = TE_DType[qkv_dtype] kv_type = q_type o_type = q_type + do_type = q_type + dqkv_type = q_type scaling_mode = tex.NVTEScalingMode.NVTE_INVALID_SCALING qkv_scale_inv_format = None do_scale_inv_format = None @@ -1240,14 +1242,20 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if recipe.mxfp8(): scaling_mode = tex.NVTEScalingMode.NVTE_MXFP8_1D_SCALING o_type = TE_DType[torch.bfloat16] + do_type = TE_DType[torch.bfloat16] + dqkv_type = TE_DType[torch.bfloat16] qkv_scale_inv_format = "bhsd" do_scale_inv_format = "bhsd" elif recipe.float8_current_scaling() and cs_o_in_f16: scaling_mode = tex.NVTEScalingMode.NVTE_DELAYED_TENSOR_SCALING o_type = TE_DType[torch.bfloat16] + do_type = TE_DType[torch.bfloat16] + dqkv_type = TE_DType[torch.bfloat16] else: scaling_mode = tex.NVTEScalingMode.NVTE_DELAYED_TENSOR_SCALING o_type = q_type + do_type = o_type + dqkv_type = q_type o_format = q_format do_format = o_format dqkv_layout = qkv_layout @@ -1257,6 +1265,8 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt q_type, kv_type, o_type, + do_type, + dqkv_type, scaling_mode, QKVLayout[qkv_layout], QKVFormat[o_format], diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 1e2b3d356e..019bf5afea 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -79,8 +79,9 @@ std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::T // describing why the configuration was rejected when backend = NVTE_No_Backend. std::tuple get_fused_attn_backend( bool is_training, size_t batch_size, const DType q_dtype, const DType kv_dtype, - const DType o_dtype, NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, - NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, + const DType o_dtype, const DType do_dtype, const DType dqkv_dtype, + NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float attn_scale, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 84f72ff879..3f2a1d4399 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -42,8 +42,9 @@ namespace transformer_engine::pytorch { // get the fused attention backend std::tuple get_fused_attn_backend( bool is_training, size_t batch_size, const DType q_dtype, const DType kv_dtype, - const DType o_dtype, NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, - NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, + const DType o_dtype, const DType do_dtype, const DType dqkv_dtype, + NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float attn_scale, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, @@ -72,6 +73,8 @@ std::tuple get_fused_attn_backend( NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type."); cfg.qkv_dtype = static_cast(q_dtype); cfg.o_dtype = static_cast(o_dtype); + cfg.do_dtype = static_cast(do_dtype); + cfg.dqkv_dtype = static_cast(dqkv_dtype); cfg.batch_size = batch_size; cfg.num_attn_heads = num_attn_heads; cfg.num_gqa_groups = num_gqa_groups; From 3ae36df37a92a370a57943da30a781ff2696e6c4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 12 May 2026 23:24:23 +0000 Subject: [PATCH 23/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/extensions.h | 17 ++++++++--------- .../pytorch/csrc/extensions/attention.cpp | 17 ++++++++--------- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 019bf5afea..9931d4b3a8 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -79,15 +79,14 @@ std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::T // describing why the configuration was rejected when backend = NVTE_No_Backend. std::tuple get_fused_attn_backend( bool is_training, size_t batch_size, const DType q_dtype, const DType kv_dtype, - const DType o_dtype, const DType do_dtype, const DType dqkv_dtype, - NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, - NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, - NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - float attn_scale, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, - int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - bool return_max_logit, bool cuda_graph, bool deterministic); + const DType o_dtype, const DType do_dtype, const DType dqkv_dtype, NVTEScalingMode scaling_mode, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, + NVTE_QKV_Layout dqkv_layout, NVTE_QKV_Format qkv_scale_inv_format, + NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, float attn_scale, float p_dropout, size_t num_attn_heads, + size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool return_max_logit, bool cuda_graph, bool deterministic); std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 3f2a1d4399..afcdcae015 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -42,15 +42,14 @@ namespace transformer_engine::pytorch { // get the fused attention backend std::tuple get_fused_attn_backend( bool is_training, size_t batch_size, const DType q_dtype, const DType kv_dtype, - const DType o_dtype, const DType do_dtype, const DType dqkv_dtype, - NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, - NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, - NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - float attn_scale, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, - int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - bool return_max_logit, bool cuda_graph, bool deterministic) { + const DType o_dtype, const DType do_dtype, const DType dqkv_dtype, NVTEScalingMode scaling_mode, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, + NVTE_QKV_Layout dqkv_layout, NVTE_QKV_Format qkv_scale_inv_format, + NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, float attn_scale, float p_dropout, size_t num_attn_heads, + size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool return_max_logit, bool cuda_graph, bool deterministic) { NVTEFusedAttnConfig cfg = NVTE_FUSED_ATTN_CONFIG_INIT; cfg.qkv_layout = qkv_layout; cfg.o_format = o_format;