From 10f7ee660ad651dc67908d875e5528a06fe94ac2 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Tue, 24 Feb 2026 19:10:35 +0000 Subject: [PATCH 01/21] Integrate CK varlen cross attention for small-seq (s_q=1, s_kv<=16) Integrate the CK team's unfused variable-length attention HIP kernels from varlen_attn/ into Transformer Engine's ROCm fused-attn path as a specialized path for specialized cross-attention (Q length 1, KV length 2-16, large batch).. - Add fused_attn_smallseq.hpp and fused_attn_smallseq.cpp under fused_attn_rocm/: declarations and implementation adapted from varlen_attn/attn_fwd.cpp and attn_bwd.cpp (scores, mask+softmax, output; grad_V, grad_attn, softmax bwd, grad_Q/grad_K). Runtime dispatch over max_seqlen_kv in {2,4,6,8,12,16}, head_dim 128, BF16. - Add fused_attn_smallseq.cpp to the ROCm fused-attn build in transformer_engine/common/CMakeLists.txt. - In fused_attn_ck_fwd: when THD and no bias, branch to small-seq path when max_seqlen_q==1 and 2<=max_seqlen_kv<=16. On shape query (Aux_CTX_Tensors->size == 0) skip get_runtime_max_seqlen (cu_seqlens pointers are null); use host max_seqlen_kv and set output_S to attention-weights shape {max_tokens_q, h_q, 1, runtime_max_seqlen_kv} and dtype QKV_type. On real run (size >= 2) call get_runtime_max_seqlen then fused_attn_smallseq_fwd. Use sequence count b_varlen = max_tokens_q (not segment count b) for get_runtime_max_seqlen, output_S shape, workspace size, and small-seq fwd so varlen kernel indexing matches Q and cu_seqlens_kv (THD may pass segment-level cu_seqlens; varlen kernel expects sequence-level batch). - In fused_attn_ck_bwd: same THD/small-seq condition. On workspace query (workspace->data.dptr == nullptr) skip get_runtime_max_seqlen and use host max_seqlen_kv; on real run call get_runtime_max_seqlen then fused_attn_smallseq_bwd. Use b_varlen = max_tokens_q_bwd for get_runtime_max_seqlen, workspace size, and small-seq bwd. - Reuse softmax LSE auxiliary buffer for attention weights in the small-seq path (forward write, backward read); - JAX attention.py: in NVTE_CK block, when THD and q_max_seqlen==1 and kv_max_seqlen<=16 set softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen) and softmax_dtype = q_dtype so Python aux buffer matches C++ attention-weights convention. - Add test_ck_unfused_smallseq_backend in tests/jax/test_fused_attn.py (parametrized s_kv in {2,4,6,8,12,16}, b=30720, s_q=1, THD_THD_THD, SeqDescFormat.Seqlens) and optional NVTE_LOG_CK_SMALLSEQ debug logging in C++. --- tests/jax/test_fused_attn.py | 64 +- transformer_engine/common/CMakeLists.txt | 1 + .../common/fused_attn_rocm/fused_attn_ck.cpp | 188 ++- .../fused_attn_rocm/fused_attn_smallseq.cpp | 1049 +++++++++++++++++ .../fused_attn_rocm/fused_attn_smallseq.hpp | 89 ++ .../jax/cpp_extensions/attention.py | 27 +- 6 files changed, 1409 insertions(+), 9 deletions(-) create mode 100644 transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp create mode 100644 transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.hpp diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 4d7718cd0..114099b16 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -539,7 +539,11 @@ def generate_random_segment_ids( return segment_ids, segment_pos, segment_pad if self.qkv_layout.is_thd(): - self.num_segments_per_seq = 2 + # For very small sequence lengths, use 1 segment instead of 2 + # to avoid division by zero in segment size calculation + # Use the minimum of Q and KV sequence lengths to ensure both work + min_seqlen = min(self.max_seqlen_q, self.max_seqlen_kv) + self.num_segments_per_seq = 2 if min_seqlen > 1 else 1 self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids( self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42 ) @@ -1214,3 +1218,61 @@ def test_jax_new_rng(): ) runner = FusedAttnRunner(**kwargs) runner.test_forward() + + +# ROCm CK internal small-seq (varlen unfused) branch tests. +# Uses THD_THD_THD with s_q=1, s_kv<=16 so the small-seq path is taken. +@pytest.mark.skipif( + not is_hip_extension(), reason="CK unfused smallseq backend only available on AMD hardware" +) +@pytest.mark.parametrize( + "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype", + [ + pytest.param(30720, 1, 2, 16, 16, 128, 128, jnp.bfloat16, + id="30720-1-2-16-16-128-128-BF16"), + pytest.param(30720, 1, 4, 16, 16, 128, 128, jnp.bfloat16, + id="30720-1-4-16-16-128-128-BF16"), + pytest.param(30720, 1, 6, 16, 16, 128, 128, jnp.bfloat16, + id="30720-1-6-16-16-128-128-BF16"), + pytest.param(30720, 1, 8, 16, 16, 128, 128, jnp.bfloat16, + id="30720-1-8-16-16-128-128-BF16"), + pytest.param(30720, 1, 12, 16, 16, 128, 128, jnp.bfloat16, + id="30720-1-12-16-16-128-128-BF16"), + pytest.param(30720, 1, 16, 16, 16, 128, 128, jnp.bfloat16, + id="30720-1-16-16-16-128-128-BF16"), + ], +) +def test_ck_unfused_smallseq_backend(b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype): + """ + Test the CK unfused small-seq (varlen) path on ROCm: s_q=1, s_kv<=16, THD layout. + Uses THD_THD_THD (Q,K,V all THD). + """ + runner = FusedAttnRunner( + batch_size=b, + max_seqlen_q=s_q, + max_seqlen_kv=s_kv, + num_heads_q=h_q, + num_heads_kv=h_kv, + head_dim_qk=d_qk, + head_dim_v=d_v, + attn_bias_type=AttnBiasType.NO_BIAS, + attn_mask_type=AttnMaskType.PADDING_MASK, + dropout_prob=0.0, + use_old_rng=True, + dtype=dtype, + is_training=True, + qkv_layout=QKVLayout.THD_THD_THD, + bias_shape=None, + window_size=None, + seq_desc_format=SeqDescFormat.Seqlens, + ) + runner._setup_inputs() + expected_backend = NVTE_Fused_Attn_Backend.NVTE_CK + if runner.backend != expected_backend: + pytest.skip( + f"Backend selection failed: expected {expected_backend}, got {runner.backend}. " + f"Config: b={b}, s_q={s_q}, s_kv={s_kv}, h_q={h_q}, h_kv={h_kv}, " + f"d_qk={d_qk}, d_v={d_v}, dtype={dtype}" + ) + runner.test_forward() + runner.test_backward() diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 50dcf90a0..6774acfd2 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -200,6 +200,7 @@ else() fused_attn_rocm/fused_attn.cpp fused_attn_rocm/fused_attn_aotriton.cpp fused_attn_rocm/fused_attn_ck.cpp + fused_attn_rocm/fused_attn_smallseq.cpp fused_attn_rocm/utils.cpp gemm/rocm_gemm.cu amd_detail/system.cpp) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 7ca6fc95f..7beead7b3 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -9,6 +9,8 @@ #include // Required for std::accumulate #ifdef USE_FUSED_ATTN_CK #include +#include "../../ck_fused_attn/src/ck_fused_attn_utils.hpp" +#include "fused_attn_smallseq.hpp" #endif // USE_FUSED_ATTN_CK #include "../util/cuda_runtime.h" #include "../util/system.h" @@ -1828,18 +1830,76 @@ void fused_attn_ck_fwd( size_t max_tokens_q = std::accumulate((input_Q->data).shape.begin(), (input_Q->data).shape.end(), static_cast(1), std::multiplies())/h_q/d_qk; size_t max_tokens_kv = std::accumulate((input_K->data).shape.begin(), (input_K->data).shape.end(), static_cast(1), std::multiplies())/h_kv/d_qk; - bool is_ragged = nvte_get_qkv_format(qkv_layout)==NVTE_QKV_Format::NVTE_THD; + bool is_ragged = nvte_get_qkv_format(qkv_layout)==NVTE_QKV_Format::NVTE_THD; + size_t runtime_max_seqlen_kv = max_seqlen_kv; + bool use_small_seq = false; + const bool log_smallseq = (std::getenv("NVTE_LOG_CK_SMALLSEQ") != nullptr); + if (log_smallseq) { + std::cerr << "[CK small-seq] fused_attn_ck_fwd ENTRY: b=" << b << " h_q=" << h_q + << " max_seqlen_q=" << max_seqlen_q << " max_seqlen_kv=" << max_seqlen_kv + << " is_ragged=" << is_ragged << " Aux_CTX_size=" << Aux_CTX_Tensors->size << std::endl; + } +#ifdef USE_FUSED_ATTN_CK + // THD can pass segment-level cu_seqlens (length b). Varlen kernel expects sequence-level batch; + // when max_seqlen_q==1, max_tokens_q == number of sequences → use as batch in varlen path. + if (is_ragged && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || bias_type == NVTE_Bias_Type::NVTE_ALIBI)) { + const size_t b_varlen = max_tokens_q; + if (Aux_CTX_Tensors->size == 0) { + runtime_max_seqlen_kv = max_seqlen_kv; + use_small_seq = (max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && runtime_max_seqlen_kv <= 16); + if (log_smallseq) { + std::cerr << "[CK small-seq] FWD shape query (size==0): skip get_runtime_max_seqlen, " + << "use host max_seqlen_kv=" << max_seqlen_kv << " use_small_seq=" << use_small_seq + << std::endl; + } + } else { + if (log_smallseq) { + std::cerr << "[CK small-seq] FWD THD branch: calling get_runtime_max_seqlen (b_varlen=" << b_varlen + << " devPtrCuSeqlensKV=" << devPtrCuSeqlensKV + << " devPtrSeqOffsetsKV=" << devPtrSeqOffsetsKV << ")" << std::endl; + } + void* max_seqlen_workspace = workspace->data.dptr; + bool need_free = false; + if (max_seqlen_workspace == nullptr) { + NVTE_CHECK_CUDA(hipMalloc(&max_seqlen_workspace, sizeof(uint64_t))); + need_free = true; + } + runtime_max_seqlen_kv = static_cast(ck_fused_attn::get_runtime_max_seqlen( + static_cast(b_varlen), devPtrCuSeqlensKV, devPtrSeqOffsetsKV, + max_seqlen_workspace, reinterpret_cast(stream))); + if (need_free) { + NVTE_CHECK_CUDA(hipFree(max_seqlen_workspace)); + } + use_small_seq = (max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && runtime_max_seqlen_kv <= 16); + if (log_smallseq) { + std::cerr << "[CK small-seq FWD] get_runtime_max_seqlen returned " << runtime_max_seqlen_kv + << " use_small_seq=" << use_small_seq << std::endl; + } + if (use_small_seq && log_smallseq) { + std::cerr << "[CK small-seq FWD] Dispatch: using specialized varlen kernel. " + << "b_varlen=" << b_varlen << " h_q=" << h_q << " h_kv=" << h_kv + << " max_seqlen_q=" << max_seqlen_q << " runtime_max_seqlen_kv=" << runtime_max_seqlen_kv + << " d_qk=" << d_qk << " d_v=" << d_v << " is_training=" << is_training + << " attn_scale=" << attn_scale << " dropout=" << dropout << std::endl; + } + } + } +#endif if (Aux_CTX_Tensors->size == 0) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { Aux_CTX_Tensors->size = 3; Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; - if(is_ragged){ + if (use_small_seq) { + output_S->data.shape = {max_tokens_q, h_q, 1, runtime_max_seqlen_kv}; + output_S->data.dtype = QKV_type; + } else if(is_ragged){ output_S->data.shape = {max_tokens_q, h_q, 1}; + output_S->data.dtype = DType::kFloat32; }else{ output_S->data.shape = {b, h_q, max_seqlen_q, 1}; + output_S->data.dtype = DType::kFloat32; } - output_S->data.dtype = DType::kFloat32; Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; output_rng_state->data.shape = {2}; @@ -1852,17 +1912,33 @@ void fused_attn_ck_fwd( Aux_CTX_Tensors->size = 2; Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; - if(is_ragged){ + if (use_small_seq) { + output_S->data.shape = {max_tokens_q, h_q, 1, runtime_max_seqlen_kv}; + output_S->data.dtype = QKV_type; + } else if(is_ragged){ output_S->data.shape = {max_tokens_q, h_q, 1}; + output_S->data.dtype = DType::kFloat32; }else{ output_S->data.shape = {b, h_q, max_seqlen_q, 1}; + output_S->data.dtype = DType::kFloat32; } - output_S->data.dtype = DType::kFloat32; Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; output_rng_state->data.shape = {2}; output_rng_state->data.dtype = DType::kInt64; } + if (use_small_seq) { + if (log_smallseq) { + std::cerr << "[CK small-seq FWD] Shape query: output_S shape={max_tokens_q,h_q,1,runtime_max_seqlen_kv}=" + << "{" << max_tokens_q << "," << h_q << ",1," << runtime_max_seqlen_kv << "}, dtype=QKV_type" + << std::endl; + } + size_t small_seq_ws = fused_attn_rocm::fused_attn_smallseq_bwd_workspace_size( + max_tokens_q, h_q, runtime_max_seqlen_kv, QKV_type); + workspace->data.shape = {small_seq_ws > 8u ? small_seq_ws : 8u}; + workspace->data.dtype = DType::kByte; + return; + } } else if (Aux_CTX_Tensors->size == 2) { Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); devPtrS = output_S->data.dptr; @@ -1883,6 +1959,35 @@ void fused_attn_ck_fwd( bool is_padding = (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); + + if (use_small_seq && (Aux_CTX_Tensors->size == 2 || Aux_CTX_Tensors->size == 3)) { + if (log_smallseq) { + std::cerr << "[CK small-seq FWD] Running specialized kernel: b_varlen=" << max_tokens_q << " h_q=" << h_q + << " h_kv=" << h_kv << " runtime_max_seqlen_kv=" << runtime_max_seqlen_kv + << " d_qk=" << d_qk << " d_v=" << d_v << " is_training=" << is_training + << " attn_scale=" << attn_scale << " dropout=" << dropout + << " Aux_CTX_Tensors->size=" << Aux_CTX_Tensors->size << std::endl; + } + fused_attn_rocm::fused_attn_smallseq_fwd( + max_tokens_q, h_q, h_kv, runtime_max_seqlen_kv, d_qk, d_v, + is_training, attn_scale, dropout, + devPtrQ, devPtrK, devPtrV, devPtrO, devPtrS, + devPtrCuSeqlensKV, devPtrSeqOffsetsKV, + rng_state->data.dptr, + reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1), + QKV_type, workspace->data.dptr, &workspace_size, stream); + if (workspace_size > 0) { + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {workspace_size}; + workspace->data.dtype = DType::kByte; + return; + } + } else { + workspace->data.shape = {1}; + workspace->data.dtype = DType::kByte; + } + return; + } fused_attn_ck_fwd_impl( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, bias_b, bias_h, @@ -1967,8 +2072,79 @@ void fused_attn_ck_bwd( void *devPtrSeqOffsetsKV = input_cu_seqlens_kv_padded->data.dptr; size_t workspace_size = 0; + size_t max_tokens_q_bwd = std::accumulate((input_Q->data).shape.begin(), (input_Q->data).shape.end(), static_cast(1), std::multiplies()) / h_q / d_qk; + + bool is_ragged = nvte_get_qkv_format(qkv_layout)==NVTE_QKV_Format::NVTE_THD; + size_t runtime_max_seqlen_kv_bwd = max_seqlen_kv; + bool use_small_seq_bwd = false; + const bool log_smallseq_bwd = (std::getenv("NVTE_LOG_CK_SMALLSEQ") != nullptr); + if (log_smallseq_bwd) { + std::cerr << "[CK small-seq] fused_attn_ck_bwd ENTRY: b=" << b << " h_q=" << h_q + << " max_seqlen_q=" << max_seqlen_q << " max_seqlen_kv=" << max_seqlen_kv + << " is_ragged=" << is_ragged << std::endl; + } + // Varlen path uses sequence count (max_tokens_q) as batch; see comment in fused_attn_ck_fwd. + if (is_ragged && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || bias_type == NVTE_Bias_Type::NVTE_ALIBI)) { + const size_t b_varlen = max_tokens_q_bwd; + if (workspace->data.dptr == nullptr) { + runtime_max_seqlen_kv_bwd = max_seqlen_kv; + use_small_seq_bwd = (max_seqlen_q == 1 && runtime_max_seqlen_kv_bwd >= 2 && runtime_max_seqlen_kv_bwd <= 16); + if (log_smallseq_bwd) { + std::cerr << "[CK small-seq] BWD workspace query (workspace==null): skip get_runtime_max_seqlen, " + << "use host max_seqlen_kv=" << max_seqlen_kv << " use_small_seq_bwd=" << use_small_seq_bwd + << std::endl; + } + } else { + if (log_smallseq_bwd) { + std::cerr << "[CK small-seq] BWD THD branch: calling get_runtime_max_seqlen (b_varlen=" << b_varlen << ")" << std::endl; + } + void* max_seqlen_workspace_bwd = workspace->data.dptr; + runtime_max_seqlen_kv_bwd = static_cast(ck_fused_attn::get_runtime_max_seqlen( + static_cast(b_varlen), devPtrCuSeqlensKV, devPtrSeqOffsetsKV, + max_seqlen_workspace_bwd, reinterpret_cast(stream))); + use_small_seq_bwd = (max_seqlen_q == 1 && runtime_max_seqlen_kv_bwd >= 2 && runtime_max_seqlen_kv_bwd <= 16); + if (log_smallseq_bwd) { + std::cerr << "[CK small-seq BWD] get_runtime_max_seqlen returned " << runtime_max_seqlen_kv_bwd + << " use_small_seq_bwd=" << use_small_seq_bwd << std::endl; + } + } + if (use_small_seq_bwd && log_smallseq_bwd) { + std::cerr << "[CK small-seq BWD] Dispatch: using specialized varlen kernel. " + << "b_varlen=" << max_tokens_q_bwd << " h_q=" << h_q << " h_kv=" << h_kv + << " max_seqlen_q=" << max_seqlen_q << " runtime_max_seqlen_kv_bwd=" << runtime_max_seqlen_kv_bwd + << " d_qk=" << d_qk << " d_v=" << d_v + << " attn_scale=" << attn_scale << " dropout=" << dropout << std::endl; + } + } + if (use_small_seq_bwd) { + size_t small_seq_bwd_workspace = fused_attn_rocm::fused_attn_smallseq_bwd_workspace_size( + max_tokens_q_bwd, h_q, runtime_max_seqlen_kv_bwd, QKV_type); + if (workspace->data.dptr == nullptr) { + if (log_smallseq_bwd) { + std::cerr << "[CK small-seq BWD] Workspace query: workspace_size=" << small_seq_bwd_workspace << std::endl; + } + workspace->data.shape = {small_seq_bwd_workspace}; + workspace->data.dtype = DType::kByte; + return; + } + if (log_smallseq_bwd) { + std::cerr << "[CK small-seq BWD] Running specialized kernel: b_varlen=" << max_tokens_q_bwd << " h_q=" << h_q + << " h_kv=" << h_kv << " runtime_max_seqlen_kv_bwd=" << runtime_max_seqlen_kv_bwd + << " d_qk=" << d_qk << " d_v=" << d_v + << " attn_scale=" << attn_scale << " dropout=" << dropout << std::endl; + } + fused_attn_rocm::fused_attn_smallseq_bwd( + max_tokens_q_bwd, h_q, h_kv, runtime_max_seqlen_kv_bwd, d_qk, d_v, + attn_scale, dropout, + devPtrQ, devPtrK, devPtrV, devPtrO, devPtrdO, devPtrSoftmaxStats, + devPtrdQ, devPtrdK, devPtrdV, + devPtrCuSeqlensKV, devPtrSeqOffsetsKV, + QKV_type, workspace->data.dptr, &workspace_size, stream); + workspace->data.shape = {workspace_size > 0 ? workspace_size : 1}; + workspace->data.dtype = DType::kByte; + return; + } - bool is_ragged = nvte_get_qkv_format(qkv_layout)==NVTE_QKV_Format::NVTE_THD; bool is_padding = (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp new file mode 100644 index 000000000..b36365fb0 --- /dev/null +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp @@ -0,0 +1,1049 @@ +/************************************************************************* + * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +/*! \file fused_attn_smallseq.cpp + * \brief Unfused small-seq (varlen) attention: seq_q=1, max_seqlen_kv<=16, THD only. + * Ported from varlen_attn/attn_fwd.cpp and attn_bwd.cpp with runtime b, head_num. + */ + +#include +#include + +#include +#include +#include + +#include "../common.h" +#include "../util/cuda_runtime.h" +#include "fused_attn_smallseq.hpp" +#include "utils.h" + +namespace transformer_engine { +namespace fused_attn_rocm { + +enum class CausalMaskType { DISABLE = 0, TOP_LEFT = 1, BOTTOM_RIGHT = 2 }; + +template +struct SmallSeqConfig { + static constexpr int seq_q = 1; + static constexpr int max_seq_kv = MAX_SEQ_KV; + static constexpr int head_dim = HEAD_DIM; + static constexpr int step2_block_size = STEP2_BLOCK_SIZE; + static constexpr bool enable_dropout_mask = ENABLE_DROPOUT_MASK; + static constexpr CausalMaskType mask_type = MASK_TYPE; +}; + +// ----- Forward kernels (with runtime batch_size, head_num) ----- + +template +__global__ void compute_scores_kernel(const T* Q, + const T* K, + T* scores, + float scale, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded, + int batch_size, + int head_num) +{ + constexpr int seq_q = Config::seq_q; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + constexpr int block_k = 64; + constexpr int thread_block_size = 64; + constexpr int tasks_per_block = TASKS_PER_BLOCK; + + int base_block_offset = blockIdx.x * thread_block_size * tasks_per_block; + int thread_id = threadIdx.x; + + for (int task = 0; task < tasks_per_block; task++) { + int cur_batch_idx = base_block_offset + task * thread_block_size + thread_id; + int batch_idx = cur_batch_idx / (seq_q * head_num); + int seq_head_idx = cur_batch_idx % (seq_q * head_num); + int seq_idx = seq_head_idx / head_num; + int head_idx = seq_head_idx % head_num; + + if (batch_idx >= batch_size) + continue; + + int seq_kv = cu_seqlens_kv[batch_idx + 1] - cu_seqlens_kv[batch_idx]; + int kv_offset = cu_seqlens_kv_padded[batch_idx]; + + float results[max_seq_kv]; + T fetch_Q[block_k]; + T fetch_K[block_k]; + T* Q_ptr = (T*)&Q[(batch_idx * seq_q * head_num + seq_idx * head_num + head_idx) * head_dim]; + T* K_ptr = (T*)&K[(kv_offset * head_num + head_idx) * head_dim]; + T* score_ptr = (T*)&scores[cur_batch_idx * max_seq_kv]; + uint4 ls_dwordx4_tmp_var; + for (int i = 0; i < seq_kv; i++) + results[i] = 0.0f; + for (int dim_offset = 0; dim_offset < head_dim; dim_offset += block_k) { + if constexpr (std::is_same::value) { + for (int k = 0; k < block_k / 8; k++) { + ls_dwordx4_tmp_var = *((uint4*)&Q_ptr[dim_offset + k * 8]); + fetch_Q[k * 8 + 0] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.x)[0]; + fetch_Q[k * 8 + 1] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.x)[1]; + fetch_Q[k * 8 + 2] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.y)[0]; + fetch_Q[k * 8 + 3] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.y)[1]; + fetch_Q[k * 8 + 4] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.z)[0]; + fetch_Q[k * 8 + 5] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.z)[1]; + fetch_Q[k * 8 + 6] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.w)[0]; + fetch_Q[k * 8 + 7] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.w)[1]; + } + for (int kv_idx = 0; kv_idx < seq_kv; kv_idx++) { + for (int k = 0; k < block_k / 8; k++) { + ls_dwordx4_tmp_var = + *((uint4*)&K_ptr[kv_idx * head_num * head_dim + dim_offset + k * 8]); + fetch_K[k * 8 + 0] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.x)[0]; + fetch_K[k * 8 + 1] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.x)[1]; + fetch_K[k * 8 + 2] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.y)[0]; + fetch_K[k * 8 + 3] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.y)[1]; + fetch_K[k * 8 + 4] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.z)[0]; + fetch_K[k * 8 + 5] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.z)[1]; + fetch_K[k * 8 + 6] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.w)[0]; + fetch_K[k * 8 + 7] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.w)[1]; + } +#pragma unroll + for (int k = 0; k < block_k; k++) + results[kv_idx] += static_cast(fetch_Q[k]) * static_cast(fetch_K[k]); + } + } else { + for (int k = 0; k < block_k / 4; k++) { + ls_dwordx4_tmp_var = *((uint4*)&Q_ptr[dim_offset + k * 4]); + fetch_Q[k * 4 + 0] = *((T*)&ls_dwordx4_tmp_var.x); + fetch_Q[k * 4 + 1] = *((T*)&ls_dwordx4_tmp_var.y); + fetch_Q[k * 4 + 2] = *((T*)&ls_dwordx4_tmp_var.z); + fetch_Q[k * 4 + 3] = *((T*)&ls_dwordx4_tmp_var.w); + } + for (int kv_idx = 0; kv_idx < seq_kv; kv_idx++) { + for (int k = 0; k < block_k / 4; k++) { + ls_dwordx4_tmp_var = + *((uint4*)&K_ptr[kv_idx * head_num * head_dim + dim_offset + k * 4]); + fetch_K[k * 4 + 0] = *((T*)&ls_dwordx4_tmp_var.x); + fetch_K[k * 4 + 1] = *((T*)&ls_dwordx4_tmp_var.y); + fetch_K[k * 4 + 2] = *((T*)&ls_dwordx4_tmp_var.z); + fetch_K[k * 4 + 3] = *((T*)&ls_dwordx4_tmp_var.w); + } +#pragma unroll + for (int k = 0; k < block_k; k++) + results[kv_idx] += fetch_Q[k] * fetch_K[k]; + } + } + } + for (int i = 0; i < seq_kv; i++) + score_ptr[i] = T(results[i] * scale); + for (int i = seq_kv; i < max_seq_kv; i++) + score_ptr[i] = T(-1e9f); + } +} + +template +__global__ void apply_mask_and_softmax_kernel(T* scores, + const T* dropout_mask, + float dropout_scale, + const int* cu_seqlens_kv, + int batch_size, + int head_num) +{ + const uint32_t block_id = blockIdx.x; + const uint32_t thread_id = threadIdx.x; + constexpr int seq_q = Config::seq_q; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int block_size = Config::step2_block_size; + constexpr int per_score_size = seq_q * max_seq_kv; + constexpr int valid_thread_range = block_size / per_score_size * per_score_size; + const uint32_t cur_block_offset = block_id * valid_thread_range + thread_id; + const uint32_t total_elt = static_cast(batch_size) * head_num * seq_q * max_seq_kv; + bool is_tail = block_id * valid_thread_range + block_size >= total_elt; + int real_row_num = + is_tail ? (total_elt - block_id * valid_thread_range) / max_seq_kv + : valid_thread_range / max_seq_kv; + + if (cur_block_offset < total_elt && thread_id < valid_thread_range) { + __shared__ T tmp_scores[valid_thread_range]; + constexpr int row_num = valid_thread_range / max_seq_kv; + __shared__ T row_max[row_num]; + __shared__ T row_sum[row_num]; + + int global_row_idx = cur_block_offset / max_seq_kv; + int batch_idx = global_row_idx / (seq_q * head_num); + int k_idx = cur_block_offset % max_seq_kv; + + int seq_kv = (batch_idx < batch_size) + ? (cu_seqlens_kv[batch_idx + 1] - cu_seqlens_kv[batch_idx]) + : max_seq_kv; + + T score_value = scores[cur_block_offset]; + tmp_scores[thread_id] = score_value; + + if constexpr (Config::mask_type == CausalMaskType::TOP_LEFT) { + int q_idx = (cur_block_offset % (seq_q * max_seq_kv)) / max_seq_kv; + if (k_idx > q_idx || k_idx >= seq_kv) + tmp_scores[thread_id] = T(-1e9f); + } else if constexpr (Config::mask_type == CausalMaskType::BOTTOM_RIGHT) { + int q_idx = (cur_block_offset % (seq_q * max_seq_kv)) / max_seq_kv; + if (k_idx < q_idx || k_idx >= seq_kv) + tmp_scores[thread_id] = T(-1e9f); + } else { + if (k_idx >= seq_kv) + tmp_scores[thread_id] = T(-1e9f); + } + __syncthreads(); + + if (thread_id < real_row_num) { + T max_val = T(-1e9f); +#pragma unroll + for (int i = 0; i < max_seq_kv; i++) + max_val = fmaxf(static_cast(max_val), + static_cast(tmp_scores[thread_id * max_seq_kv + i])); + row_max[thread_id] = max_val; + } + __syncthreads(); + + T exp_val = T(expf(static_cast(tmp_scores[thread_id] - + row_max[thread_id / max_seq_kv]))); + tmp_scores[thread_id] = exp_val; + __syncthreads(); + + if (thread_id < real_row_num) { + T sum = T(0.0f); +#pragma unroll + for (int i = 0; i < max_seq_kv; i++) + sum += tmp_scores[thread_id * max_seq_kv + i]; + row_sum[thread_id] = sum; + } + __syncthreads(); + + T attn_weight = tmp_scores[thread_id] / row_sum[thread_id / max_seq_kv]; + if constexpr (Config::enable_dropout_mask) { + attn_weight = attn_weight * dropout_mask[cur_block_offset] * dropout_scale; + } + scores[cur_block_offset] = attn_weight; + } +} + +template +__global__ void compute_output_kernel(const T* attn_weights, + const T* V, + T* O, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded, + int batch_size, + int head_num) +{ + constexpr int seq_q = Config::seq_q; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + constexpr int block_k = BLOCK_K; + constexpr int dwordx4_load_elt = 16 / sizeof(T); + constexpr int warp_size = 64; + constexpr int process_head_per_warp = warp_size / (head_dim / block_k); + constexpr int tasks_per_block = TASKS_PER_BLOCK; + + int base_block_offset = blockIdx.x * process_head_per_warp * tasks_per_block; + int thread_id = threadIdx.x; + int thread_batch_offset = thread_id / (head_dim / block_k); + int thread_head_offset = thread_id % (head_dim / block_k) * block_k; + + uint4 load_dwordx4_tmp_var[block_k / dwordx4_load_elt], + store_dwordx4_tmp_var[block_k / dwordx4_load_elt]; + T attn[max_seq_kv]; + + for (int task = 0; task < tasks_per_block; task++) { + int block_batch_head_idx = base_block_offset + task * process_head_per_warp; + int cur_idx = block_batch_head_idx + thread_batch_offset; + + int batch_idx = cur_idx / (seq_q * head_num); + int seq_head_idx = cur_idx % (seq_q * head_num); + int seq_q_idx = seq_head_idx / head_num; + int head_idx = seq_head_idx % head_num; + + if (batch_idx >= batch_size) + continue; + + int seq_kv = cu_seqlens_kv[batch_idx + 1] - cu_seqlens_kv[batch_idx]; + int kv_offset = cu_seqlens_kv_padded[batch_idx]; + +#pragma unroll + for (int i = 0; i < block_k / dwordx4_load_elt; i++) { + store_dwordx4_tmp_var[i].x = 0; + store_dwordx4_tmp_var[i].y = 0; + store_dwordx4_tmp_var[i].z = 0; + store_dwordx4_tmp_var[i].w = 0; + } +#pragma unroll + for (int i = 0; i < max_seq_kv; i++) + attn[i] = attn_weights[cur_idx * max_seq_kv + i]; + for (int j = 0; j < seq_kv; j++) { +#pragma unroll + for (int i = 0; i < block_k / dwordx4_load_elt; i++) { + load_dwordx4_tmp_var[i] = + *((uint4*)&V[((kv_offset + j) * head_num + head_idx) * head_dim + thread_head_offset + + i * dwordx4_load_elt]); + } +#pragma unroll + for (int b = 0; b < block_k; b++) + ((T*)&store_dwordx4_tmp_var[b / dwordx4_load_elt])[b % dwordx4_load_elt] += + attn[j] * + ((T*)&load_dwordx4_tmp_var[b / dwordx4_load_elt])[b % dwordx4_load_elt]; + } +#pragma unroll + for (int i = 0; i < block_k / dwordx4_load_elt; i++) + *((uint4*)&O[(batch_idx * seq_q * head_num + seq_q_idx * head_num + head_idx) * head_dim + + thread_head_offset + i * dwordx4_load_elt]) = store_dwordx4_tmp_var[i]; + } +} + +// ----- Forward launcher ----- + +template +void run_attn_fwd_impl(int b, + int head_num, + const T* Q, + const T* K, + const T* V, + const T* dropout_mask, + float dropout_p, + float sqr_dk_scale, + T* O, + T* workspace, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded, + hipStream_t stream) +{ + constexpr int seq_q = Config::seq_q; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + constexpr int warp_size = 64; + + int merge_bs = b * head_num; + float scale = sqr_dk_scale; + float dropout_scale = (dropout_p > 0.0f) ? (1.0f / (1.0f - dropout_p)) : 1.0f; + + constexpr int kernel1_threads = 64; + dim3 block(kernel1_threads); + dim3 grid((merge_bs + kernel1_threads - 1) / kernel1_threads); + compute_scores_kernel<<>>( + Q, K, workspace, scale, cu_seqlens_kv, cu_seqlens_kv_padded, b, head_num); + + constexpr int work_thread_num = + Config::step2_block_size / (seq_q * max_seq_kv) * (seq_q * max_seq_kv); + dim3 grid2((merge_bs * seq_q * max_seq_kv + work_thread_num - 1) / work_thread_num); + dim3 block2(Config::step2_block_size); + apply_mask_and_softmax_kernel<<>>( + workspace, dropout_mask, dropout_scale, cu_seqlens_kv, b, head_num); + + constexpr int kernel3_block_k = 8; + constexpr int kernel3_threads = 64; + constexpr int process_head_per_warp = warp_size / (head_dim / kernel3_block_k); + + dim3 block3(kernel3_threads); + dim3 grid3((merge_bs / process_head_per_warp + 2 - 1) / 2); + compute_output_kernel<<>>( + workspace, V, O, cu_seqlens_kv, cu_seqlens_kv_padded, b, head_num); +} + +// ----- Backward kernels (with runtime batch_size, head_num) ----- + +template +__global__ void compute_grad_v_kernel(const T* attn_weights, + const T* grad_O, + T* grad_V, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded, + int batch_size, + int head_num) +{ + constexpr int seq_q = Config::seq_q; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + constexpr int block_k = BLOCK_K; + constexpr int dwordx4_load_elt = 16 / sizeof(T); + constexpr int warp_size = 64; + constexpr int process_head_per_warp = warp_size / (head_dim / block_k); + constexpr int tasks_per_block = TASKS_PER_BLOCK; + + int base_block_offset = blockIdx.x * process_head_per_warp * tasks_per_block; + int thread_id = threadIdx.x; + int thread_batch_offset = thread_id / (head_dim / block_k); + int thread_head_offset = thread_id % (head_dim / block_k) * block_k; + + uint4 load_dwordx4_tmp_var[block_k / dwordx4_load_elt]; + T attn[max_seq_kv]; + + for (int task = 0; task < tasks_per_block; task++) { + int block_batch_head_idx = base_block_offset + task * process_head_per_warp; + int cur_idx = block_batch_head_idx + thread_batch_offset; + + int batch_idx = cur_idx / (seq_q * head_num); + int seq_head_idx = cur_idx % (seq_q * head_num); + int seq_q_idx = seq_head_idx / head_num; + int head_idx = seq_head_idx % head_num; + + if (batch_idx >= batch_size) + continue; + + int seq_kv = cu_seqlens_kv[batch_idx + 1] - cu_seqlens_kv[batch_idx]; + +#pragma unroll + for (int i = 0; i < max_seq_kv; i++) + attn[i] = attn_weights[cur_idx * max_seq_kv + i]; + + for (int j = 0; j < seq_kv; j++) { + uint4 store_dwordx4_tmp_var[block_k / dwordx4_load_elt]; +#pragma unroll + for (int i = 0; i < block_k / dwordx4_load_elt; i++) { + store_dwordx4_tmp_var[i].x = 0; + store_dwordx4_tmp_var[i].y = 0; + store_dwordx4_tmp_var[i].z = 0; + store_dwordx4_tmp_var[i].w = 0; + } + +#pragma unroll + for (int i = 0; i < block_k / dwordx4_load_elt; i++) { + load_dwordx4_tmp_var[i] = + *((uint4*)&grad_O[(batch_idx * seq_q * head_num + seq_q_idx * head_num + head_idx) * + head_dim + + thread_head_offset + i * dwordx4_load_elt]); + } + +#pragma unroll + for (int b = 0; b < block_k; b++) { + ((T*)&store_dwordx4_tmp_var[b / dwordx4_load_elt])[b % dwordx4_load_elt] += + attn[j] * + ((T*)&load_dwordx4_tmp_var[b / dwordx4_load_elt])[b % dwordx4_load_elt]; + } + +#pragma unroll + for (int i = 0; i < block_k / dwordx4_load_elt; i++) { + int grad_v_idx = (cu_seqlens_kv_padded[batch_idx] + j) * head_num * head_dim + + head_idx * head_dim + thread_head_offset + i * dwordx4_load_elt; + *((uint4*)&grad_V[grad_v_idx]) = store_dwordx4_tmp_var[i]; + } + } + } +} + +template +__global__ void compute_grad_attn_kernel(const T* grad_O, + const T* V, + T* grad_attn, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded, + int batch_size, + int head_num) +{ + constexpr int seq_q = Config::seq_q; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + constexpr int block_k = 64; + constexpr int thread_block_size = 64; + constexpr int tasks_per_block = TASKS_PER_BLOCK; + + int base_block_offset = blockIdx.x * thread_block_size * tasks_per_block; + int thread_id = threadIdx.x; + + for (int task = 0; task < tasks_per_block; task++) { + int cur_batch_idx = base_block_offset + task * thread_block_size + thread_id; + int batch_idx = cur_batch_idx / (seq_q * head_num); + int seq_head_idx = cur_batch_idx % (seq_q * head_num); + int seq_idx = seq_head_idx / head_num; + int head_idx = seq_head_idx % head_num; + + if (batch_idx >= batch_size) + continue; + + int seq_kv = cu_seqlens_kv[batch_idx + 1] - cu_seqlens_kv[batch_idx]; + + float results[max_seq_kv]; + T fetch_grad_O[block_k]; + T fetch_V[block_k]; + + T* grad_O_ptr = (T*)&grad_O[(batch_idx * seq_q * head_num + seq_idx * head_num + head_idx) * + head_dim]; + + const T* V_base = + &V[cu_seqlens_kv_padded[batch_idx] * head_num * head_dim + head_idx * head_dim]; + int V_stride = head_num * head_dim; + + T* grad_attn_ptr = (T*)&grad_attn[cur_batch_idx * max_seq_kv]; + + uint4 ls_dwordx4_tmp_var; + + for (int i = 0; i < seq_kv; i++) + results[i] = 0.0f; + + for (int dim_offset = 0; dim_offset < head_dim; dim_offset += block_k) { + if constexpr (std::is_same::value) { + for (int k = 0; k < block_k / 8; k++) { + ls_dwordx4_tmp_var = *((uint4*)&grad_O_ptr[dim_offset + k * 8]); + fetch_grad_O[k * 8 + 0] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.x)[0]; + fetch_grad_O[k * 8 + 1] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.x)[1]; + fetch_grad_O[k * 8 + 2] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.y)[0]; + fetch_grad_O[k * 8 + 3] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.y)[1]; + fetch_grad_O[k * 8 + 4] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.z)[0]; + fetch_grad_O[k * 8 + 5] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.z)[1]; + fetch_grad_O[k * 8 + 6] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.w)[0]; + fetch_grad_O[k * 8 + 7] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.w)[1]; + } + for (int kv_idx = 0; kv_idx < seq_kv; kv_idx++) { + for (int k = 0; k < block_k / 8; k++) { + ls_dwordx4_tmp_var = + *((uint4*)&V_base[kv_idx * V_stride + dim_offset + k * 8]); + fetch_V[k * 8 + 0] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.x)[0]; + fetch_V[k * 8 + 1] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.x)[1]; + fetch_V[k * 8 + 2] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.y)[0]; + fetch_V[k * 8 + 3] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.y)[1]; + fetch_V[k * 8 + 4] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.z)[0]; + fetch_V[k * 8 + 5] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.z)[1]; + fetch_V[k * 8 + 6] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.w)[0]; + fetch_V[k * 8 + 7] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.w)[1]; + } +#pragma unroll + for (int k = 0; k < block_k; k++) + results[kv_idx] += + static_cast(fetch_grad_O[k]) * static_cast(fetch_V[k]); + } + } else { + for (int k = 0; k < block_k / 4; k++) { + ls_dwordx4_tmp_var = *((uint4*)&grad_O_ptr[dim_offset + k * 4]); + fetch_grad_O[k * 4 + 0] = *((T*)&ls_dwordx4_tmp_var.x); + fetch_grad_O[k * 4 + 1] = *((T*)&ls_dwordx4_tmp_var.y); + fetch_grad_O[k * 4 + 2] = *((T*)&ls_dwordx4_tmp_var.z); + fetch_grad_O[k * 4 + 3] = *((T*)&ls_dwordx4_tmp_var.w); + } + for (int kv_idx = 0; kv_idx < seq_kv; kv_idx++) { + for (int k = 0; k < block_k / 4; k++) { + ls_dwordx4_tmp_var = + *((uint4*)&V_base[kv_idx * V_stride + dim_offset + k * 4]); + fetch_V[k * 4 + 0] = *((T*)&ls_dwordx4_tmp_var.x); + fetch_V[k * 4 + 1] = *((T*)&ls_dwordx4_tmp_var.y); + fetch_V[k * 4 + 2] = *((T*)&ls_dwordx4_tmp_var.z); + fetch_V[k * 4 + 3] = *((T*)&ls_dwordx4_tmp_var.w); + } +#pragma unroll + for (int k = 0; k < block_k; k++) + results[kv_idx] += fetch_grad_O[k] * fetch_V[k]; + } + } + } + for (int i = 0; i < seq_kv; i++) + grad_attn_ptr[i] = T(results[i]); + for (int i = seq_kv; i < max_seq_kv; i++) + grad_attn_ptr[i] = T(0.0f); + } +} + +template +__global__ void softmax_backward_kernel(const T* attn_weights, + const T* dropout_mask, + T* grad_attn, + float dropout_scale, + const int* cu_seqlens_kv, + int batch_size, + int head_num) +{ + const uint32_t block_id = blockIdx.x; + const uint32_t thread_id = threadIdx.x; + constexpr int seq_q = Config::seq_q; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int block_size = Config::step2_block_size; + constexpr int per_grad_attn_size = seq_q * max_seq_kv; + constexpr int valid_thread_range = block_size / per_grad_attn_size * per_grad_attn_size; + const uint32_t cur_block_offset = block_id * valid_thread_range + thread_id; + const uint32_t total_elt = static_cast(batch_size) * head_num * seq_q * max_seq_kv; + bool is_tail = block_id * valid_thread_range + block_size >= total_elt; + int real_row_num = + is_tail ? (total_elt - block_id * valid_thread_range) / max_seq_kv + : valid_thread_range / max_seq_kv; + + if (cur_block_offset < total_elt && thread_id < valid_thread_range) { + __shared__ T tmp_grad_score[valid_thread_range]; + constexpr int row_num = valid_thread_range / max_seq_kv; + __shared__ T reduce_grad_score[row_num]; + + int global_row_idx = cur_block_offset / max_seq_kv; + int batch_idx = global_row_idx / (seq_q * head_num); + int k_idx = cur_block_offset % max_seq_kv; + + int seq_kv = (batch_idx < batch_size) + ? (cu_seqlens_kv[batch_idx + 1] - cu_seqlens_kv[batch_idx]) + : max_seq_kv; + + T grad_attn_value = grad_attn[cur_block_offset]; + if constexpr (Config::enable_dropout_mask) + grad_attn_value = grad_attn_value * dropout_mask[cur_block_offset] * dropout_scale; + T attn_weight = attn_weights[cur_block_offset]; + T grad_score = grad_attn_value * attn_weight; + tmp_grad_score[thread_id] = grad_score; + __syncthreads(); + + if (thread_id < real_row_num) { + T sum = T(0.0f); +#pragma unroll + for (int i = 0; i < max_seq_kv; i++) + sum += tmp_grad_score[thread_id * max_seq_kv + i]; + reduce_grad_score[thread_id] = sum; + } + __syncthreads(); + + grad_score -= attn_weight * reduce_grad_score[thread_id / max_seq_kv]; + + if constexpr (Config::mask_type == CausalMaskType::TOP_LEFT) { + int q_idx = (cur_block_offset % (seq_q * max_seq_kv)) / max_seq_kv; + if (k_idx > q_idx || k_idx >= seq_kv) + grad_score = T(0.0f); + } else if constexpr (Config::mask_type == CausalMaskType::BOTTOM_RIGHT) { + int q_idx = (cur_block_offset % (seq_q * max_seq_kv)) / max_seq_kv; + if (k_idx < q_idx || k_idx >= seq_kv) + grad_score = T(0.0f); + } else { + if (k_idx >= seq_kv) + grad_score = T(0.0f); + } + grad_attn[cur_block_offset] = grad_score; + } +} + +template +__global__ void compute_grad_qk_kernel(const T* grad_scores, + const T* Q, + const T* K, + T* grad_Q, + T* grad_K, + float scale, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded, + int batch_size, + int head_num) +{ + constexpr int seq_q = Config::seq_q; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + constexpr int block_k = BLOCK_K; + constexpr int dwordx4_load_elt = 16 / sizeof(T); + constexpr int warp_size = 64; + constexpr int process_head_per_warp = warp_size / (head_dim / block_k); + constexpr int tasks_per_block = TASKS_PER_BLOCK; + + int base_block_offset = blockIdx.x * process_head_per_warp * tasks_per_block; + int thread_id = threadIdx.x; + int thread_batch_offset = thread_id / (head_dim / block_k); + int thread_head_offset = thread_id % (head_dim / block_k) * block_k; + + uint4 load_dwordx4_tmp_var[block_k / dwordx4_load_elt]; + T grad_score_vals[max_seq_kv]; + + for (int task = 0; task < tasks_per_block; task++) { + int block_batch_head_idx = base_block_offset + task * process_head_per_warp; + int cur_idx = block_batch_head_idx + thread_batch_offset; + + int batch_idx = cur_idx / (seq_q * head_num); + int seq_head_idx = cur_idx % (seq_q * head_num); + int seq_q_idx = seq_head_idx / head_num; + int head_idx = seq_head_idx % head_num; + + if (batch_idx >= batch_size) + continue; + + int seq_kv = cu_seqlens_kv[batch_idx + 1] - cu_seqlens_kv[batch_idx]; + +#pragma unroll + for (int i = 0; i < max_seq_kv; i++) + grad_score_vals[i] = grad_scores[cur_idx * max_seq_kv + i]; + + uint4 store_dwordx4_tmp_var[block_k / dwordx4_load_elt]; +#pragma unroll + for (int i = 0; i < block_k / dwordx4_load_elt; i++) { + store_dwordx4_tmp_var[i].x = 0; + store_dwordx4_tmp_var[i].y = 0; + store_dwordx4_tmp_var[i].z = 0; + store_dwordx4_tmp_var[i].w = 0; + } + for (int j = 0; j < seq_kv; j++) { +#pragma unroll + for (int i = 0; i < block_k / dwordx4_load_elt; i++) { + int k_idx = (cu_seqlens_kv_padded[batch_idx] + j) * head_num * head_dim + + head_idx * head_dim + thread_head_offset + i * dwordx4_load_elt; + load_dwordx4_tmp_var[i] = *((uint4*)&K[k_idx]); + } +#pragma unroll + for (int b = 0; b < block_k; b++) { + ((T*)&store_dwordx4_tmp_var[b / dwordx4_load_elt])[b % dwordx4_load_elt] += + grad_score_vals[j] * + ((T*)&load_dwordx4_tmp_var[b / dwordx4_load_elt])[b % dwordx4_load_elt]; + } + } +#pragma unroll + for (int i = 0; i < block_k / dwordx4_load_elt; i++) { + T* grad_Q_ptr = &grad_Q[(batch_idx * seq_q * head_num + seq_q_idx * head_num + head_idx) * + head_dim + + thread_head_offset + i * dwordx4_load_elt]; + for (int b = 0; b < dwordx4_load_elt; b++) + grad_Q_ptr[b] = ((T*)&store_dwordx4_tmp_var[i])[b] * scale; + } +#pragma unroll + for (int i = 0; i < block_k / dwordx4_load_elt; i++) { + load_dwordx4_tmp_var[i] = + *((uint4*)&Q[(batch_idx * seq_q * head_num + seq_q_idx * head_num + head_idx) * head_dim + + thread_head_offset + i * dwordx4_load_elt]); + } + for (int j = 0; j < seq_kv; j++) { +#pragma unroll + for (int b = 0; b < block_k; b++) { + T val = grad_score_vals[j] * + ((T*)&load_dwordx4_tmp_var[b / dwordx4_load_elt])[b % dwordx4_load_elt] * + T(scale); + int grad_k_idx = (cu_seqlens_kv_padded[batch_idx] + j) * head_num * head_dim + + head_idx * head_dim + thread_head_offset + b; + grad_K[grad_k_idx] = val; + } + } + } +} + +template +void run_attn_bwd_impl(int b, + int head_num, + const T* Q, + const T* K, + const T* V, + const T* grad_O, + const T* attn_weights, + const T* dropout_mask, + float dropout_p, + float sqr_dk_scale, + T* grad_Q, + T* grad_K, + T* grad_V, + T* workspace, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded, + hipStream_t stream) +{ + constexpr int seq_q = Config::seq_q; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + constexpr int warp_size = 64; + + int merge_bs = b * head_num; + float scale = sqr_dk_scale; + float dropout_scale = (dropout_p > 0.0f) ? (1.0f / (1.0f - dropout_p)) : 1.0f; + + dim3 block(warp_size); + constexpr int tasks_per_block_v = 16; + dim3 grid_v((b * seq_q * head_num + tasks_per_block_v - 1) / tasks_per_block_v); + compute_grad_v_kernel<<>>( + attn_weights, grad_O, grad_V, cu_seqlens_kv, cu_seqlens_kv_padded, b, head_num); + + constexpr int tasks_per_block_attn = 16; + constexpr int process_head_per_warp = warp_size / (head_dim / 64); + dim3 grid_grad_attn((b * seq_q * head_num + tasks_per_block_attn * process_head_per_warp - 1) / + (tasks_per_block_attn * process_head_per_warp)); + compute_grad_attn_kernel<<>>( + grad_O, V, workspace, cu_seqlens_kv, cu_seqlens_kv_padded, b, head_num); + + constexpr int work_thread_num = + Config::step2_block_size / (seq_q * max_seq_kv) * (seq_q * max_seq_kv); + dim3 grid_softmax((merge_bs * seq_q * max_seq_kv + work_thread_num - 1) / work_thread_num); + dim3 block_softmax(Config::step2_block_size); + softmax_backward_kernel<<>>( + attn_weights, dropout_mask, workspace, dropout_scale, cu_seqlens_kv, b, head_num); + + constexpr int tasks_per_block_qk = 4; + dim3 grid_qk((b * seq_q * head_num + tasks_per_block_qk - 1) / tasks_per_block_qk); + compute_grad_qk_kernel<<>>( + workspace, Q, K, grad_Q, grad_K, scale, cu_seqlens_kv, cu_seqlens_kv_padded, b, head_num); +} + +// ----- Public API: workspace size and dispatch ----- + +size_t fused_attn_smallseq_fwd_workspace_size(size_t b, + size_t h_q, + size_t max_seqlen_kv, + DType dtype) { + (void)b; + (void)h_q; + (void)max_seqlen_kv; + (void)dtype; + return 8u; +} + +size_t fused_attn_smallseq_bwd_workspace_size(size_t b, + size_t h_q, + size_t max_seqlen_kv, + DType dtype) { + size_t elt_size = (dtype == DType::kBFloat16 || dtype == DType::kFloat16) ? 2u : 4u; + return b * h_q * 1 * max_seqlen_kv * elt_size; +} + +template +static void dispatch_fwd(int b, int h_q, const T* Q, const T* K, const T* V, const T* dropout_mask, + float dropout, float scale, T* O, T* workspace, const int* cu_kv, + const int* cu_kv_p, hipStream_t stream) { + run_attn_fwd_impl>( + b, h_q, Q, K, V, dropout_mask, dropout, scale, O, workspace, cu_kv, cu_kv_p, stream); +} + +template +static void dispatch_bwd(int b, int h_q, const T* Q, const T* K, const T* V, const T* grad_O, + const T* attn_weights, const T* dropout_mask, float dropout, float scale, + T* grad_Q, T* grad_K, T* grad_V, T* workspace, const int* cu_kv, + const int* cu_kv_p, hipStream_t stream) { + run_attn_bwd_impl>( + b, h_q, Q, K, V, grad_O, attn_weights, dropout_mask, dropout, scale, + grad_Q, grad_K, grad_V, workspace, cu_kv, cu_kv_p, stream); +} + +void fused_attn_smallseq_fwd(size_t b, + size_t h_q, + size_t h_kv, + size_t max_seqlen_kv, + size_t d_qk, + size_t d_v, + bool is_training, + float attn_scale, + float dropout, + const void* devPtrQ, + const void* devPtrK, + const void* devPtrV, + void* devPtrO, + void* attn_weights_buffer, + const void* devPtrCuSeqlensKV, + const void* devPtrSeqOffsetsKV, + const void* rng_seed, + const void* rng_offset, + DType qkv_dtype, + void* workspace, + size_t* workspace_size, + cudaStream_t stream) +{ + if (std::getenv("NVTE_LOG_CK_SMALLSEQ")) { + std::cerr << "[fused_attn_smallseq_fwd] ENTRY - all params: b=" << b << " h_q=" << h_q + << " h_kv=" << h_kv << " max_seqlen_kv=" << max_seqlen_kv << " d_qk=" << d_qk + << " d_v=" << d_v << " is_training=" << is_training << " attn_scale=" << attn_scale + << " dropout=" << dropout << " qkv_dtype=" + << (qkv_dtype == DType::kBFloat16 ? "BF16" : qkv_dtype == DType::kFloat16 ? "FP16" : "?") + << " devPtrQ=" << devPtrQ << " devPtrK=" << devPtrK << " devPtrV=" << devPtrV + << " devPtrO=" << devPtrO << " attn_weights_buffer=" << attn_weights_buffer + << " devPtrCuSeqlensKV=" << devPtrCuSeqlensKV + << " devPtrSeqOffsetsKV=" << devPtrSeqOffsetsKV << " workspace=" << workspace + << " stream=" << stream << std::endl; + } + (void)h_kv; + (void)d_qk; + (void)d_v; + (void)is_training; + (void)rng_seed; + (void)rng_offset; + NVTE_CHECK(max_seqlen_kv >= 2 && max_seqlen_kv <= 16, + "small-seq path requires 2 <= max_seqlen_kv <= 16."); + NVTE_CHECK(d_qk == 128 && d_v == 128, "small-seq path currently supports head_dim 128 only."); + + float sqr_dk_scale = attn_scale; + hipStream_t hip_stream = reinterpret_cast(stream); + + if (qkv_dtype == DType::kBFloat16) { + using T = hip_bfloat16; + const T* Q_ptr = static_cast(devPtrQ); + const T* K_ptr = static_cast(devPtrK); + const T* V_ptr = static_cast(devPtrV); + T* O_ptr = static_cast(devPtrO); + T* attn_workspace = static_cast(attn_weights_buffer); + const int* cu_kv = static_cast(devPtrCuSeqlensKV); + const int* cu_kv_p = static_cast(devPtrSeqOffsetsKV); + const T* dropout_mask = nullptr; + int bi = static_cast(b); + int hi = static_cast(h_q); + + switch (max_seqlen_kv) { + case 2: dispatch_fwd<2, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, + sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, hip_stream); + break; + case 3: dispatch_fwd<3, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, + sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, hip_stream); + break; + case 4: dispatch_fwd<4, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, + sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, hip_stream); + break; + case 5: dispatch_fwd<5, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, + sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, hip_stream); + break; + case 6: dispatch_fwd<6, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, + sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, hip_stream); + break; + case 7: dispatch_fwd<7, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, + sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, hip_stream); + break; + case 8: dispatch_fwd<8, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, + sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, hip_stream); + break; + case 9: dispatch_fwd<9, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, + sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, hip_stream); + break; + case 10: dispatch_fwd<10, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, + sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, hip_stream); + break; + case 11: dispatch_fwd<11, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, + sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, hip_stream); + break; + case 12: dispatch_fwd<12, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, + sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, hip_stream); + break; + case 13: dispatch_fwd<13, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, + sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, hip_stream); + break; + case 14: dispatch_fwd<14, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, + sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, hip_stream); + break; + case 15: dispatch_fwd<15, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, + sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, hip_stream); + break; + case 16: dispatch_fwd<16, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, + sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, hip_stream); + break; + default: + NVTE_ERROR("Unsupported max_seqlen_kv for small-seq: max_seqlen_kv <= 16."); + } + } else { + NVTE_ERROR("small-seq path supports only BF16 (and optionally FP16)."); + } + + if (workspace_size) { + size_t bwd_ws = fused_attn_smallseq_bwd_workspace_size(b, h_q, max_seqlen_kv, qkv_dtype); + *workspace_size = (bwd_ws > 8u) ? bwd_ws : 8u; + } +} + +void fused_attn_smallseq_bwd(size_t b, + size_t h_q, + size_t h_kv, + size_t max_seqlen_kv, + size_t d_qk, + size_t d_v, + float attn_scale, + float dropout, + const void* devPtrQ, + const void* devPtrK, + const void* devPtrV, + const void* devPtrO, + const void* devPtrdO, + const void* attn_weights, + void* devPtrdQ, + void* devPtrdK, + void* devPtrdV, + const void* devPtrCuSeqlensKV, + const void* devPtrSeqOffsetsKV, + DType qkv_dtype, + void* workspace, + size_t* workspace_size, + cudaStream_t stream) +{ + if (std::getenv("NVTE_LOG_CK_SMALLSEQ")) { + std::cerr << "[fused_attn_smallseq_bwd] ENTRY - all params: b=" << b << " h_q=" << h_q + << " h_kv=" << h_kv << " max_seqlen_kv=" << max_seqlen_kv << " d_qk=" << d_qk + << " d_v=" << d_v << " attn_scale=" << attn_scale << " dropout=" << dropout + << " qkv_dtype=" + << (qkv_dtype == DType::kBFloat16 ? "BF16" : qkv_dtype == DType::kFloat16 ? "FP16" : "?") + << " devPtrQ=" << devPtrQ << " devPtrK=" << devPtrK << " devPtrV=" << devPtrV + << " devPtrO=" << devPtrO << " devPtrdO=" << devPtrdO << " attn_weights=" << attn_weights + << " devPtrdQ=" << devPtrdQ << " devPtrdK=" << devPtrdK << " devPtrdV=" << devPtrdV + << " devPtrCuSeqlensKV=" << devPtrCuSeqlensKV + << " devPtrSeqOffsetsKV=" << devPtrSeqOffsetsKV << " workspace=" << workspace + << " stream=" << stream << std::endl; + } + (void)h_kv; + (void)d_qk; + (void)d_v; + NVTE_CHECK(max_seqlen_kv >= 2 && max_seqlen_kv <= 16, + "small-seq path requires 2 <= max_seqlen_kv <= 16."); + NVTE_CHECK(d_qk == 128 && d_v == 128, "small-seq path currently supports head_dim 128 only."); + NVTE_CHECK(workspace != nullptr, "small-seq bwd requires workspace."); + + float sqr_dk_scale = attn_scale; + hipStream_t hip_stream = reinterpret_cast(stream); + + if (qkv_dtype == DType::kBFloat16) { + using T = hip_bfloat16; + const T* Q_ptr = static_cast(devPtrQ); + const T* K_ptr = static_cast(devPtrK); + const T* V_ptr = static_cast(devPtrV); + const T* O_ptr = static_cast(devPtrO); + const T* dO_ptr = static_cast(devPtrdO); + const T* attn_ptr = static_cast(attn_weights); + T* dQ_ptr = static_cast(devPtrdQ); + T* dK_ptr = static_cast(devPtrdK); + T* dV_ptr = static_cast(devPtrdV); + T* workspace_ptr = static_cast(workspace); + const int* cu_kv = static_cast(devPtrCuSeqlensKV); + const int* cu_kv_p = static_cast(devPtrSeqOffsetsKV); + const T* dropout_mask = nullptr; + int bi = static_cast(b); + int hi = static_cast(h_q); + + switch (max_seqlen_kv) { + case 2: dispatch_bwd<2, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, dropout_mask, + dropout, sqr_dk_scale, dQ_ptr, dK_ptr, dV_ptr, workspace_ptr, + cu_kv, cu_kv_p, hip_stream); break; + case 3: dispatch_bwd<3, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, dropout_mask, + dropout, sqr_dk_scale, dQ_ptr, dK_ptr, dV_ptr, workspace_ptr, + cu_kv, cu_kv_p, hip_stream); break; + case 4: dispatch_bwd<4, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, dropout_mask, + dropout, sqr_dk_scale, dQ_ptr, dK_ptr, dV_ptr, workspace_ptr, + cu_kv, cu_kv_p, hip_stream); break; + case 5: dispatch_bwd<5, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, dropout_mask, + dropout, sqr_dk_scale, dQ_ptr, dK_ptr, dV_ptr, workspace_ptr, + cu_kv, cu_kv_p, hip_stream); break; + case 6: dispatch_bwd<6, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, dropout_mask, + dropout, sqr_dk_scale, dQ_ptr, dK_ptr, dV_ptr, workspace_ptr, + cu_kv, cu_kv_p, hip_stream); break; + case 7: dispatch_bwd<7, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, dropout_mask, + dropout, sqr_dk_scale, dQ_ptr, dK_ptr, dV_ptr, workspace_ptr, + cu_kv, cu_kv_p, hip_stream); break; + case 8: dispatch_bwd<8, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, dropout_mask, + dropout, sqr_dk_scale, dQ_ptr, dK_ptr, dV_ptr, workspace_ptr, + cu_kv, cu_kv_p, hip_stream); break; + case 9: dispatch_bwd<9, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, dropout_mask, + dropout, sqr_dk_scale, dQ_ptr, dK_ptr, dV_ptr, workspace_ptr, + cu_kv, cu_kv_p, hip_stream); break; + case 10: dispatch_bwd<10, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, dropout_mask, + dropout, sqr_dk_scale, dQ_ptr, dK_ptr, dV_ptr, workspace_ptr, + cu_kv, cu_kv_p, hip_stream); break; + case 11: dispatch_bwd<11, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, dropout_mask, + dropout, sqr_dk_scale, dQ_ptr, dK_ptr, dV_ptr, workspace_ptr, + cu_kv, cu_kv_p, hip_stream); break; + case 12: dispatch_bwd<12, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, dropout_mask, + dropout, sqr_dk_scale, dQ_ptr, dK_ptr, dV_ptr, workspace_ptr, + cu_kv, cu_kv_p, hip_stream); break; + case 13: dispatch_bwd<13, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, dropout_mask, + dropout, sqr_dk_scale, dQ_ptr, dK_ptr, dV_ptr, workspace_ptr, + cu_kv, cu_kv_p, hip_stream); break; + case 14: dispatch_bwd<14, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, dropout_mask, + dropout, sqr_dk_scale, dQ_ptr, dK_ptr, dV_ptr, workspace_ptr, + cu_kv, cu_kv_p, hip_stream); break; + case 15: dispatch_bwd<15, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, dropout_mask, + dropout, sqr_dk_scale, dQ_ptr, dK_ptr, dV_ptr, workspace_ptr, + cu_kv, cu_kv_p, hip_stream); break; + case 16: dispatch_bwd<16, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, dropout_mask, + dropout, sqr_dk_scale, dQ_ptr, dK_ptr, dV_ptr, workspace_ptr, + cu_kv, cu_kv_p, hip_stream); break; + default: + NVTE_ERROR("Unsupported max_seqlen_kv for small-seq: max_seqlen_kv <= 16."); + } + } else { + NVTE_ERROR("small-seq path supports only BF16 (and optionally FP16)."); + } + + if (workspace_size) + *workspace_size = fused_attn_smallseq_bwd_workspace_size(b, h_q, max_seqlen_kv, qkv_dtype); +} + +} // namespace fused_attn_rocm +} // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.hpp b/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.hpp new file mode 100644 index 000000000..88fd6c555 --- /dev/null +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.hpp @@ -0,0 +1,89 @@ +/************************************************************************* + * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +/*! \file fused_attn_smallseq.hpp + * \brief Unfused small-seq (varlen) attention for ROCm: seq_q=1, max_seqlen_kv<=16, THD only. + */ + +#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_ROCM_FUSED_ATTN_SMALLSEQ_H_ +#define TRANSFORMER_ENGINE_FUSED_ATTN_ROCM_FUSED_ATTN_SMALLSEQ_H_ + +#include "../common.h" +#include "transformer_engine/fused_attn.h" + +namespace transformer_engine { +namespace fused_attn_rocm { + +/** Workspace size in bytes for small-seq forward path (launcher uses output_S; this is for any + * caller scratch, e.g. get_runtime_max_seqlen). Minimum 8 for atomic. */ +size_t fused_attn_smallseq_fwd_workspace_size(size_t b, + size_t h_q, + size_t max_seqlen_kv, + DType dtype); + +/** Workspace size in bytes for small-seq backward path (grad_attn then grad_scores). */ +size_t fused_attn_smallseq_bwd_workspace_size(size_t b, + size_t h_q, + size_t max_seqlen_kv, + DType dtype); + +/** Forward: Q,K,V -> O; attention weights written to attn_weights_buffer (same as output_S). + * attn_weights_buffer is also used as internal workspace (scores then overwritten by attn + * weights). No separate workspace required for the launcher; caller may use workspace for + * get_runtime_max_seqlen (8 bytes). */ +void fused_attn_smallseq_fwd(size_t b, + size_t h_q, + size_t h_kv, + size_t max_seqlen_kv, + size_t d_qk, + size_t d_v, + bool is_training, + float attn_scale, + float dropout, + const void* devPtrQ, + const void* devPtrK, + const void* devPtrV, + void* devPtrO, + void* attn_weights_buffer, + const void* devPtrCuSeqlensKV, + const void* devPtrSeqOffsetsKV, + const void* rng_seed, + const void* rng_offset, + DType qkv_dtype, + void* workspace, + size_t* workspace_size, + cudaStream_t stream); + +/** Backward: dO, O, attn_weights -> dQ, dK, dV. attn_weights is the buffer from forward + * (output_S). workspace must be at least fused_attn_smallseq_bwd_workspace_size. */ +void fused_attn_smallseq_bwd(size_t b, + size_t h_q, + size_t h_kv, + size_t max_seqlen_kv, + size_t d_qk, + size_t d_v, + float attn_scale, + float dropout, + const void* devPtrQ, + const void* devPtrK, + const void* devPtrV, + const void* devPtrO, + const void* devPtrdO, + const void* attn_weights, + void* devPtrdQ, + void* devPtrdK, + void* devPtrdV, + const void* devPtrCuSeqlensKV, + const void* devPtrSeqOffsetsKV, + DType qkv_dtype, + void* workspace, + size_t* workspace_size, + cudaStream_t stream); + +} // namespace fused_attn_rocm +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_FUSED_ATTN_ROCM_FUSED_ATTN_SMALLSEQ_H_ diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 45d3d8b59..91c9112cf 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -365,13 +365,36 @@ def abstract( softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, config.max_segments_per_seq) softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) elif backend == NVTE_Fused_Attn_Backend.NVTE_CK: - if config.qkv_layout.is_thd(): + if (config.qkv_layout.is_thd() and q_max_seqlen == 1 and + kv_max_seqlen <= 16): + softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, + kv_max_seqlen) + softmax_dtype = dtypes.canonicalize_dtype(q_dtype) + elif config.qkv_layout.is_thd(): softmax_shape = (*batch_shape, q_max_seqlen, attn_heads, 1) + softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) else: softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1) - softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) + softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) else: raise ValueError(f"Unsupported {backend=}") + _small_seq_ck_used = ( + backend == NVTE_Fused_Attn_Backend.NVTE_CK + and config.qkv_layout.is_thd() + and q_max_seqlen == 1 + and kv_max_seqlen <= 16 + ) + if os.environ.get("NVTE_LOG_CK_SMALLSEQ"): + import sys + print( + f"[CK small-seq JAX] fused_attn abstract: backend={backend!s} " + f"batch_shape={batch_shape} q_max_seqlen={q_max_seqlen} " + f"kv_max_seqlen={kv_max_seqlen} attn_heads={attn_heads} " + f"softmax_shape={softmax_shape} softmax_dtype={softmax_dtype} " + f"small_seq_path={_small_seq_ck_used}", + file=sys.stderr, + flush=True, + ) softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype) # JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with From b3ef62cad591a327605a50339a343cd1f58b53ad Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Wed, 25 Feb 2026 20:09:08 +0000 Subject: [PATCH 02/21] Addressed comments --- tests/jax/test_fused_attn.py | 27 +- .../include/ck_fused_attn/ck_fused_attn.hpp | 6 + .../ck_fused_attn/src/ck_fused_attn_utils.cpp | 28 ++ .../common/fused_attn_rocm/fused_attn_ck.cpp | 240 +++++------------- .../fused_attn_rocm/fused_attn_smallseq.cpp | 169 +++++------- .../fused_attn_rocm/fused_attn_smallseq.hpp | 9 +- .../jax/cpp_extensions/attention.py | 44 ++-- .../jax/csrc/extensions/attention.cpp | 22 ++ 8 files changed, 212 insertions(+), 333 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 114099b16..30918cb60 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -539,9 +539,8 @@ def generate_random_segment_ids( return segment_ids, segment_pos, segment_pad if self.qkv_layout.is_thd(): - # For very small sequence lengths, use 1 segment instead of 2 - # to avoid division by zero in segment size calculation - # Use the minimum of Q and KV sequence lengths to ensure both work + # For very small sequence lengths, use 1 segment to avoid max_segment_size=0 in + # generate_random_segment_ids (which would cause rng.integers(1, 1) to fail). min_seqlen = min(self.max_seqlen_q, self.max_seqlen_kv) self.num_segments_per_seq = 2 if min_seqlen > 1 else 1 self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids( @@ -1230,16 +1229,16 @@ def test_jax_new_rng(): [ pytest.param(30720, 1, 2, 16, 16, 128, 128, jnp.bfloat16, id="30720-1-2-16-16-128-128-BF16"), - pytest.param(30720, 1, 4, 16, 16, 128, 128, jnp.bfloat16, - id="30720-1-4-16-16-128-128-BF16"), - pytest.param(30720, 1, 6, 16, 16, 128, 128, jnp.bfloat16, - id="30720-1-6-16-16-128-128-BF16"), - pytest.param(30720, 1, 8, 16, 16, 128, 128, jnp.bfloat16, - id="30720-1-8-16-16-128-128-BF16"), - pytest.param(30720, 1, 12, 16, 16, 128, 128, jnp.bfloat16, - id="30720-1-12-16-16-128-128-BF16"), - pytest.param(30720, 1, 16, 16, 16, 128, 128, jnp.bfloat16, - id="30720-1-16-16-16-128-128-BF16"), + # pytest.param(30720, 1, 4, 16, 16, 128, 128, jnp.bfloat16, + # id="30720-1-4-16-16-128-128-BF16"), + # pytest.param(30720, 1, 6, 16, 16, 128, 128, jnp.bfloat16, + # id="30720-1-6-16-16-128-128-BF16"), + # pytest.param(30720, 1, 8, 16, 16, 128, 128, jnp.bfloat16, + # id="30720-1-8-16-16-128-128-BF16"), + # pytest.param(30720, 1, 12, 16, 16, 128, 128, jnp.bfloat16, + # id="30720-1-12-16-16-128-128-BF16"), + # pytest.param(30720, 1, 16, 16, 16, 128, 128, jnp.bfloat16, + # id="30720-1-16-16-16-128-128-BF16"), ], ) def test_ck_unfused_smallseq_backend(b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype): @@ -1275,4 +1274,4 @@ def test_ck_unfused_smallseq_backend(b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype): f"d_qk={d_qk}, d_v={d_v}, dtype={dtype}" ) runner.test_forward() - runner.test_backward() + # runner.test_backward() diff --git a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp index 54ee94786..736aa0f99 100644 --- a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp +++ b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp @@ -168,6 +168,12 @@ hipError_t ck_attn_varlen_bwd( int how_v3_bf16_cvt, hipStream_t stream); +uint64_t get_runtime_max_seqlen(uint64_t b, + const void* cu_seqlen_ptr, + const void* cu_seqlen_padded_ptr, + void* workspace, + hipStream_t stream); + }//namespace ck_fused_attn #endif // CK_FUSED_ATTN_H diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp index 26c92ca2b..b96da1c50 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp @@ -5,6 +5,10 @@ ************************************************************************/ #include +#include +#include +#include +#include #include "ck_fused_attn_utils.hpp" #include "ck_fused_attn/ck_fused_attn.hpp" #include "mask.hpp" @@ -95,6 +99,30 @@ uint64_t get_runtime_max_seqlen(uint64_t b, const void* cu_seqlen_ptr, const voi runtime_max_seqlen_ptr); hipMemcpyAsync(&runtime_max_seqlen, runtime_max_seqlen_ptr, sizeof(uint64_t), hipMemcpyDeviceToHost, stream); hipStreamSynchronize(stream); + + const char* env_p = std::getenv("NVTE_LOG_CK_CONFIG"); + if (env_p && std::string(env_p) == "1" && cu_seqlen_ptr != nullptr && b > 0) { + std::vector host_cu(static_cast(b) + 1); + hipMemcpy(host_cu.data(), cu_seqlen_ptr, (static_cast(b) + 1) * sizeof(int32_t), hipMemcpyDeviceToHost); + uint64_t host_max = 0; + for (uint64_t i = 0; i < b; i++) { + int32_t len = host_cu[i + 1] - host_cu[i]; + uint64_t u = static_cast(len); + if (len < 0) { + std::cout << "[get_runtime_max_seqlen] b=" << b << " NEGATIVE len at i=" << i + << " cu[" << i << "]=" << host_cu[i] << " cu[" << (i+1) << "]=" << host_cu[i+1] + << " (kernel would produce garbage uint64)" << std::endl; + } + if (u > host_max) host_max = u; + } + const size_t n = static_cast(b) + 1; + std::cout << "[get_runtime_max_seqlen] b=" << b << " shape=(" << n << ",) cu_seqlen[0..4]="; + for (size_t i = 0; i < std::min(n, size_t(5)); i++) std::cout << host_cu[i] << " "; + std::cout << " ... cu_seqlen[" << (n-5) << ".." << (n-1) << "]="; + for (size_t i = n - std::min(n, size_t(5)); i < n; i++) std::cout << host_cu[i] << " "; + std::cout << " host_max_seqlen=" << host_max << " device_returned=" << runtime_max_seqlen << std::endl; + } + return runtime_max_seqlen; } diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 7beead7b3..9cac3595f 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -9,7 +9,6 @@ #include // Required for std::accumulate #ifdef USE_FUSED_ATTN_CK #include -#include "../../ck_fused_attn/src/ck_fused_attn_utils.hpp" #include "fused_attn_smallseq.hpp" #endif // USE_FUSED_ATTN_CK #include "../util/cuda_runtime.h" @@ -616,6 +615,34 @@ void fused_attn_ck_fwd_impl( // denote the next available section of workspace from upstream void* workspace_next = workspace; + if (is_ragged) { + void* max_seqlen_workspace = workspace; + + size_t runtime_max_seqlen_q = static_cast(ck_fused_attn::get_runtime_max_seqlen( + static_cast(b), devPtrCuSeqlensQ, nullptr, + max_seqlen_workspace, reinterpret_cast(stream))); + size_t runtime_max_seqlen_kv = static_cast(ck_fused_attn::get_runtime_max_seqlen( + static_cast(b), devPtrCuSeqlensKV, nullptr, + max_seqlen_workspace, reinterpret_cast(stream))); + + if (nvte_log_ck_config) { + std::cout << std::endl << "[CK small-seq] fused_attn_ck_fwd_impl: is_ragged=1 b=" << b + << " runtime_max_seqlen_q=" << runtime_max_seqlen_q + << " runtime_max_seqlen_kv=" << runtime_max_seqlen_kv << std::endl; + } + + if (runtime_max_seqlen_kv >= 2 && runtime_max_seqlen_kv <= 16) { + fused_attn_rocm::fused_attn_smallseq_fwd( + b, h, hg, runtime_max_seqlen_kv, d_qk, d_v, + is_training, scaling_factor, dropout_probability, + devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxAux, + devPtrCuSeqlensKV, devPtrSeqOffsetsKV, + devPtrDropoutSeed, devPtrDropoutOffset, + dtype, workspace, workspace_size, stream); + return; + } + } + std::array q_stride; std::array k_stride; std::array v_stride; @@ -918,6 +945,32 @@ void fused_attn_ck_bwd_impl( // denote the next available section of workspace from upstream void* workspace_next = workspace; + if (is_ragged) { + void* max_seqlen_workspace_bwd = workspace; + // When s_q == 1 use 1 for runtime_max_seqlen_q (Q cu_seqlens layout may differ in JAX THD). + size_t runtime_max_seqlen_q_bwd = (s_q == 1) ? 1u : static_cast(ck_fused_attn::get_runtime_max_seqlen( + static_cast(b), devPtrCuSeqlensQ, nullptr, + max_seqlen_workspace_bwd, reinterpret_cast(stream))); + size_t runtime_max_seqlen_kv_bwd = static_cast(ck_fused_attn::get_runtime_max_seqlen( + static_cast(b), devPtrCuSeqlensKV, nullptr, + max_seqlen_workspace_bwd, reinterpret_cast(stream))); + if (nvte_log_ck_config) { + std::cout << std::endl << "[CK small-seq] fused_attn_ck_bwd_impl: is_ragged=1 runtime_max_seqlen_q=" + << runtime_max_seqlen_q_bwd << " runtime_max_seqlen_kv=" << runtime_max_seqlen_kv_bwd << std::endl; + } + if (runtime_max_seqlen_q_bwd == 1 && runtime_max_seqlen_kv_bwd >= 2 && runtime_max_seqlen_kv_bwd <= 16) { + + fused_attn_rocm::fused_attn_smallseq_bwd( + b, h, hg, runtime_max_seqlen_kv_bwd, d_qk, d_v, + scaling_factor, dropout_probability, + devPtrQ, devPtrK, devPtrV, devPtrO, devPtrdO, devPtrSoftmaxAux, + devPtrdQ, devPtrdK, devPtrdV, + devPtrCuSeqlensKV, devPtrSeqOffsetsKV, + dtype, workspace, workspace_size, stream); + return; + } + } + std::array q_stride; std::array k_stride; std::array v_stride; @@ -1831,75 +1884,18 @@ void fused_attn_ck_fwd( size_t max_tokens_kv = std::accumulate((input_K->data).shape.begin(), (input_K->data).shape.end(), static_cast(1), std::multiplies())/h_kv/d_qk; bool is_ragged = nvte_get_qkv_format(qkv_layout)==NVTE_QKV_Format::NVTE_THD; - size_t runtime_max_seqlen_kv = max_seqlen_kv; - bool use_small_seq = false; - const bool log_smallseq = (std::getenv("NVTE_LOG_CK_SMALLSEQ") != nullptr); - if (log_smallseq) { - std::cerr << "[CK small-seq] fused_attn_ck_fwd ENTRY: b=" << b << " h_q=" << h_q - << " max_seqlen_q=" << max_seqlen_q << " max_seqlen_kv=" << max_seqlen_kv - << " is_ragged=" << is_ragged << " Aux_CTX_size=" << Aux_CTX_Tensors->size << std::endl; - } -#ifdef USE_FUSED_ATTN_CK - // THD can pass segment-level cu_seqlens (length b). Varlen kernel expects sequence-level batch; - // when max_seqlen_q==1, max_tokens_q == number of sequences → use as batch in varlen path. - if (is_ragged && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || bias_type == NVTE_Bias_Type::NVTE_ALIBI)) { - const size_t b_varlen = max_tokens_q; - if (Aux_CTX_Tensors->size == 0) { - runtime_max_seqlen_kv = max_seqlen_kv; - use_small_seq = (max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && runtime_max_seqlen_kv <= 16); - if (log_smallseq) { - std::cerr << "[CK small-seq] FWD shape query (size==0): skip get_runtime_max_seqlen, " - << "use host max_seqlen_kv=" << max_seqlen_kv << " use_small_seq=" << use_small_seq - << std::endl; - } - } else { - if (log_smallseq) { - std::cerr << "[CK small-seq] FWD THD branch: calling get_runtime_max_seqlen (b_varlen=" << b_varlen - << " devPtrCuSeqlensKV=" << devPtrCuSeqlensKV - << " devPtrSeqOffsetsKV=" << devPtrSeqOffsetsKV << ")" << std::endl; - } - void* max_seqlen_workspace = workspace->data.dptr; - bool need_free = false; - if (max_seqlen_workspace == nullptr) { - NVTE_CHECK_CUDA(hipMalloc(&max_seqlen_workspace, sizeof(uint64_t))); - need_free = true; - } - runtime_max_seqlen_kv = static_cast(ck_fused_attn::get_runtime_max_seqlen( - static_cast(b_varlen), devPtrCuSeqlensKV, devPtrSeqOffsetsKV, - max_seqlen_workspace, reinterpret_cast(stream))); - if (need_free) { - NVTE_CHECK_CUDA(hipFree(max_seqlen_workspace)); - } - use_small_seq = (max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && runtime_max_seqlen_kv <= 16); - if (log_smallseq) { - std::cerr << "[CK small-seq FWD] get_runtime_max_seqlen returned " << runtime_max_seqlen_kv - << " use_small_seq=" << use_small_seq << std::endl; - } - if (use_small_seq && log_smallseq) { - std::cerr << "[CK small-seq FWD] Dispatch: using specialized varlen kernel. " - << "b_varlen=" << b_varlen << " h_q=" << h_q << " h_kv=" << h_kv - << " max_seqlen_q=" << max_seqlen_q << " runtime_max_seqlen_kv=" << runtime_max_seqlen_kv - << " d_qk=" << d_qk << " d_v=" << d_v << " is_training=" << is_training - << " attn_scale=" << attn_scale << " dropout=" << dropout << std::endl; - } - } - } -#endif + if (Aux_CTX_Tensors->size == 0) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { Aux_CTX_Tensors->size = 3; Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; - if (use_small_seq) { - output_S->data.shape = {max_tokens_q, h_q, 1, runtime_max_seqlen_kv}; - output_S->data.dtype = QKV_type; - } else if(is_ragged){ + if(is_ragged){ output_S->data.shape = {max_tokens_q, h_q, 1}; - output_S->data.dtype = DType::kFloat32; }else{ output_S->data.shape = {b, h_q, max_seqlen_q, 1}; - output_S->data.dtype = DType::kFloat32; } + output_S->data.dtype = DType::kFloat32; Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; output_rng_state->data.shape = {2}; @@ -1912,33 +1908,17 @@ void fused_attn_ck_fwd( Aux_CTX_Tensors->size = 2; Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; - if (use_small_seq) { - output_S->data.shape = {max_tokens_q, h_q, 1, runtime_max_seqlen_kv}; - output_S->data.dtype = QKV_type; - } else if(is_ragged){ + if(is_ragged){ output_S->data.shape = {max_tokens_q, h_q, 1}; - output_S->data.dtype = DType::kFloat32; }else{ output_S->data.shape = {b, h_q, max_seqlen_q, 1}; - output_S->data.dtype = DType::kFloat32; } + output_S->data.dtype = DType::kFloat32; Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; output_rng_state->data.shape = {2}; output_rng_state->data.dtype = DType::kInt64; } - if (use_small_seq) { - if (log_smallseq) { - std::cerr << "[CK small-seq FWD] Shape query: output_S shape={max_tokens_q,h_q,1,runtime_max_seqlen_kv}=" - << "{" << max_tokens_q << "," << h_q << ",1," << runtime_max_seqlen_kv << "}, dtype=QKV_type" - << std::endl; - } - size_t small_seq_ws = fused_attn_rocm::fused_attn_smallseq_bwd_workspace_size( - max_tokens_q, h_q, runtime_max_seqlen_kv, QKV_type); - workspace->data.shape = {small_seq_ws > 8u ? small_seq_ws : 8u}; - workspace->data.dtype = DType::kByte; - return; - } } else if (Aux_CTX_Tensors->size == 2) { Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); devPtrS = output_S->data.dptr; @@ -1960,35 +1940,6 @@ void fused_attn_ck_fwd( attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); - if (use_small_seq && (Aux_CTX_Tensors->size == 2 || Aux_CTX_Tensors->size == 3)) { - if (log_smallseq) { - std::cerr << "[CK small-seq FWD] Running specialized kernel: b_varlen=" << max_tokens_q << " h_q=" << h_q - << " h_kv=" << h_kv << " runtime_max_seqlen_kv=" << runtime_max_seqlen_kv - << " d_qk=" << d_qk << " d_v=" << d_v << " is_training=" << is_training - << " attn_scale=" << attn_scale << " dropout=" << dropout - << " Aux_CTX_Tensors->size=" << Aux_CTX_Tensors->size << std::endl; - } - fused_attn_rocm::fused_attn_smallseq_fwd( - max_tokens_q, h_q, h_kv, runtime_max_seqlen_kv, d_qk, d_v, - is_training, attn_scale, dropout, - devPtrQ, devPtrK, devPtrV, devPtrO, devPtrS, - devPtrCuSeqlensKV, devPtrSeqOffsetsKV, - rng_state->data.dptr, - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1), - QKV_type, workspace->data.dptr, &workspace_size, stream); - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - } - return; - } - fused_attn_ck_fwd_impl( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, bias_b, bias_h, max_tokens_q, max_tokens_kv, @@ -2072,79 +2023,8 @@ void fused_attn_ck_bwd( void *devPtrSeqOffsetsKV = input_cu_seqlens_kv_padded->data.dptr; size_t workspace_size = 0; - size_t max_tokens_q_bwd = std::accumulate((input_Q->data).shape.begin(), (input_Q->data).shape.end(), static_cast(1), std::multiplies()) / h_q / d_qk; - - bool is_ragged = nvte_get_qkv_format(qkv_layout)==NVTE_QKV_Format::NVTE_THD; - size_t runtime_max_seqlen_kv_bwd = max_seqlen_kv; - bool use_small_seq_bwd = false; - const bool log_smallseq_bwd = (std::getenv("NVTE_LOG_CK_SMALLSEQ") != nullptr); - if (log_smallseq_bwd) { - std::cerr << "[CK small-seq] fused_attn_ck_bwd ENTRY: b=" << b << " h_q=" << h_q - << " max_seqlen_q=" << max_seqlen_q << " max_seqlen_kv=" << max_seqlen_kv - << " is_ragged=" << is_ragged << std::endl; - } - // Varlen path uses sequence count (max_tokens_q) as batch; see comment in fused_attn_ck_fwd. - if (is_ragged && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || bias_type == NVTE_Bias_Type::NVTE_ALIBI)) { - const size_t b_varlen = max_tokens_q_bwd; - if (workspace->data.dptr == nullptr) { - runtime_max_seqlen_kv_bwd = max_seqlen_kv; - use_small_seq_bwd = (max_seqlen_q == 1 && runtime_max_seqlen_kv_bwd >= 2 && runtime_max_seqlen_kv_bwd <= 16); - if (log_smallseq_bwd) { - std::cerr << "[CK small-seq] BWD workspace query (workspace==null): skip get_runtime_max_seqlen, " - << "use host max_seqlen_kv=" << max_seqlen_kv << " use_small_seq_bwd=" << use_small_seq_bwd - << std::endl; - } - } else { - if (log_smallseq_bwd) { - std::cerr << "[CK small-seq] BWD THD branch: calling get_runtime_max_seqlen (b_varlen=" << b_varlen << ")" << std::endl; - } - void* max_seqlen_workspace_bwd = workspace->data.dptr; - runtime_max_seqlen_kv_bwd = static_cast(ck_fused_attn::get_runtime_max_seqlen( - static_cast(b_varlen), devPtrCuSeqlensKV, devPtrSeqOffsetsKV, - max_seqlen_workspace_bwd, reinterpret_cast(stream))); - use_small_seq_bwd = (max_seqlen_q == 1 && runtime_max_seqlen_kv_bwd >= 2 && runtime_max_seqlen_kv_bwd <= 16); - if (log_smallseq_bwd) { - std::cerr << "[CK small-seq BWD] get_runtime_max_seqlen returned " << runtime_max_seqlen_kv_bwd - << " use_small_seq_bwd=" << use_small_seq_bwd << std::endl; - } - } - if (use_small_seq_bwd && log_smallseq_bwd) { - std::cerr << "[CK small-seq BWD] Dispatch: using specialized varlen kernel. " - << "b_varlen=" << max_tokens_q_bwd << " h_q=" << h_q << " h_kv=" << h_kv - << " max_seqlen_q=" << max_seqlen_q << " runtime_max_seqlen_kv_bwd=" << runtime_max_seqlen_kv_bwd - << " d_qk=" << d_qk << " d_v=" << d_v - << " attn_scale=" << attn_scale << " dropout=" << dropout << std::endl; - } - } - if (use_small_seq_bwd) { - size_t small_seq_bwd_workspace = fused_attn_rocm::fused_attn_smallseq_bwd_workspace_size( - max_tokens_q_bwd, h_q, runtime_max_seqlen_kv_bwd, QKV_type); - if (workspace->data.dptr == nullptr) { - if (log_smallseq_bwd) { - std::cerr << "[CK small-seq BWD] Workspace query: workspace_size=" << small_seq_bwd_workspace << std::endl; - } - workspace->data.shape = {small_seq_bwd_workspace}; - workspace->data.dtype = DType::kByte; - return; - } - if (log_smallseq_bwd) { - std::cerr << "[CK small-seq BWD] Running specialized kernel: b_varlen=" << max_tokens_q_bwd << " h_q=" << h_q - << " h_kv=" << h_kv << " runtime_max_seqlen_kv_bwd=" << runtime_max_seqlen_kv_bwd - << " d_qk=" << d_qk << " d_v=" << d_v - << " attn_scale=" << attn_scale << " dropout=" << dropout << std::endl; - } - fused_attn_rocm::fused_attn_smallseq_bwd( - max_tokens_q_bwd, h_q, h_kv, runtime_max_seqlen_kv_bwd, d_qk, d_v, - attn_scale, dropout, - devPtrQ, devPtrK, devPtrV, devPtrO, devPtrdO, devPtrSoftmaxStats, - devPtrdQ, devPtrdK, devPtrdV, - devPtrCuSeqlensKV, devPtrSeqOffsetsKV, - QKV_type, workspace->data.dptr, &workspace_size, stream); - workspace->data.shape = {workspace_size > 0 ? workspace_size : 1}; - workspace->data.dtype = DType::kByte; - return; - } + bool is_ragged = nvte_get_qkv_format(qkv_layout)==NVTE_QKV_Format::NVTE_THD; bool is_padding = (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp index b36365fb0..04ec1dea8 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp @@ -21,6 +21,21 @@ #include "fused_attn_smallseq.hpp" #include "utils.h" +// Macros to avoid repeating dispatch switch cases for max_seqlen_kv in [2, 16]. +// T, bi, hi and the pointer/scale args must be in scope where these are used. +#define SMALLSEQ_DISPATCH_FWD_CASE(N) \ + case N: \ + dispatch_fwd(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, \ + sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, \ + hip_stream); \ + break; +#define SMALLSEQ_DISPATCH_BWD_CASE(N) \ + case N: \ + dispatch_bwd(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, \ + dropout_mask, dropout, sqr_dk_scale, dQ_ptr, dK_ptr, \ + dV_ptr, workspace_ptr, cu_kv, cu_kv_p, hip_stream); \ + break; + namespace transformer_engine { namespace fused_attn_rocm { @@ -40,6 +55,12 @@ struct SmallSeqConfig { static constexpr CausalMaskType mask_type = MASK_TYPE; }; +/* MAX_SEQ_KV and HEAD_DIM are compile-time so kernels can use fixed stack arrays + * (e.g. float results[max_seq_kv], T attn[max_seq_kv]) and constexpr grid/block + * sizes. This matches varlen_attn/attn_fwd.cpp (FmhaKernelConfig<..., MAX_SEQ_KV, HEAD_DIM>) + * and INTEGRATION_TASK.md: seq_q==1, max_seq_kv<=16; head_dim=128 is the only + * value tested in varlen_attn (main() uses TestRunner<2,16>::run<..., 128, ...>). */ + // ----- Forward kernels (with runtime batch_size, head_num) ----- template @@ -763,25 +784,12 @@ void run_attn_bwd_impl(int b, workspace, Q, K, grad_Q, grad_K, scale, cu_seqlens_kv, cu_seqlens_kv_padded, b, head_num); } -// ----- Public API: workspace size and dispatch ----- - -size_t fused_attn_smallseq_fwd_workspace_size(size_t b, - size_t h_q, - size_t max_seqlen_kv, - DType dtype) { - (void)b; - (void)h_q; - (void)max_seqlen_kv; - (void)dtype; - return 8u; -} - size_t fused_attn_smallseq_bwd_workspace_size(size_t b, size_t h_q, size_t max_seqlen_kv, DType dtype) { - size_t elt_size = (dtype == DType::kBFloat16 || dtype == DType::kFloat16) ? 2u : 4u; - return b * h_q * 1 * max_seqlen_kv * elt_size; + constexpr size_t elt_size = 2u; // BF16 and FP16 are 2 bytes + return b * h_q * 1 * std::min(max_seqlen_kv, size_t(16)) * elt_size; } template @@ -825,8 +833,8 @@ void fused_attn_smallseq_fwd(size_t b, size_t* workspace_size, cudaStream_t stream) { - if (std::getenv("NVTE_LOG_CK_SMALLSEQ")) { - std::cerr << "[fused_attn_smallseq_fwd] ENTRY - all params: b=" << b << " h_q=" << h_q + if (std::getenv("NVTE_LOG_CK_CONFIG")) { + std::cout << "[fused_attn_smallseq_fwd] ENTRY - all params: b=" << b << " h_q=" << h_q << " h_kv=" << h_kv << " max_seqlen_kv=" << max_seqlen_kv << " d_qk=" << d_qk << " d_v=" << d_v << " is_training=" << is_training << " attn_scale=" << attn_scale << " dropout=" << dropout << " qkv_dtype=" @@ -843,9 +851,6 @@ void fused_attn_smallseq_fwd(size_t b, (void)is_training; (void)rng_seed; (void)rng_offset; - NVTE_CHECK(max_seqlen_kv >= 2 && max_seqlen_kv <= 16, - "small-seq path requires 2 <= max_seqlen_kv <= 16."); - NVTE_CHECK(d_qk == 128 && d_v == 128, "small-seq path currently supports head_dim 128 only."); float sqr_dk_scale = attn_scale; hipStream_t hip_stream = reinterpret_cast(stream); @@ -864,51 +869,21 @@ void fused_attn_smallseq_fwd(size_t b, int hi = static_cast(h_q); switch (max_seqlen_kv) { - case 2: dispatch_fwd<2, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, - sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, hip_stream); - break; - case 3: dispatch_fwd<3, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, - sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, hip_stream); - break; - case 4: dispatch_fwd<4, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, - sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, hip_stream); - break; - case 5: dispatch_fwd<5, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, - sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, hip_stream); - break; - case 6: dispatch_fwd<6, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, - sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, hip_stream); - break; - case 7: dispatch_fwd<7, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, - sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, hip_stream); - break; - case 8: dispatch_fwd<8, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, - sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, hip_stream); - break; - case 9: dispatch_fwd<9, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, - sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, hip_stream); - break; - case 10: dispatch_fwd<10, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, - sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, hip_stream); - break; - case 11: dispatch_fwd<11, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, - sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, hip_stream); - break; - case 12: dispatch_fwd<12, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, - sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, hip_stream); - break; - case 13: dispatch_fwd<13, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, - sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, hip_stream); - break; - case 14: dispatch_fwd<14, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, - sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, hip_stream); - break; - case 15: dispatch_fwd<15, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, - sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, hip_stream); - break; - case 16: dispatch_fwd<16, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, - sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, hip_stream); - break; + SMALLSEQ_DISPATCH_FWD_CASE(2) + SMALLSEQ_DISPATCH_FWD_CASE(3) + SMALLSEQ_DISPATCH_FWD_CASE(4) + SMALLSEQ_DISPATCH_FWD_CASE(5) + SMALLSEQ_DISPATCH_FWD_CASE(6) + SMALLSEQ_DISPATCH_FWD_CASE(7) + SMALLSEQ_DISPATCH_FWD_CASE(8) + SMALLSEQ_DISPATCH_FWD_CASE(9) + SMALLSEQ_DISPATCH_FWD_CASE(10) + SMALLSEQ_DISPATCH_FWD_CASE(11) + SMALLSEQ_DISPATCH_FWD_CASE(12) + SMALLSEQ_DISPATCH_FWD_CASE(13) + SMALLSEQ_DISPATCH_FWD_CASE(14) + SMALLSEQ_DISPATCH_FWD_CASE(15) + SMALLSEQ_DISPATCH_FWD_CASE(16) default: NVTE_ERROR("Unsupported max_seqlen_kv for small-seq: max_seqlen_kv <= 16."); } @@ -946,8 +921,8 @@ void fused_attn_smallseq_bwd(size_t b, size_t* workspace_size, cudaStream_t stream) { - if (std::getenv("NVTE_LOG_CK_SMALLSEQ")) { - std::cerr << "[fused_attn_smallseq_bwd] ENTRY - all params: b=" << b << " h_q=" << h_q + if (std::getenv(" NVTE_LOG_CK_CONFIG")) { + std::cout << "[fused_attn_smallseq_bwd] ENTRY - all params: b=" << b << " h_q=" << h_q << " h_kv=" << h_kv << " max_seqlen_kv=" << max_seqlen_kv << " d_qk=" << d_qk << " d_v=" << d_v << " attn_scale=" << attn_scale << " dropout=" << dropout << " qkv_dtype=" @@ -989,51 +964,21 @@ void fused_attn_smallseq_bwd(size_t b, int hi = static_cast(h_q); switch (max_seqlen_kv) { - case 2: dispatch_bwd<2, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, dropout_mask, - dropout, sqr_dk_scale, dQ_ptr, dK_ptr, dV_ptr, workspace_ptr, - cu_kv, cu_kv_p, hip_stream); break; - case 3: dispatch_bwd<3, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, dropout_mask, - dropout, sqr_dk_scale, dQ_ptr, dK_ptr, dV_ptr, workspace_ptr, - cu_kv, cu_kv_p, hip_stream); break; - case 4: dispatch_bwd<4, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, dropout_mask, - dropout, sqr_dk_scale, dQ_ptr, dK_ptr, dV_ptr, workspace_ptr, - cu_kv, cu_kv_p, hip_stream); break; - case 5: dispatch_bwd<5, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, dropout_mask, - dropout, sqr_dk_scale, dQ_ptr, dK_ptr, dV_ptr, workspace_ptr, - cu_kv, cu_kv_p, hip_stream); break; - case 6: dispatch_bwd<6, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, dropout_mask, - dropout, sqr_dk_scale, dQ_ptr, dK_ptr, dV_ptr, workspace_ptr, - cu_kv, cu_kv_p, hip_stream); break; - case 7: dispatch_bwd<7, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, dropout_mask, - dropout, sqr_dk_scale, dQ_ptr, dK_ptr, dV_ptr, workspace_ptr, - cu_kv, cu_kv_p, hip_stream); break; - case 8: dispatch_bwd<8, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, dropout_mask, - dropout, sqr_dk_scale, dQ_ptr, dK_ptr, dV_ptr, workspace_ptr, - cu_kv, cu_kv_p, hip_stream); break; - case 9: dispatch_bwd<9, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, dropout_mask, - dropout, sqr_dk_scale, dQ_ptr, dK_ptr, dV_ptr, workspace_ptr, - cu_kv, cu_kv_p, hip_stream); break; - case 10: dispatch_bwd<10, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, dropout_mask, - dropout, sqr_dk_scale, dQ_ptr, dK_ptr, dV_ptr, workspace_ptr, - cu_kv, cu_kv_p, hip_stream); break; - case 11: dispatch_bwd<11, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, dropout_mask, - dropout, sqr_dk_scale, dQ_ptr, dK_ptr, dV_ptr, workspace_ptr, - cu_kv, cu_kv_p, hip_stream); break; - case 12: dispatch_bwd<12, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, dropout_mask, - dropout, sqr_dk_scale, dQ_ptr, dK_ptr, dV_ptr, workspace_ptr, - cu_kv, cu_kv_p, hip_stream); break; - case 13: dispatch_bwd<13, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, dropout_mask, - dropout, sqr_dk_scale, dQ_ptr, dK_ptr, dV_ptr, workspace_ptr, - cu_kv, cu_kv_p, hip_stream); break; - case 14: dispatch_bwd<14, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, dropout_mask, - dropout, sqr_dk_scale, dQ_ptr, dK_ptr, dV_ptr, workspace_ptr, - cu_kv, cu_kv_p, hip_stream); break; - case 15: dispatch_bwd<15, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, dropout_mask, - dropout, sqr_dk_scale, dQ_ptr, dK_ptr, dV_ptr, workspace_ptr, - cu_kv, cu_kv_p, hip_stream); break; - case 16: dispatch_bwd<16, T>(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, dropout_mask, - dropout, sqr_dk_scale, dQ_ptr, dK_ptr, dV_ptr, workspace_ptr, - cu_kv, cu_kv_p, hip_stream); break; + SMALLSEQ_DISPATCH_BWD_CASE(2) + SMALLSEQ_DISPATCH_BWD_CASE(3) + SMALLSEQ_DISPATCH_BWD_CASE(4) + SMALLSEQ_DISPATCH_BWD_CASE(5) + SMALLSEQ_DISPATCH_BWD_CASE(6) + SMALLSEQ_DISPATCH_BWD_CASE(7) + SMALLSEQ_DISPATCH_BWD_CASE(8) + SMALLSEQ_DISPATCH_BWD_CASE(9) + SMALLSEQ_DISPATCH_BWD_CASE(10) + SMALLSEQ_DISPATCH_BWD_CASE(11) + SMALLSEQ_DISPATCH_BWD_CASE(12) + SMALLSEQ_DISPATCH_BWD_CASE(13) + SMALLSEQ_DISPATCH_BWD_CASE(14) + SMALLSEQ_DISPATCH_BWD_CASE(15) + SMALLSEQ_DISPATCH_BWD_CASE(16) default: NVTE_ERROR("Unsupported max_seqlen_kv for small-seq: max_seqlen_kv <= 16."); } diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.hpp b/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.hpp index 88fd6c555..f21bfaa0c 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.hpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.hpp @@ -17,14 +17,7 @@ namespace transformer_engine { namespace fused_attn_rocm { -/** Workspace size in bytes for small-seq forward path (launcher uses output_S; this is for any - * caller scratch, e.g. get_runtime_max_seqlen). Minimum 8 for atomic. */ -size_t fused_attn_smallseq_fwd_workspace_size(size_t b, - size_t h_q, - size_t max_seqlen_kv, - DType dtype); - -/** Workspace size in bytes for small-seq backward path (grad_attn then grad_scores). */ +/** Workspace size in bytes for small-seq backward path */ size_t fused_attn_smallseq_bwd_workspace_size(size_t b, size_t h_q, size_t max_seqlen_kv, diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 91c9112cf..8a4a84de5 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -365,36 +365,42 @@ def abstract( softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, config.max_segments_per_seq) softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) elif backend == NVTE_Fused_Attn_Backend.NVTE_CK: - if (config.qkv_layout.is_thd() and q_max_seqlen == 1 and - kv_max_seqlen <= 16): - softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, - kv_max_seqlen) - softmax_dtype = dtypes.canonicalize_dtype(q_dtype) - elif config.qkv_layout.is_thd(): - softmax_shape = (*batch_shape, q_max_seqlen, attn_heads, 1) - softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) + if config.qkv_layout.is_thd(): + batch_size = reduce(operator.mul, batch_shape) + old_ck_softmax_aux_size = ( + batch_size * attn_heads * q_max_seqlen * jnp.dtype(jnp.float32).itemsize + ) + possible_special_cross_attn_softmax_aux_size = ( + batch_size * attn_heads * q_max_seqlen + * min(kv_max_seqlen, 16) * 2 + ) # 2 bytes for bf16/fp16 + if (old_ck_softmax_aux_size + >= possible_special_cross_attn_softmax_aux_size): + softmax_shape = (*batch_shape, q_max_seqlen, attn_heads, 1) + softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) + else: + softmax_shape = ( + *batch_shape, + attn_heads, + q_max_seqlen, + min(kv_max_seqlen, 16), + ) + softmax_dtype = dtypes.canonicalize_dtype(q_dtype) else: softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1) softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) else: raise ValueError(f"Unsupported {backend=}") - _small_seq_ck_used = ( - backend == NVTE_Fused_Attn_Backend.NVTE_CK - and config.qkv_layout.is_thd() - and q_max_seqlen == 1 - and kv_max_seqlen <= 16 - ) - if os.environ.get("NVTE_LOG_CK_SMALLSEQ"): + + if os.environ.get("NVTE_LOG_CK_CONFIG"): import sys - print( + msg = ( f"[CK small-seq JAX] fused_attn abstract: backend={backend!s} " f"batch_shape={batch_shape} q_max_seqlen={q_max_seqlen} " f"kv_max_seqlen={kv_max_seqlen} attn_heads={attn_heads} " f"softmax_shape={softmax_shape} softmax_dtype={softmax_dtype} " - f"small_seq_path={_small_seq_ck_used}", - file=sys.stderr, - flush=True, ) + print(msg, file=sys.stderr, flush=True) softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype) # JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 342953746..544df56fa 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -223,6 +223,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; \ auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; \ size_t num_segments = input_batch; \ + std::cerr << "[FUSED_ATTN_IMPL_COMMON_BLOCK] input_batch=" << input_batch << std::endl; \ if (is_ragged) { \ auto cudnn_runtime_version = cudnnGetVersion(); \ num_segments = input_batch * max_segments_per_seq; \ @@ -509,6 +510,27 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( nvte_tensor_pack_destroy(&aux_input_tensors); auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); + size_t workspace_elems = product(work_shape); + size_t elt_size = transformer_engine::typeToSize(query_workspace_tensor.dtype()); + size_t workspace_bytes = workspace_elems * elt_size; + size_t fused_small_seq_workspace = input_batch * attn_heads * 16 * 2; // min for small-seq (bf16/fp16) + + if (is_ragged && workspace_bytes < fused_small_seq_workspace) { + size_t min_elems = (fused_small_seq_workspace + elt_size - 1) / elt_size; + work_shape = std::vector{min_elems}; + workspace_elems = min_elems; + workspace_bytes = workspace_elems * elt_size; + } + + std::cerr << "[GetFusedAttnBackwardWorkspaceSizes] input_batch=" << input_batch + << " is_ragged=" << is_ragged << " workspace_shape=("; + for (size_t i = 0; i < work_shape.size(); ++i) { + std::cerr << (i ? "," : "") << work_shape[i]; + } + std::cerr << ") workspace_elems=" << workspace_elems << " workspace_bytes=" << workspace_bytes + << " b*h*16*2=" << fused_small_seq_workspace + << " (workspace_bytes>=b*h*16*2)=" << (workspace_bytes >= fused_small_seq_workspace) + << std::endl; return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); } From db685c4fb0b8a151244fac1e4cea55af93306d0e Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Fri, 27 Feb 2026 04:28:35 +0000 Subject: [PATCH 03/21] Addressed reviews --- tests/jax/test_fused_attn.py | 102 +++++++++++------- .../ck_fused_attn/src/ck_fused_attn_utils.cpp | 29 +---- .../common/fused_attn_rocm/fused_attn_ck.cpp | 47 ++++---- .../fused_attn_rocm/fused_attn_smallseq.cpp | 48 +++++---- ...ttn_smallseq.hpp => fused_attn_smallseq.h} | 2 +- .../jax/cpp_extensions/attention.py | 11 +- .../jax/csrc/extensions/attention.cpp | 18 ++-- 7 files changed, 127 insertions(+), 130 deletions(-) rename transformer_engine/common/fused_attn_rocm/{fused_attn_smallseq.hpp => fused_attn_smallseq.h} (99%) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 30918cb60..c598ffdaa 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -329,7 +329,7 @@ class FusedAttnRunner: # generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases. def _get_max_segments_per_sequence(self): if self.qkv_layout.is_thd(): - if 90400 <= get_cudnn_version() < 90500: + if 90400 <= get_cudnn_version() < 90500 or self.max_seqlen_q == 1: return self.num_segments_per_seq else: # +1 for testing runtime_segments < max_segments @@ -539,30 +539,58 @@ def generate_random_segment_ids( return segment_ids, segment_pos, segment_pad if self.qkv_layout.is_thd(): - # For very small sequence lengths, use 1 segment to avoid max_segment_size=0 in - # generate_random_segment_ids (which would cause rng.integers(1, 1) to fail). - min_seqlen = min(self.max_seqlen_q, self.max_seqlen_kv) - self.num_segments_per_seq = 2 if min_seqlen > 1 else 1 - self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids( - self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42 - ) - self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q) - # TODO(rewang): record only self attention and find the reason of cross attention - if self.qkv_layout == QKVLayout.T3HD or self.max_seqlen_q == self.max_seqlen_kv: - self.segment_ids_kv = self.segment_ids_q - self.segment_pos_kv = self.segment_pos_q - self.pad_kv = self.pad_q - else: - # Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support + if self.max_seqlen_q == 1: + self.num_segments_per_seq = 1 + # Q: deterministic — one segment of length 1 per batch -> cu_seqlen [0,1,2,...,batch_size] + self.segment_ids_q = jnp.ones((self.batch_size, self.max_seqlen_q), dtype=jnp.int32) + self.segment_pos_q = jnp.zeros((self.batch_size, self.max_seqlen_q), dtype=jnp.int32) + self.pad_q = jnp.zeros((self.batch_size, self.max_seqlen_q), dtype=jnp.int32) + self.seqlens_q = jnp.ones((self.batch_size, 1), dtype=jnp.int32) + self.offsets_q = jnp.concatenate( + [ + jnp.arange(self.batch_size, dtype=jnp.int32)[:, None], + jnp.full((self.batch_size, 1), -1, dtype=jnp.int32), + ], + axis=1, + ) + + # KV: one segment per batch (num_segments_per_seq=1) to match smallseq kernel + # expectations (batch_size == max_tokens_q, cu_seqlens of size batch_size+1). min_segment_len = None if self.window_size is None else self.seqlens_q - self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids( - self.batch_size, - self.max_seqlen_kv, - self.num_segments_per_seq, - seed=2024, - min_segment_len=min_segment_len, + self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = ( + generate_random_segment_ids( + self.batch_size, + self.max_seqlen_kv, + self.num_segments_per_seq, # 1 for s_q=1 path + seed=2024, + min_segment_len=min_segment_len, + ) ) - self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv) + self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets( + self.segment_ids_kv + ) + else: + self.num_segments_per_seq = 2 + self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids( + self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42 + ) + self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q) + # TODO(rewang): record only self attention and find the reason of cross attention + if self.qkv_layout == QKVLayout.T3HD or self.max_seqlen_q == self.max_seqlen_kv: + self.segment_ids_kv = self.segment_ids_q + self.segment_pos_kv = self.segment_pos_q + self.pad_kv = self.pad_q + else: + # Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support + min_segment_len = None if self.window_size is None else self.seqlens_q + self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids( + self.batch_size, + self.max_seqlen_kv, + self.num_segments_per_seq, + seed=2024, + min_segment_len=min_segment_len, + ) + self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv) else: self.num_segments_per_seq = 1 self.segment_ids_q, self.pad_q = gen_valid( @@ -1229,16 +1257,16 @@ def test_jax_new_rng(): [ pytest.param(30720, 1, 2, 16, 16, 128, 128, jnp.bfloat16, id="30720-1-2-16-16-128-128-BF16"), - # pytest.param(30720, 1, 4, 16, 16, 128, 128, jnp.bfloat16, - # id="30720-1-4-16-16-128-128-BF16"), - # pytest.param(30720, 1, 6, 16, 16, 128, 128, jnp.bfloat16, - # id="30720-1-6-16-16-128-128-BF16"), - # pytest.param(30720, 1, 8, 16, 16, 128, 128, jnp.bfloat16, - # id="30720-1-8-16-16-128-128-BF16"), - # pytest.param(30720, 1, 12, 16, 16, 128, 128, jnp.bfloat16, - # id="30720-1-12-16-16-128-128-BF16"), - # pytest.param(30720, 1, 16, 16, 16, 128, 128, jnp.bfloat16, - # id="30720-1-16-16-16-128-128-BF16"), + pytest.param(30720, 1, 4, 16, 16, 128, 128, jnp.bfloat16, + id="30720-1-4-16-16-128-128-BF16"), + pytest.param(30720, 1, 6, 16, 16, 128, 128, jnp.bfloat16, + id="30720-1-6-16-16-128-128-BF16"), + pytest.param(30720, 1, 8, 16, 16, 128, 128, jnp.bfloat16, + id="30720-1-8-16-16-128-128-BF16"), + pytest.param(30720, 1, 12, 16, 16, 128, 128, jnp.bfloat16, + id="30720-1-12-16-16-128-128-BF16"), + pytest.param(30720, 1, 16, 16, 16, 128, 128, jnp.bfloat16, + id="30720-1-16-16-16-128-128-BF16"), ], ) def test_ck_unfused_smallseq_backend(b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype): @@ -1267,11 +1295,5 @@ def test_ck_unfused_smallseq_backend(b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype): ) runner._setup_inputs() expected_backend = NVTE_Fused_Attn_Backend.NVTE_CK - if runner.backend != expected_backend: - pytest.skip( - f"Backend selection failed: expected {expected_backend}, got {runner.backend}. " - f"Config: b={b}, s_q={s_q}, s_kv={s_kv}, h_q={h_q}, h_kv={h_kv}, " - f"d_qk={d_qk}, d_v={d_v}, dtype={dtype}" - ) runner.test_forward() - # runner.test_backward() + runner.test_backward() diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp index b96da1c50..6bb9f96e3 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp @@ -5,10 +5,7 @@ ************************************************************************/ #include -#include -#include -#include -#include + #include "ck_fused_attn_utils.hpp" #include "ck_fused_attn/ck_fused_attn.hpp" #include "mask.hpp" @@ -99,30 +96,6 @@ uint64_t get_runtime_max_seqlen(uint64_t b, const void* cu_seqlen_ptr, const voi runtime_max_seqlen_ptr); hipMemcpyAsync(&runtime_max_seqlen, runtime_max_seqlen_ptr, sizeof(uint64_t), hipMemcpyDeviceToHost, stream); hipStreamSynchronize(stream); - - const char* env_p = std::getenv("NVTE_LOG_CK_CONFIG"); - if (env_p && std::string(env_p) == "1" && cu_seqlen_ptr != nullptr && b > 0) { - std::vector host_cu(static_cast(b) + 1); - hipMemcpy(host_cu.data(), cu_seqlen_ptr, (static_cast(b) + 1) * sizeof(int32_t), hipMemcpyDeviceToHost); - uint64_t host_max = 0; - for (uint64_t i = 0; i < b; i++) { - int32_t len = host_cu[i + 1] - host_cu[i]; - uint64_t u = static_cast(len); - if (len < 0) { - std::cout << "[get_runtime_max_seqlen] b=" << b << " NEGATIVE len at i=" << i - << " cu[" << i << "]=" << host_cu[i] << " cu[" << (i+1) << "]=" << host_cu[i+1] - << " (kernel would produce garbage uint64)" << std::endl; - } - if (u > host_max) host_max = u; - } - const size_t n = static_cast(b) + 1; - std::cout << "[get_runtime_max_seqlen] b=" << b << " shape=(" << n << ",) cu_seqlen[0..4]="; - for (size_t i = 0; i < std::min(n, size_t(5)); i++) std::cout << host_cu[i] << " "; - std::cout << " ... cu_seqlen[" << (n-5) << ".." << (n-1) << "]="; - for (size_t i = n - std::min(n, size_t(5)); i < n; i++) std::cout << host_cu[i] << " "; - std::cout << " host_max_seqlen=" << host_max << " device_returned=" << runtime_max_seqlen << std::endl; - } - return runtime_max_seqlen; } diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 9cac3595f..c5de6bda3 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -9,7 +9,7 @@ #include // Required for std::accumulate #ifdef USE_FUSED_ATTN_CK #include -#include "fused_attn_smallseq.hpp" +#include "fused_attn_smallseq.h" #endif // USE_FUSED_ATTN_CK #include "../util/cuda_runtime.h" #include "../util/system.h" @@ -619,19 +619,18 @@ void fused_attn_ck_fwd_impl( void* max_seqlen_workspace = workspace; size_t runtime_max_seqlen_q = static_cast(ck_fused_attn::get_runtime_max_seqlen( - static_cast(b), devPtrCuSeqlensQ, nullptr, - max_seqlen_workspace, reinterpret_cast(stream))); + static_cast(b), devPtrCuSeqlensQ, nullptr, max_seqlen_workspace, stream)); size_t runtime_max_seqlen_kv = static_cast(ck_fused_attn::get_runtime_max_seqlen( - static_cast(b), devPtrCuSeqlensKV, nullptr, - max_seqlen_workspace, reinterpret_cast(stream))); + static_cast(b), devPtrCuSeqlensKV, nullptr, max_seqlen_workspace, stream)); - if (nvte_log_ck_config) { - std::cout << std::endl << "[CK small-seq] fused_attn_ck_fwd_impl: is_ragged=1 b=" << b - << " runtime_max_seqlen_q=" << runtime_max_seqlen_q - << " runtime_max_seqlen_kv=" << runtime_max_seqlen_kv << std::endl; + if (std::getenv("NVTE_LOG_CK_CONFIG")) { + std::cout << std::endl << "attn_fwd(ck small-seq): "; + std::cout << "b: " << b << ", "; + std::cout << "runtime_max_seqlen_q: " << runtime_max_seqlen_q << ", "; + std::cout << "runtime_max_seqlen_kv: " << runtime_max_seqlen_kv << std::endl; } - if (runtime_max_seqlen_kv >= 2 && runtime_max_seqlen_kv <= 16) { + if (runtime_max_seqlen_q==1 && runtime_max_seqlen_kv >= 2 && runtime_max_seqlen_kv <= 16) { fused_attn_rocm::fused_attn_smallseq_fwd( b, h, hg, runtime_max_seqlen_kv, d_qk, d_v, is_training, scaling_factor, dropout_probability, @@ -946,22 +945,24 @@ void fused_attn_ck_bwd_impl( void* workspace_next = workspace; if (is_ragged) { - void* max_seqlen_workspace_bwd = workspace; - // When s_q == 1 use 1 for runtime_max_seqlen_q (Q cu_seqlens layout may differ in JAX THD). - size_t runtime_max_seqlen_q_bwd = (s_q == 1) ? 1u : static_cast(ck_fused_attn::get_runtime_max_seqlen( - static_cast(b), devPtrCuSeqlensQ, nullptr, - max_seqlen_workspace_bwd, reinterpret_cast(stream))); - size_t runtime_max_seqlen_kv_bwd = static_cast(ck_fused_attn::get_runtime_max_seqlen( - static_cast(b), devPtrCuSeqlensKV, nullptr, - max_seqlen_workspace_bwd, reinterpret_cast(stream))); - if (nvte_log_ck_config) { - std::cout << std::endl << "[CK small-seq] fused_attn_ck_bwd_impl: is_ragged=1 runtime_max_seqlen_q=" - << runtime_max_seqlen_q_bwd << " runtime_max_seqlen_kv=" << runtime_max_seqlen_kv_bwd << std::endl; + void* max_seqlen_workspace = workspace; + + size_t runtime_max_seqlen_q = static_cast(ck_fused_attn::get_runtime_max_seqlen( + b, devPtrCuSeqlensQ, nullptr, max_seqlen_workspace, stream)); + size_t runtime_max_seqlen_kv = static_cast(ck_fused_attn::get_runtime_max_seqlen( + b, devPtrCuSeqlensKV, nullptr, max_seqlen_workspace, stream)); + + if (std::getenv("NVTE_LOG_CK_CONFIG")) { + std::cout << std::endl << "attn_bwd(ck small-seq): "; + std::cout << "b: " << b << ", "; + std::cout << "runtime_max_seqlen_q: " << runtime_max_seqlen_q << ", "; + std::cout << "runtime_max_seqlen_kv: " << runtime_max_seqlen_kv << std::endl; } - if (runtime_max_seqlen_q_bwd == 1 && runtime_max_seqlen_kv_bwd >= 2 && runtime_max_seqlen_kv_bwd <= 16) { + + if (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && runtime_max_seqlen_kv <= 16) { fused_attn_rocm::fused_attn_smallseq_bwd( - b, h, hg, runtime_max_seqlen_kv_bwd, d_qk, d_v, + b, h, hg, runtime_max_seqlen_kv, d_qk, d_v, scaling_factor, dropout_probability, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrdO, devPtrSoftmaxAux, devPtrdQ, devPtrdK, devPtrdV, diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp index 04ec1dea8..546dce4ed 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp @@ -18,7 +18,7 @@ #include "../common.h" #include "../util/cuda_runtime.h" -#include "fused_attn_smallseq.hpp" +#include "fused_attn_smallseq.h" #include "utils.h" // Macros to avoid repeating dispatch switch cases for max_seqlen_kv in [2, 16]. @@ -833,17 +833,20 @@ void fused_attn_smallseq_fwd(size_t b, size_t* workspace_size, cudaStream_t stream) { - if (std::getenv("NVTE_LOG_CK_CONFIG")) { - std::cout << "[fused_attn_smallseq_fwd] ENTRY - all params: b=" << b << " h_q=" << h_q - << " h_kv=" << h_kv << " max_seqlen_kv=" << max_seqlen_kv << " d_qk=" << d_qk - << " d_v=" << d_v << " is_training=" << is_training << " attn_scale=" << attn_scale - << " dropout=" << dropout << " qkv_dtype=" + if (std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ")) { + std::cout << std::endl << "attn_fwd(ck small-seq kernel): "; + std::cout << "b: " << b << ", "; + std::cout << "h_q: " << h_q << ", "; + std::cout << "h_kv: " << h_kv << ", "; + std::cout << "max_seqlen_kv: " << max_seqlen_kv << ", "; + std::cout << "d_qk: " << d_qk << ", "; + std::cout << "d_v: " << d_v << ", "; + std::cout << "is_training: " << is_training << ", "; + std::cout << "attn_scale: " << attn_scale << ", "; + std::cout << "dropout: " << dropout << ", "; + std::cout << "qkv_dtype: " << (qkv_dtype == DType::kBFloat16 ? "BF16" : qkv_dtype == DType::kFloat16 ? "FP16" : "?") - << " devPtrQ=" << devPtrQ << " devPtrK=" << devPtrK << " devPtrV=" << devPtrV - << " devPtrO=" << devPtrO << " attn_weights_buffer=" << attn_weights_buffer - << " devPtrCuSeqlensKV=" << devPtrCuSeqlensKV - << " devPtrSeqOffsetsKV=" << devPtrSeqOffsetsKV << " workspace=" << workspace - << " stream=" << stream << std::endl; + << std::endl; } (void)h_kv; (void)d_qk; @@ -921,18 +924,19 @@ void fused_attn_smallseq_bwd(size_t b, size_t* workspace_size, cudaStream_t stream) { - if (std::getenv(" NVTE_LOG_CK_CONFIG")) { - std::cout << "[fused_attn_smallseq_bwd] ENTRY - all params: b=" << b << " h_q=" << h_q - << " h_kv=" << h_kv << " max_seqlen_kv=" << max_seqlen_kv << " d_qk=" << d_qk - << " d_v=" << d_v << " attn_scale=" << attn_scale << " dropout=" << dropout - << " qkv_dtype=" + if (std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ")) { + std::cout << std::endl << "attn_bwd(ck small-seq kernel): "; + std::cout << "b: " << b << ", "; + std::cout << "h_q: " << h_q << ", "; + std::cout << "h_kv: " << h_kv << ", "; + std::cout << "max_seqlen_kv: " << max_seqlen_kv << ", "; + std::cout << "d_qk: " << d_qk << ", "; + std::cout << "d_v: " << d_v << ", "; + std::cout << "attn_scale: " << attn_scale << ", "; + std::cout << "dropout: " << dropout << ", "; + std::cout << "qkv_dtype: " << (qkv_dtype == DType::kBFloat16 ? "BF16" : qkv_dtype == DType::kFloat16 ? "FP16" : "?") - << " devPtrQ=" << devPtrQ << " devPtrK=" << devPtrK << " devPtrV=" << devPtrV - << " devPtrO=" << devPtrO << " devPtrdO=" << devPtrdO << " attn_weights=" << attn_weights - << " devPtrdQ=" << devPtrdQ << " devPtrdK=" << devPtrdK << " devPtrdV=" << devPtrdV - << " devPtrCuSeqlensKV=" << devPtrCuSeqlensKV - << " devPtrSeqOffsetsKV=" << devPtrSeqOffsetsKV << " workspace=" << workspace - << " stream=" << stream << std::endl; + << std::endl; } (void)h_kv; (void)d_qk; diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.hpp b/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.h similarity index 99% rename from transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.hpp rename to transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.h index f21bfaa0c..ad3d10285 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.hpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.h @@ -4,7 +4,7 @@ * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ -/*! \file fused_attn_smallseq.hpp +/*! \file fused_attn_smallseq.h * \brief Unfused small-seq (varlen) attention for ROCm: seq_q=1, max_seqlen_kv<=16, THD only. */ diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 8a4a84de5..663d0ceea 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -393,14 +393,11 @@ def abstract( raise ValueError(f"Unsupported {backend=}") if os.environ.get("NVTE_LOG_CK_CONFIG"): - import sys - msg = ( - f"[CK small-seq JAX] fused_attn abstract: backend={backend!s} " - f"batch_shape={batch_shape} q_max_seqlen={q_max_seqlen} " - f"kv_max_seqlen={kv_max_seqlen} attn_heads={attn_heads} " - f"softmax_shape={softmax_shape} softmax_dtype={softmax_dtype} " + print( + "attn_fwd(ck small-seq JAX abstract): " + f"batch_shape: {batch_shape}, softmax_shape: {softmax_shape}, softmax_dtype: {softmax_dtype}" ) - print(msg, file=sys.stderr, flush=True) + softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype) # JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 544df56fa..724010b59 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -223,7 +223,6 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; \ auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; \ size_t num_segments = input_batch; \ - std::cerr << "[FUSED_ATTN_IMPL_COMMON_BLOCK] input_batch=" << input_batch << std::endl; \ if (is_ragged) { \ auto cudnn_runtime_version = cudnnGetVersion(); \ num_segments = input_batch * max_segments_per_seq; \ @@ -522,15 +521,16 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( workspace_bytes = workspace_elems * elt_size; } - std::cerr << "[GetFusedAttnBackwardWorkspaceSizes] input_batch=" << input_batch - << " is_ragged=" << is_ragged << " workspace_shape=("; - for (size_t i = 0; i < work_shape.size(); ++i) { - std::cerr << (i ? "," : "") << work_shape[i]; + if (std::getenv("NVTE_LOG_CK_CONFIG")) { + std::cout << std::endl << "attn_bwd(ck small-seq workspace size): "; + std::cout << "input_batch: " << input_batch << ", "; + std::cout << "is_ragged: " << is_ragged << ", "; + std::cout << "workspace_elems: " << workspace_elems << ", "; + std::cout << "workspace_bytes: " << workspace_bytes << ", "; + std::cout << "small_seq_min_bytes: " << fused_small_seq_workspace << ", "; + std::cout << "workspace_bytes >= fused_small_seq_workspace: " << (workspace_bytes >= fused_small_seq_workspace ? "true" : "false") + << std::endl; } - std::cerr << ") workspace_elems=" << workspace_elems << " workspace_bytes=" << workspace_bytes - << " b*h*16*2=" << fused_small_seq_workspace - << " (workspace_bytes>=b*h*16*2)=" << (workspace_bytes >= fused_small_seq_workspace) - << std::endl; return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); } From b6a5ee8ec2d518e9bf27c84e2e548821db01f6ba Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Fri, 27 Feb 2026 08:33:48 +0000 Subject: [PATCH 04/21] Guard CK small-seq behind NVTE_FUSED_ATTN_CK_SMALLSEQ=1; add FP16 support to small-seq kernels --- tests/jax/test_fused_attn.py | 54 ++++-- .../include/ck_fused_attn/ck_fused_attn.hpp | 2 +- .../ck_fused_attn/src/ck_fused_attn_utils.cpp | 1 - .../common/fused_attn_rocm/fused_attn_ck.cpp | 12 +- .../fused_attn_rocm/fused_attn_smallseq.cpp | 156 +++++++++++++----- .../fused_attn_rocm/fused_attn_smallseq.h | 4 +- .../jax/cpp_extensions/attention.py | 29 ++-- .../jax/csrc/extensions/attention.cpp | 45 ++--- 8 files changed, 196 insertions(+), 107 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index c598ffdaa..471666699 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -9,6 +9,7 @@ from functools import partial from math import sqrt from typing import Tuple, Optional, Dict +import os import random import jax @@ -329,7 +330,12 @@ class FusedAttnRunner: # generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases. def _get_max_segments_per_sequence(self): if self.qkv_layout.is_thd(): - if 90400 <= get_cudnn_version() < 90500 or self.max_seqlen_q == 1: + if ( + 90400 <= get_cudnn_version() < 90500 + or ( self.max_seqlen_q == 1 and + is_hip_extension() and + os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") == "1") + ): return self.num_segments_per_seq else: # +1 for testing runtime_segments < max_segments @@ -539,7 +545,7 @@ def generate_random_segment_ids( return segment_ids, segment_pos, segment_pad if self.qkv_layout.is_thd(): - if self.max_seqlen_q == 1: + if self.max_seqlen_q == 1 and is_hip_extension() and os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") == "1": self.num_segments_per_seq = 1 # Q: deterministic — one segment of length 1 per batch -> cu_seqlen [0,1,2,...,batch_size] self.segment_ids_q = jnp.ones((self.batch_size, self.max_seqlen_q), dtype=jnp.int32) @@ -555,7 +561,6 @@ def generate_random_segment_ids( ) # KV: one segment per batch (num_segments_per_seq=1) to match smallseq kernel - # expectations (batch_size == max_tokens_q, cu_seqlens of size batch_size+1). min_segment_len = None if self.window_size is None else self.seqlens_q self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = ( generate_random_segment_ids( @@ -1247,26 +1252,43 @@ def test_jax_new_rng(): runner.test_forward() -# ROCm CK internal small-seq (varlen unfused) branch tests. +# ROCm CK small-seq varlen tests. # Uses THD_THD_THD with s_q=1, s_kv<=16 so the small-seq path is taken. +# Run only when NVTE_FUSED_ATTN_CK_SMALLSEQ=1. +@pytest.mark.skipif( + os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") != "1", + reason="CK unfused smallseq tests require NVTE_FUSED_ATTN_CK_SMALLSEQ=1", +) @pytest.mark.skipif( not is_hip_extension(), reason="CK unfused smallseq backend only available on AMD hardware" ) @pytest.mark.parametrize( "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype", [ - pytest.param(30720, 1, 2, 16, 16, 128, 128, jnp.bfloat16, - id="30720-1-2-16-16-128-128-BF16"), - pytest.param(30720, 1, 4, 16, 16, 128, 128, jnp.bfloat16, - id="30720-1-4-16-16-128-128-BF16"), - pytest.param(30720, 1, 6, 16, 16, 128, 128, jnp.bfloat16, - id="30720-1-6-16-16-128-128-BF16"), - pytest.param(30720, 1, 8, 16, 16, 128, 128, jnp.bfloat16, - id="30720-1-8-16-16-128-128-BF16"), - pytest.param(30720, 1, 12, 16, 16, 128, 128, jnp.bfloat16, - id="30720-1-12-16-16-128-128-BF16"), - pytest.param(30720, 1, 16, 16, 16, 128, 128, jnp.bfloat16, - id="30720-1-16-16-16-128-128-BF16"), + pytest.param(4000, 1, 2, 16, 16, 128, 128, jnp.bfloat16, + id="4000-1-2-16-16-128-128-BF16"), + pytest.param(4000, 1, 4, 16, 16, 128, 128, jnp.bfloat16, + id="4000-1-4-16-16-128-128-BF16"), + pytest.param(4000, 1, 6, 16, 16, 128, 128, jnp.bfloat16, + id="4000-1-6-16-16-128-128-BF16"), + pytest.param(4000, 1, 8, 16, 16, 128, 128, jnp.bfloat16, + id="4000-1-8-16-16-128-128-BF16"), + pytest.param(4000, 1, 12, 16, 16, 128, 128, jnp.bfloat16, + id="4000-1-12-16-16-128-128-BF16"), + pytest.param(4000, 1, 16, 16, 16, 128, 128, jnp.bfloat16, + id="4000-1-16-16-16-128-128-BF16"), + pytest.param(4000, 1, 2, 16, 16, 128, 128, jnp.float16, + id="4000-1-2-16-16-128-128-FP16"), + pytest.param(4000, 1, 4, 16, 16, 128, 128, jnp.float16, + id="4000-1-4-16-16-128-128-FP16"), + pytest.param(4000, 1, 6, 16, 16, 128, 128, jnp.float16, + id="4000-1-6-16-16-128-128-FP16"), + pytest.param(4000, 1, 8, 16, 16, 128, 128, jnp.float16, + id="4000-1-8-16-16-128-128-FP16"), + pytest.param(4000, 1, 12, 16, 16, 128, 128, jnp.float16, + id="4000-1-12-16-16-128-128-FP16"), + pytest.param(4000, 1, 16, 16, 16, 128, 128, jnp.float16, + id="4000-1-16-16-16-128-128-FP16"), ], ) def test_ck_unfused_smallseq_backend(b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype): diff --git a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp index 736aa0f99..47528c020 100644 --- a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp +++ b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp index 6bb9f96e3..26c92ca2b 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp @@ -5,7 +5,6 @@ ************************************************************************/ #include - #include "ck_fused_attn_utils.hpp" #include "ck_fused_attn/ck_fused_attn.hpp" #include "mask.hpp" diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index c5de6bda3..6a293c88c 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ @@ -615,9 +615,10 @@ void fused_attn_ck_fwd_impl( // denote the next available section of workspace from upstream void* workspace_next = workspace; - if (is_ragged) { + const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ"); + if (is_ragged && nvte_smallseq && std::string(nvte_smallseq) == "1") { void* max_seqlen_workspace = workspace; - + size_t runtime_max_seqlen_q = static_cast(ck_fused_attn::get_runtime_max_seqlen( static_cast(b), devPtrCuSeqlensQ, nullptr, max_seqlen_workspace, stream)); size_t runtime_max_seqlen_kv = static_cast(ck_fused_attn::get_runtime_max_seqlen( @@ -630,7 +631,7 @@ void fused_attn_ck_fwd_impl( std::cout << "runtime_max_seqlen_kv: " << runtime_max_seqlen_kv << std::endl; } - if (runtime_max_seqlen_q==1 && runtime_max_seqlen_kv >= 2 && runtime_max_seqlen_kv <= 16) { + if (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && runtime_max_seqlen_kv <= 16) { fused_attn_rocm::fused_attn_smallseq_fwd( b, h, hg, runtime_max_seqlen_kv, d_qk, d_v, is_training, scaling_factor, dropout_probability, @@ -944,7 +945,8 @@ void fused_attn_ck_bwd_impl( // denote the next available section of workspace from upstream void* workspace_next = workspace; - if (is_ragged) { + const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ"); + if (is_ragged && nvte_smallseq && std::string(nvte_smallseq) == "1") { void* max_seqlen_workspace = workspace; size_t runtime_max_seqlen_q = static_cast(ck_fused_attn::get_runtime_max_seqlen( diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp index 546dce4ed..4bf84320a 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp @@ -1,16 +1,16 @@ /************************************************************************* - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ /*! \file fused_attn_smallseq.cpp * \brief Unfused small-seq (varlen) attention: seq_q=1, max_seqlen_kv<=16, THD only. - * Ported from varlen_attn/attn_fwd.cpp and attn_bwd.cpp with runtime b, head_num. */ #include #include +#include #include #include @@ -106,30 +106,30 @@ __global__ void compute_scores_kernel(const T* Q, for (int i = 0; i < seq_kv; i++) results[i] = 0.0f; for (int dim_offset = 0; dim_offset < head_dim; dim_offset += block_k) { - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value || std::is_same::value) { for (int k = 0; k < block_k / 8; k++) { ls_dwordx4_tmp_var = *((uint4*)&Q_ptr[dim_offset + k * 8]); - fetch_Q[k * 8 + 0] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.x)[0]; - fetch_Q[k * 8 + 1] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.x)[1]; - fetch_Q[k * 8 + 2] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.y)[0]; - fetch_Q[k * 8 + 3] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.y)[1]; - fetch_Q[k * 8 + 4] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.z)[0]; - fetch_Q[k * 8 + 5] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.z)[1]; - fetch_Q[k * 8 + 6] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.w)[0]; - fetch_Q[k * 8 + 7] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.w)[1]; + fetch_Q[k * 8 + 0] = ((T*)&ls_dwordx4_tmp_var.x)[0]; + fetch_Q[k * 8 + 1] = ((T*)&ls_dwordx4_tmp_var.x)[1]; + fetch_Q[k * 8 + 2] = ((T*)&ls_dwordx4_tmp_var.y)[0]; + fetch_Q[k * 8 + 3] = ((T*)&ls_dwordx4_tmp_var.y)[1]; + fetch_Q[k * 8 + 4] = ((T*)&ls_dwordx4_tmp_var.z)[0]; + fetch_Q[k * 8 + 5] = ((T*)&ls_dwordx4_tmp_var.z)[1]; + fetch_Q[k * 8 + 6] = ((T*)&ls_dwordx4_tmp_var.w)[0]; + fetch_Q[k * 8 + 7] = ((T*)&ls_dwordx4_tmp_var.w)[1]; } for (int kv_idx = 0; kv_idx < seq_kv; kv_idx++) { for (int k = 0; k < block_k / 8; k++) { ls_dwordx4_tmp_var = *((uint4*)&K_ptr[kv_idx * head_num * head_dim + dim_offset + k * 8]); - fetch_K[k * 8 + 0] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.x)[0]; - fetch_K[k * 8 + 1] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.x)[1]; - fetch_K[k * 8 + 2] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.y)[0]; - fetch_K[k * 8 + 3] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.y)[1]; - fetch_K[k * 8 + 4] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.z)[0]; - fetch_K[k * 8 + 5] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.z)[1]; - fetch_K[k * 8 + 6] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.w)[0]; - fetch_K[k * 8 + 7] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.w)[1]; + fetch_K[k * 8 + 0] = ((T*)&ls_dwordx4_tmp_var.x)[0]; + fetch_K[k * 8 + 1] = ((T*)&ls_dwordx4_tmp_var.x)[1]; + fetch_K[k * 8 + 2] = ((T*)&ls_dwordx4_tmp_var.y)[0]; + fetch_K[k * 8 + 3] = ((T*)&ls_dwordx4_tmp_var.y)[1]; + fetch_K[k * 8 + 4] = ((T*)&ls_dwordx4_tmp_var.z)[0]; + fetch_K[k * 8 + 5] = ((T*)&ls_dwordx4_tmp_var.z)[1]; + fetch_K[k * 8 + 6] = ((T*)&ls_dwordx4_tmp_var.w)[0]; + fetch_K[k * 8 + 7] = ((T*)&ls_dwordx4_tmp_var.w)[1]; } #pragma unroll for (int k = 0; k < block_k; k++) @@ -502,30 +502,30 @@ __global__ void compute_grad_attn_kernel(const T* grad_O, results[i] = 0.0f; for (int dim_offset = 0; dim_offset < head_dim; dim_offset += block_k) { - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value || std::is_same::value) { for (int k = 0; k < block_k / 8; k++) { ls_dwordx4_tmp_var = *((uint4*)&grad_O_ptr[dim_offset + k * 8]); - fetch_grad_O[k * 8 + 0] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.x)[0]; - fetch_grad_O[k * 8 + 1] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.x)[1]; - fetch_grad_O[k * 8 + 2] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.y)[0]; - fetch_grad_O[k * 8 + 3] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.y)[1]; - fetch_grad_O[k * 8 + 4] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.z)[0]; - fetch_grad_O[k * 8 + 5] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.z)[1]; - fetch_grad_O[k * 8 + 6] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.w)[0]; - fetch_grad_O[k * 8 + 7] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.w)[1]; + fetch_grad_O[k * 8 + 0] = ((T*)&ls_dwordx4_tmp_var.x)[0]; + fetch_grad_O[k * 8 + 1] = ((T*)&ls_dwordx4_tmp_var.x)[1]; + fetch_grad_O[k * 8 + 2] = ((T*)&ls_dwordx4_tmp_var.y)[0]; + fetch_grad_O[k * 8 + 3] = ((T*)&ls_dwordx4_tmp_var.y)[1]; + fetch_grad_O[k * 8 + 4] = ((T*)&ls_dwordx4_tmp_var.z)[0]; + fetch_grad_O[k * 8 + 5] = ((T*)&ls_dwordx4_tmp_var.z)[1]; + fetch_grad_O[k * 8 + 6] = ((T*)&ls_dwordx4_tmp_var.w)[0]; + fetch_grad_O[k * 8 + 7] = ((T*)&ls_dwordx4_tmp_var.w)[1]; } for (int kv_idx = 0; kv_idx < seq_kv; kv_idx++) { for (int k = 0; k < block_k / 8; k++) { ls_dwordx4_tmp_var = *((uint4*)&V_base[kv_idx * V_stride + dim_offset + k * 8]); - fetch_V[k * 8 + 0] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.x)[0]; - fetch_V[k * 8 + 1] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.x)[1]; - fetch_V[k * 8 + 2] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.y)[0]; - fetch_V[k * 8 + 3] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.y)[1]; - fetch_V[k * 8 + 4] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.z)[0]; - fetch_V[k * 8 + 5] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.z)[1]; - fetch_V[k * 8 + 6] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.w)[0]; - fetch_V[k * 8 + 7] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.w)[1]; + fetch_V[k * 8 + 0] = ((T*)&ls_dwordx4_tmp_var.x)[0]; + fetch_V[k * 8 + 1] = ((T*)&ls_dwordx4_tmp_var.x)[1]; + fetch_V[k * 8 + 2] = ((T*)&ls_dwordx4_tmp_var.y)[0]; + fetch_V[k * 8 + 3] = ((T*)&ls_dwordx4_tmp_var.y)[1]; + fetch_V[k * 8 + 4] = ((T*)&ls_dwordx4_tmp_var.z)[0]; + fetch_V[k * 8 + 5] = ((T*)&ls_dwordx4_tmp_var.z)[1]; + fetch_V[k * 8 + 6] = ((T*)&ls_dwordx4_tmp_var.w)[0]; + fetch_V[k * 8 + 7] = ((T*)&ls_dwordx4_tmp_var.w)[1]; } #pragma unroll for (int k = 0; k < block_k; k++) @@ -708,7 +708,7 @@ __global__ void compute_grad_qk_kernel(const T* grad_scores, head_dim + thread_head_offset + i * dwordx4_load_elt]; for (int b = 0; b < dwordx4_load_elt; b++) - grad_Q_ptr[b] = ((T*)&store_dwordx4_tmp_var[i])[b] * scale; + grad_Q_ptr[b] = ((T*)&store_dwordx4_tmp_var[i])[b] * T(scale); } #pragma unroll for (int i = 0; i < block_k / dwordx4_load_elt; i++) { @@ -833,8 +833,9 @@ void fused_attn_smallseq_fwd(size_t b, size_t* workspace_size, cudaStream_t stream) { - if (std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ")) { - std::cout << std::endl << "attn_fwd(ck small-seq kernel): "; + const char* nvte_smallseq = std::getenv("NVTE_LOG_CK_CONFIG"); + if (nvte_smallseq && std::string(nvte_smallseq) == "1") { + std::cout << std::endl << "attn_fwd(small-seq kernel): "; std::cout << "b: " << b << ", "; std::cout << "h_q: " << h_q << ", "; std::cout << "h_kv: " << h_kv << ", "; @@ -871,6 +872,38 @@ void fused_attn_smallseq_fwd(size_t b, int bi = static_cast(b); int hi = static_cast(h_q); + switch (max_seqlen_kv) { + SMALLSEQ_DISPATCH_FWD_CASE(2) + SMALLSEQ_DISPATCH_FWD_CASE(3) + SMALLSEQ_DISPATCH_FWD_CASE(4) + SMALLSEQ_DISPATCH_FWD_CASE(5) + SMALLSEQ_DISPATCH_FWD_CASE(6) + SMALLSEQ_DISPATCH_FWD_CASE(7) + SMALLSEQ_DISPATCH_FWD_CASE(8) + SMALLSEQ_DISPATCH_FWD_CASE(9) + SMALLSEQ_DISPATCH_FWD_CASE(10) + SMALLSEQ_DISPATCH_FWD_CASE(11) + SMALLSEQ_DISPATCH_FWD_CASE(12) + SMALLSEQ_DISPATCH_FWD_CASE(13) + SMALLSEQ_DISPATCH_FWD_CASE(14) + SMALLSEQ_DISPATCH_FWD_CASE(15) + SMALLSEQ_DISPATCH_FWD_CASE(16) + default: + NVTE_ERROR("Unsupported max_seqlen_kv for small-seq: max_seqlen_kv <= 16."); + } + } else if (qkv_dtype == DType::kFloat16) { + using T = __half; + const T* Q_ptr = static_cast(devPtrQ); + const T* K_ptr = static_cast(devPtrK); + const T* V_ptr = static_cast(devPtrV); + T* O_ptr = static_cast(devPtrO); + T* attn_workspace = static_cast(attn_weights_buffer); + const int* cu_kv = static_cast(devPtrCuSeqlensKV); + const int* cu_kv_p = static_cast(devPtrSeqOffsetsKV); + const T* dropout_mask = nullptr; + int bi = static_cast(b); + int hi = static_cast(h_q); + switch (max_seqlen_kv) { SMALLSEQ_DISPATCH_FWD_CASE(2) SMALLSEQ_DISPATCH_FWD_CASE(3) @@ -891,7 +924,7 @@ void fused_attn_smallseq_fwd(size_t b, NVTE_ERROR("Unsupported max_seqlen_kv for small-seq: max_seqlen_kv <= 16."); } } else { - NVTE_ERROR("small-seq path supports only BF16 (and optionally FP16)."); + NVTE_ERROR("small-seq path supports only BF16 and FP16."); } if (workspace_size) { @@ -941,10 +974,6 @@ void fused_attn_smallseq_bwd(size_t b, (void)h_kv; (void)d_qk; (void)d_v; - NVTE_CHECK(max_seqlen_kv >= 2 && max_seqlen_kv <= 16, - "small-seq path requires 2 <= max_seqlen_kv <= 16."); - NVTE_CHECK(d_qk == 128 && d_v == 128, "small-seq path currently supports head_dim 128 only."); - NVTE_CHECK(workspace != nullptr, "small-seq bwd requires workspace."); float sqr_dk_scale = attn_scale; hipStream_t hip_stream = reinterpret_cast(stream); @@ -967,6 +996,43 @@ void fused_attn_smallseq_bwd(size_t b, int bi = static_cast(b); int hi = static_cast(h_q); + switch (max_seqlen_kv) { + SMALLSEQ_DISPATCH_BWD_CASE(2) + SMALLSEQ_DISPATCH_BWD_CASE(3) + SMALLSEQ_DISPATCH_BWD_CASE(4) + SMALLSEQ_DISPATCH_BWD_CASE(5) + SMALLSEQ_DISPATCH_BWD_CASE(6) + SMALLSEQ_DISPATCH_BWD_CASE(7) + SMALLSEQ_DISPATCH_BWD_CASE(8) + SMALLSEQ_DISPATCH_BWD_CASE(9) + SMALLSEQ_DISPATCH_BWD_CASE(10) + SMALLSEQ_DISPATCH_BWD_CASE(11) + SMALLSEQ_DISPATCH_BWD_CASE(12) + SMALLSEQ_DISPATCH_BWD_CASE(13) + SMALLSEQ_DISPATCH_BWD_CASE(14) + SMALLSEQ_DISPATCH_BWD_CASE(15) + SMALLSEQ_DISPATCH_BWD_CASE(16) + default: + NVTE_ERROR("Unsupported max_seqlen_kv for small-seq: max_seqlen_kv <= 16."); + } + } else if (qkv_dtype == DType::kFloat16) { + using T = __half; + const T* Q_ptr = static_cast(devPtrQ); + const T* K_ptr = static_cast(devPtrK); + const T* V_ptr = static_cast(devPtrV); + const T* O_ptr = static_cast(devPtrO); + const T* dO_ptr = static_cast(devPtrdO); + const T* attn_ptr = static_cast(attn_weights); + T* dQ_ptr = static_cast(devPtrdQ); + T* dK_ptr = static_cast(devPtrdK); + T* dV_ptr = static_cast(devPtrdV); + T* workspace_ptr = static_cast(workspace); + const int* cu_kv = static_cast(devPtrCuSeqlensKV); + const int* cu_kv_p = static_cast(devPtrSeqOffsetsKV); + const T* dropout_mask = nullptr; + int bi = static_cast(b); + int hi = static_cast(h_q); + switch (max_seqlen_kv) { SMALLSEQ_DISPATCH_BWD_CASE(2) SMALLSEQ_DISPATCH_BWD_CASE(3) @@ -987,7 +1053,7 @@ void fused_attn_smallseq_bwd(size_t b, NVTE_ERROR("Unsupported max_seqlen_kv for small-seq: max_seqlen_kv <= 16."); } } else { - NVTE_ERROR("small-seq path supports only BF16 (and optionally FP16)."); + NVTE_ERROR("small-seq path supports only BF16 and FP16."); } if (workspace_size) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.h b/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.h index ad3d10285..9a5e8cefc 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.h +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.h @@ -1,11 +1,11 @@ /************************************************************************* - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ /*! \file fused_attn_smallseq.h - * \brief Unfused small-seq (varlen) attention for ROCm: seq_q=1, max_seqlen_kv<=16, THD only. + * \brief Small-seq (varlen) attention for ROCm: seq_q=1, max_seqlen_kv<=16, THD only. */ #ifndef TRANSFORMER_ENGINE_FUSED_ATTN_ROCM_FUSED_ATTN_SMALLSEQ_H_ diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 663d0ceea..90860ff09 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -366,26 +366,21 @@ def abstract( softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) elif backend == NVTE_Fused_Attn_Backend.NVTE_CK: if config.qkv_layout.is_thd(): - batch_size = reduce(operator.mul, batch_shape) - old_ck_softmax_aux_size = ( - batch_size * attn_heads * q_max_seqlen * jnp.dtype(jnp.float32).itemsize - ) - possible_special_cross_attn_softmax_aux_size = ( - batch_size * attn_heads * q_max_seqlen - * min(kv_max_seqlen, 16) * 2 - ) # 2 bytes for bf16/fp16 - if (old_ck_softmax_aux_size - >= possible_special_cross_attn_softmax_aux_size): + # THD only: check env; run small-seq logic only when enabled + if os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") != "1": softmax_shape = (*batch_shape, q_max_seqlen, attn_heads, 1) softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) else: - softmax_shape = ( - *batch_shape, - attn_heads, - q_max_seqlen, - min(kv_max_seqlen, 16), - ) - softmax_dtype = dtypes.canonicalize_dtype(q_dtype) + batch_size = reduce(operator.mul, batch_shape) + old_ck_softmax_size = (batch_size * attn_heads * q_max_seqlen * 1) + possible_ck_smallseq_softmax_size = (batch_size * attn_heads * + q_max_seqlen * min(kv_max_seqlen, 16) * 2) # 2 bytes for bf16/fp16 + if old_ck_softmax_size >= possible_ck_smallseq_softmax_size: + softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1) + softmax_dtype = dtypes.canonicalize_dtype(q_dtype) + else: + softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, min(kv_max_seqlen, 16)) + softmax_dtype = dtypes.canonicalize_dtype(q_dtype) else: softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1) softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 724010b59..2994e1e97 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -509,27 +509,32 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( nvte_tensor_pack_destroy(&aux_input_tensors); auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); - size_t workspace_elems = product(work_shape); - size_t elt_size = transformer_engine::typeToSize(query_workspace_tensor.dtype()); - size_t workspace_bytes = workspace_elems * elt_size; - size_t fused_small_seq_workspace = input_batch * attn_heads * 16 * 2; // min for small-seq (bf16/fp16) - - if (is_ragged && workspace_bytes < fused_small_seq_workspace) { - size_t min_elems = (fused_small_seq_workspace + elt_size - 1) / elt_size; - work_shape = std::vector{min_elems}; - workspace_elems = min_elems; - workspace_bytes = workspace_elems * elt_size; - } + + const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ"); + if (nvte_smallseq && std::string(nvte_smallseq) == "1") { + size_t workspace_elems = product(work_shape); + size_t elt_size = transformer_engine::typeToSize(query_workspace_tensor.dtype()); + size_t workspace_bytes = workspace_elems * elt_size; + size_t fused_small_seq_workspace = input_batch * attn_heads * 16 * 2; // min for small-seq (bf16/fp16) + + if (is_ragged && workspace_bytes < fused_small_seq_workspace) { + size_t min_elems = (fused_small_seq_workspace + elt_size - 1) / elt_size; + work_shape = std::vector{min_elems}; + workspace_elems = min_elems; + workspace_bytes = workspace_elems * elt_size; + } - if (std::getenv("NVTE_LOG_CK_CONFIG")) { - std::cout << std::endl << "attn_bwd(ck small-seq workspace size): "; - std::cout << "input_batch: " << input_batch << ", "; - std::cout << "is_ragged: " << is_ragged << ", "; - std::cout << "workspace_elems: " << workspace_elems << ", "; - std::cout << "workspace_bytes: " << workspace_bytes << ", "; - std::cout << "small_seq_min_bytes: " << fused_small_seq_workspace << ", "; - std::cout << "workspace_bytes >= fused_small_seq_workspace: " << (workspace_bytes >= fused_small_seq_workspace ? "true" : "false") - << std::endl; + const char* nvte_log_ck_config = std::getenv("NVTE_LOG_CK_CONFIG"); + if (nvte_log_ck_config && std::string(nvte_log_ck_config) == "1") { + std::cout << std::endl << "attn_bwd(ck small-seq workspace size): "; + std::cout << "input_batch: " << input_batch << ", "; + std::cout << "is_ragged: " << is_ragged << ", "; + std::cout << "workspace_elems: " << workspace_elems << ", "; + std::cout << "workspace_bytes: " << workspace_bytes << ", "; + std::cout << "small_seq_min_bytes: " << fused_small_seq_workspace << ", "; + std::cout << "workspace_bytes >= fused_small_seq_workspace: " << (workspace_bytes >= fused_small_seq_workspace ? "true" : "false") + << std::endl; + } } return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); } From 75f7cfae6e2f2ca2e23695f98b3ea1b80b535630 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Fri, 27 Feb 2026 18:16:32 +0000 Subject: [PATCH 05/21] ROCm CK unfused small-seq: env guard, FP16, tests, and logging - tests/jax: CK small-seq tests use fixture to set/restore NVTE_FUSED_ATTN_CK_SMALLSEQ=1; parametrize dtype (BF16/FP16) and add sequence-packing cases (2048-2-4, 2-4096-8192); when env set, num_segments_per_seq = max_seqlen_q for THD else 2. - JAX attention.py: THD softmax shape/dtype uses small-seq path only when env=1, else original layout - JAX attention.cpp: Added env guard - fused_attn_smallseq: Use TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT for fwd/bwd; add FP16 (__half) support; fix __half*float with T(scale). --- tests/jax/test_fused_attn.py | 65 +++++------ .../fused_attn_rocm/fused_attn_smallseq.cpp | 102 ++---------------- .../jax/cpp_extensions/attention.py | 20 ++-- .../jax/csrc/extensions/attention.cpp | 18 ++-- 4 files changed, 55 insertions(+), 150 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 471666699..48528b4be 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -332,8 +332,7 @@ def _get_max_segments_per_sequence(self): if self.qkv_layout.is_thd(): if ( 90400 <= get_cudnn_version() < 90500 - or ( self.max_seqlen_q == 1 and - is_hip_extension() and + or ( is_hip_extension() and os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") == "1") ): return self.num_segments_per_seq @@ -575,7 +574,10 @@ def generate_random_segment_ids( self.segment_ids_kv ) else: - self.num_segments_per_seq = 2 + if is_hip_extension() and os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") == "1": + self.num_segments_per_seq = self.max_seqlen_q + else: + self.num_segments_per_seq = 2 self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids( self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42 ) @@ -1253,48 +1255,36 @@ def test_jax_new_rng(): # ROCm CK small-seq varlen tests. -# Uses THD_THD_THD with s_q=1, s_kv<=16 so the small-seq path is taken. -# Run only when NVTE_FUSED_ATTN_CK_SMALLSEQ=1. -@pytest.mark.skipif( - os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") != "1", - reason="CK unfused smallseq tests require NVTE_FUSED_ATTN_CK_SMALLSEQ=1", -) @pytest.mark.skipif( not is_hip_extension(), reason="CK unfused smallseq backend only available on AMD hardware" ) + +@pytest.fixture +def ck_smallseq_env(monkeypatch): + monkeypatch.setenv("NVTE_FUSED_ATTN_CK_SMALLSEQ", "1") + yield + +@pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float16], ids=["BF16", "FP16"]) @pytest.mark.parametrize( - "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype", + "b, s_q, s_kv, h_q, h_kv, d_qk, d_v", [ - pytest.param(4000, 1, 2, 16, 16, 128, 128, jnp.bfloat16, - id="4000-1-2-16-16-128-128-BF16"), - pytest.param(4000, 1, 4, 16, 16, 128, 128, jnp.bfloat16, - id="4000-1-4-16-16-128-128-BF16"), - pytest.param(4000, 1, 6, 16, 16, 128, 128, jnp.bfloat16, - id="4000-1-6-16-16-128-128-BF16"), - pytest.param(4000, 1, 8, 16, 16, 128, 128, jnp.bfloat16, - id="4000-1-8-16-16-128-128-BF16"), - pytest.param(4000, 1, 12, 16, 16, 128, 128, jnp.bfloat16, - id="4000-1-12-16-16-128-128-BF16"), - pytest.param(4000, 1, 16, 16, 16, 128, 128, jnp.bfloat16, - id="4000-1-16-16-16-128-128-BF16"), - pytest.param(4000, 1, 2, 16, 16, 128, 128, jnp.float16, - id="4000-1-2-16-16-128-128-FP16"), - pytest.param(4000, 1, 4, 16, 16, 128, 128, jnp.float16, - id="4000-1-4-16-16-128-128-FP16"), - pytest.param(4000, 1, 6, 16, 16, 128, 128, jnp.float16, - id="4000-1-6-16-16-128-128-FP16"), - pytest.param(4000, 1, 8, 16, 16, 128, 128, jnp.float16, - id="4000-1-8-16-16-128-128-FP16"), - pytest.param(4000, 1, 12, 16, 16, 128, 128, jnp.float16, - id="4000-1-12-16-16-128-128-FP16"), - pytest.param(4000, 1, 16, 16, 16, 128, 128, jnp.float16, - id="4000-1-16-16-16-128-128-FP16"), + pytest.param(4000, 1, 2, 16, 16, 128, 128, id="4000-1-2-16-16-128-128"), + pytest.param(4000, 1, 4, 16, 16, 128, 128, id="4000-1-4-16-16-128-128"), + pytest.param(4000, 1, 6, 16, 16, 128, 128, id="4000-1-6-16-16-128-128"), + pytest.param(4000, 1, 8, 16, 16, 128, 128, id="4000-1-8-16-16-128-128"), + pytest.param(4000, 1, 12, 16, 16, 128, 128, id="4000-1-12-16-16-128-128"), + pytest.param(4000, 1, 16, 16, 16, 128, 128, id="4000-1-16-16-16-128-128"), + pytest.param(2048, 2, 4, 16, 16, 128, 128, id="seqpack-2048-2-4-16-16-128-128"), + pytest.param(2, 4096, 8192, 16, 16, 128, 128, id="seqpack-2-4096-8192-16-16-128-128"), ], ) -def test_ck_unfused_smallseq_backend(b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype): +def test_ck_unfused_smallseq_backend( + ck_smallseq_env, b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype +): """ Test the CK unfused small-seq (varlen) path on ROCm: s_q=1, s_kv<=16, THD layout. - Uses THD_THD_THD (Q,K,V all THD). + Uses THD_THD_THD (Q,K,V all THD). ck_smallseq_env sets NVTE_FUSED_ATTN_CK_SMALLSEQ=1 and + restores it after the test. """ runner = FusedAttnRunner( batch_size=b, @@ -1316,6 +1306,5 @@ def test_ck_unfused_smallseq_backend(b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype): seq_desc_format=SeqDescFormat.Seqlens, ) runner._setup_inputs() - expected_backend = NVTE_Fused_Attn_Backend.NVTE_CK - runner.test_forward() + # runner.test_forward() runner.test_backward() diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp index 4bf84320a..9d484e83e 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp @@ -849,18 +849,11 @@ void fused_attn_smallseq_fwd(size_t b, << (qkv_dtype == DType::kBFloat16 ? "BF16" : qkv_dtype == DType::kFloat16 ? "FP16" : "?") << std::endl; } - (void)h_kv; - (void)d_qk; - (void)d_v; - (void)is_training; - (void)rng_seed; - (void)rng_offset; float sqr_dk_scale = attn_scale; hipStream_t hip_stream = reinterpret_cast(stream); - if (qkv_dtype == DType::kBFloat16) { - using T = hip_bfloat16; + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(qkv_dtype, T, const T* Q_ptr = static_cast(devPtrQ); const T* K_ptr = static_cast(devPtrK); const T* V_ptr = static_cast(devPtrV); @@ -891,46 +884,8 @@ void fused_attn_smallseq_fwd(size_t b, default: NVTE_ERROR("Unsupported max_seqlen_kv for small-seq: max_seqlen_kv <= 16."); } - } else if (qkv_dtype == DType::kFloat16) { - using T = __half; - const T* Q_ptr = static_cast(devPtrQ); - const T* K_ptr = static_cast(devPtrK); - const T* V_ptr = static_cast(devPtrV); - T* O_ptr = static_cast(devPtrO); - T* attn_workspace = static_cast(attn_weights_buffer); - const int* cu_kv = static_cast(devPtrCuSeqlensKV); - const int* cu_kv_p = static_cast(devPtrSeqOffsetsKV); - const T* dropout_mask = nullptr; - int bi = static_cast(b); - int hi = static_cast(h_q); - - switch (max_seqlen_kv) { - SMALLSEQ_DISPATCH_FWD_CASE(2) - SMALLSEQ_DISPATCH_FWD_CASE(3) - SMALLSEQ_DISPATCH_FWD_CASE(4) - SMALLSEQ_DISPATCH_FWD_CASE(5) - SMALLSEQ_DISPATCH_FWD_CASE(6) - SMALLSEQ_DISPATCH_FWD_CASE(7) - SMALLSEQ_DISPATCH_FWD_CASE(8) - SMALLSEQ_DISPATCH_FWD_CASE(9) - SMALLSEQ_DISPATCH_FWD_CASE(10) - SMALLSEQ_DISPATCH_FWD_CASE(11) - SMALLSEQ_DISPATCH_FWD_CASE(12) - SMALLSEQ_DISPATCH_FWD_CASE(13) - SMALLSEQ_DISPATCH_FWD_CASE(14) - SMALLSEQ_DISPATCH_FWD_CASE(15) - SMALLSEQ_DISPATCH_FWD_CASE(16) - default: - NVTE_ERROR("Unsupported max_seqlen_kv for small-seq: max_seqlen_kv <= 16."); - } - } else { - NVTE_ERROR("small-seq path supports only BF16 and FP16."); - } + ); - if (workspace_size) { - size_t bwd_ws = fused_attn_smallseq_bwd_workspace_size(b, h_q, max_seqlen_kv, qkv_dtype); - *workspace_size = (bwd_ws > 8u) ? bwd_ws : 8u; - } } void fused_attn_smallseq_bwd(size_t b, @@ -957,7 +912,8 @@ void fused_attn_smallseq_bwd(size_t b, size_t* workspace_size, cudaStream_t stream) { - if (std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ")) { + const char* nvte_smallseq = std::getenv("NVTE_LOG_CK_CONFIG"); + if (nvte_smallseq && std::string(nvte_smallseq) == "1") { std::cout << std::endl << "attn_bwd(ck small-seq kernel): "; std::cout << "b: " << b << ", "; std::cout << "h_q: " << h_q << ", "; @@ -971,15 +927,11 @@ void fused_attn_smallseq_bwd(size_t b, << (qkv_dtype == DType::kBFloat16 ? "BF16" : qkv_dtype == DType::kFloat16 ? "FP16" : "?") << std::endl; } - (void)h_kv; - (void)d_qk; - (void)d_v; float sqr_dk_scale = attn_scale; hipStream_t hip_stream = reinterpret_cast(stream); - if (qkv_dtype == DType::kBFloat16) { - using T = hip_bfloat16; + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(qkv_dtype, T, const T* Q_ptr = static_cast(devPtrQ); const T* K_ptr = static_cast(devPtrK); const T* V_ptr = static_cast(devPtrV); @@ -1015,49 +967,7 @@ void fused_attn_smallseq_bwd(size_t b, default: NVTE_ERROR("Unsupported max_seqlen_kv for small-seq: max_seqlen_kv <= 16."); } - } else if (qkv_dtype == DType::kFloat16) { - using T = __half; - const T* Q_ptr = static_cast(devPtrQ); - const T* K_ptr = static_cast(devPtrK); - const T* V_ptr = static_cast(devPtrV); - const T* O_ptr = static_cast(devPtrO); - const T* dO_ptr = static_cast(devPtrdO); - const T* attn_ptr = static_cast(attn_weights); - T* dQ_ptr = static_cast(devPtrdQ); - T* dK_ptr = static_cast(devPtrdK); - T* dV_ptr = static_cast(devPtrdV); - T* workspace_ptr = static_cast(workspace); - const int* cu_kv = static_cast(devPtrCuSeqlensKV); - const int* cu_kv_p = static_cast(devPtrSeqOffsetsKV); - const T* dropout_mask = nullptr; - int bi = static_cast(b); - int hi = static_cast(h_q); - - switch (max_seqlen_kv) { - SMALLSEQ_DISPATCH_BWD_CASE(2) - SMALLSEQ_DISPATCH_BWD_CASE(3) - SMALLSEQ_DISPATCH_BWD_CASE(4) - SMALLSEQ_DISPATCH_BWD_CASE(5) - SMALLSEQ_DISPATCH_BWD_CASE(6) - SMALLSEQ_DISPATCH_BWD_CASE(7) - SMALLSEQ_DISPATCH_BWD_CASE(8) - SMALLSEQ_DISPATCH_BWD_CASE(9) - SMALLSEQ_DISPATCH_BWD_CASE(10) - SMALLSEQ_DISPATCH_BWD_CASE(11) - SMALLSEQ_DISPATCH_BWD_CASE(12) - SMALLSEQ_DISPATCH_BWD_CASE(13) - SMALLSEQ_DISPATCH_BWD_CASE(14) - SMALLSEQ_DISPATCH_BWD_CASE(15) - SMALLSEQ_DISPATCH_BWD_CASE(16) - default: - NVTE_ERROR("Unsupported max_seqlen_kv for small-seq: max_seqlen_kv <= 16."); - } - } else { - NVTE_ERROR("small-seq path supports only BF16 and FP16."); - } - - if (workspace_size) - *workspace_size = fused_attn_smallseq_bwd_workspace_size(b, h_q, max_seqlen_kv, qkv_dtype); + ); } } // namespace fused_attn_rocm diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 90860ff09..a8839f404 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -372,10 +372,14 @@ def abstract( softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) else: batch_size = reduce(operator.mul, batch_shape) - old_ck_softmax_size = (batch_size * attn_heads * q_max_seqlen * 1) - possible_ck_smallseq_softmax_size = (batch_size * attn_heads * - q_max_seqlen * min(kv_max_seqlen, 16) * 2) # 2 bytes for bf16/fp16 - if old_ck_softmax_size >= possible_ck_smallseq_softmax_size: + ck_standard_softmax_aux_size = ( + batch_size * attn_heads * q_max_seqlen * 1 + ) + ck_smallseq_softmax_aux_size = ( + batch_size * attn_heads * q_max_seqlen + * min(kv_max_seqlen, 16) * 2 + ) # 2 bytes for bf16/fp16 + if ck_standard_softmax_aux_size >= ck_smallseq_softmax_aux_size: softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1) softmax_dtype = dtypes.canonicalize_dtype(q_dtype) else: @@ -388,9 +392,11 @@ def abstract( raise ValueError(f"Unsupported {backend=}") if os.environ.get("NVTE_LOG_CK_CONFIG"): - print( - "attn_fwd(ck small-seq JAX abstract): " - f"batch_shape: {batch_shape}, softmax_shape: {softmax_shape}, softmax_dtype: {softmax_dtype}" + jax.debug.print( + "attn_fwd(ck small-seq JAX abstract): batch_shape: {}, softmax_shape: {}, softmax_dtype: {}", + batch_shape, + softmax_shape, + softmax_dtype, ) softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype) diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 2994e1e97..55f5575ed 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -509,16 +509,16 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( nvte_tensor_pack_destroy(&aux_input_tensors); auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); - + const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ"); - if (nvte_smallseq && std::string(nvte_smallseq) == "1") { + if (is_ragged && nvte_smallseq && std::string(nvte_smallseq) == "1") { size_t workspace_elems = product(work_shape); size_t elt_size = transformer_engine::typeToSize(query_workspace_tensor.dtype()); size_t workspace_bytes = workspace_elems * elt_size; - size_t fused_small_seq_workspace = input_batch * attn_heads * 16 * 2; // min for small-seq (bf16/fp16) + size_t unfused_small_seq_workspace = input_batch * attn_heads * 16 * 2; // min for unfused small-seq (bf16/fp16) - if (is_ragged && workspace_bytes < fused_small_seq_workspace) { - size_t min_elems = (fused_small_seq_workspace + elt_size - 1) / elt_size; + if (workspace_bytes < unfused_small_seq_workspace) { + size_t min_elems = (unfused_small_seq_workspace + elt_size - 1) / elt_size; work_shape = std::vector{min_elems}; workspace_elems = min_elems; workspace_bytes = workspace_elems * elt_size; @@ -526,14 +526,14 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( const char* nvte_log_ck_config = std::getenv("NVTE_LOG_CK_CONFIG"); if (nvte_log_ck_config && std::string(nvte_log_ck_config) == "1") { - std::cout << std::endl << "attn_bwd(ck small-seq workspace size): "; + std::cout << std::endl << "attn_bwd(ck unfused small-seq workspace size): "; std::cout << "input_batch: " << input_batch << ", "; std::cout << "is_ragged: " << is_ragged << ", "; std::cout << "workspace_elems: " << workspace_elems << ", "; std::cout << "workspace_bytes: " << workspace_bytes << ", "; - std::cout << "small_seq_min_bytes: " << fused_small_seq_workspace << ", "; - std::cout << "workspace_bytes >= fused_small_seq_workspace: " << (workspace_bytes >= fused_small_seq_workspace ? "true" : "false") - << std::endl; + std::cout << "unfused_small_seq_min_bytes: " << unfused_small_seq_workspace << ", "; + std::cout << "workspace_bytes >= unfused_small_seq_workspace: " + << (workspace_bytes >= unfused_small_seq_workspace ? "true" : "false") << std::endl; } } return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); From c737072f8297118b7fd9065cca6228609ca101f1 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Fri, 27 Feb 2026 19:29:30 +0000 Subject: [PATCH 06/21] Disabled xla_gpu_graph_level --- tests/jax/test_fused_attn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 48528b4be..7df45596e 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -1261,7 +1261,9 @@ def test_jax_new_rng(): @pytest.fixture def ck_smallseq_env(monkeypatch): + """Enable CK small-seq path and disable XLA GPU graphs for these tests.""" monkeypatch.setenv("NVTE_FUSED_ATTN_CK_SMALLSEQ", "1") + monkeypatch.setenv("XLA_FLAGS", "--xla_gpu_graph_level=0") yield @pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float16], ids=["BF16", "FP16"]) From 4537cce2fa6a3370f3d489a97d39419bd03362fd Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Fri, 27 Feb 2026 23:30:44 +0000 Subject: [PATCH 07/21] Updated XLA_FLAGS in ci/jax.sh --- ci/jax.sh | 1 + tests/jax/test_fused_attn.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/ci/jax.sh b/ci/jax.sh index 81d994585..d1b1bb890 100755 --- a/ci/jax.sh +++ b/ci/jax.sh @@ -58,6 +58,7 @@ run_test_config() { run_default_fa 1 test_custom_call_compute.py run_default_fa 1 test_functions.py run 1 test_fused_attn.py + XLA_FLAGS='--xla_gpu_graph_level=0' run 1 test_fused_attn.py -k 'test_ck_unfused_smallseq_backend' # CK smallseq with GPU graph disabled NVTE_CK_USES_FWD_V3=0 NVTE_CK_USES_BWD_V3=0 run_default_fa_lbl "v2" 3 test_fused_attn.py # Using FAv2 for forward and backward pass run_default_fa 1 test_helper.py run_default_fa 1 test_layer.py #it effectevly always uses unfused attention diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 7df45596e..b69902057 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -1262,8 +1262,9 @@ def test_jax_new_rng(): @pytest.fixture def ck_smallseq_env(monkeypatch): """Enable CK small-seq path and disable XLA GPU graphs for these tests.""" + if "xla_gpu_graph_level=0" not in os.environ.get("XLA_FLAGS", ""): + pytest.skip("Run with XLA_FLAGS='--xla_gpu_graph_level=0' pytest ...") monkeypatch.setenv("NVTE_FUSED_ATTN_CK_SMALLSEQ", "1") - monkeypatch.setenv("XLA_FLAGS", "--xla_gpu_graph_level=0") yield @pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float16], ids=["BF16", "FP16"]) From c6e0eaea424805c56bdc66eb3e8f71d1c1dd14d3 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Tue, 3 Mar 2026 17:20:57 +0000 Subject: [PATCH 08/21] Adressed comments --- ci/jax.sh | 4 +-- tests/jax/test_fused_attn.py | 9 +++-- .../common/fused_attn_rocm/fused_attn_ck.cpp | 33 ++++++++++++------- .../fused_attn_rocm/fused_attn_smallseq.cpp | 6 ++-- .../fused_attn_rocm/fused_attn_smallseq.h | 3 +- .../jax/cpp_extensions/attention.py | 2 +- 6 files changed, 31 insertions(+), 26 deletions(-) diff --git a/ci/jax.sh b/ci/jax.sh index d1b1bb890..f048492ba 100755 --- a/ci/jax.sh +++ b/ci/jax.sh @@ -57,9 +57,9 @@ run_test_config() { export NVTE_JAX_UNITTEST_LEVEL=L0 # this env variable controls parameters set for some tests run_default_fa 1 test_custom_call_compute.py run_default_fa 1 test_functions.py - run 1 test_fused_attn.py + run 1 test_fused_attn.py -k 'not test_ck_unfused_smallseq_backend' # skip smallseq in normal flow XLA_FLAGS='--xla_gpu_graph_level=0' run 1 test_fused_attn.py -k 'test_ck_unfused_smallseq_backend' # CK smallseq with GPU graph disabled - NVTE_CK_USES_FWD_V3=0 NVTE_CK_USES_BWD_V3=0 run_default_fa_lbl "v2" 3 test_fused_attn.py # Using FAv2 for forward and backward pass + NVTE_CK_USES_FWD_V3=0 NVTE_CK_USES_BWD_V3=0 run_default_fa_lbl "v2" 3 test_fused_attn.py -k 'not test_ck_unfused_smallseq_backend' # Using FAv2 for forward and backward pass run_default_fa 1 test_helper.py run_default_fa 1 test_layer.py #it effectevly always uses unfused attention run_default_fa 1 test_sanity_import.py diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index b69902057..0bc1d25a6 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -1255,15 +1255,11 @@ def test_jax_new_rng(): # ROCm CK small-seq varlen tests. -@pytest.mark.skipif( - not is_hip_extension(), reason="CK unfused smallseq backend only available on AMD hardware" -) - @pytest.fixture def ck_smallseq_env(monkeypatch): """Enable CK small-seq path and disable XLA GPU graphs for these tests.""" if "xla_gpu_graph_level=0" not in os.environ.get("XLA_FLAGS", ""): - pytest.skip("Run with XLA_FLAGS='--xla_gpu_graph_level=0' pytest ...") + pytest.skip("Test must be run with XLA_FLAGS='--xla_gpu_graph_level=0'") monkeypatch.setenv("NVTE_FUSED_ATTN_CK_SMALLSEQ", "1") yield @@ -1281,6 +1277,9 @@ def ck_smallseq_env(monkeypatch): pytest.param(2, 4096, 8192, 16, 16, 128, 128, id="seqpack-2-4096-8192-16-16-128-128"), ], ) +@pytest.mark.skipif( + not is_hip_extension(), reason="CK unfused smallseq backend only available on AMD hardware" +) def test_ck_unfused_smallseq_backend( ck_smallseq_env, b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype ): diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 6a293c88c..2af841581 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -617,18 +617,24 @@ void fused_attn_ck_fwd_impl( const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ"); if (is_ragged && nvte_smallseq && std::string(nvte_smallseq) == "1") { - void* max_seqlen_workspace = workspace; - + void* max_seqlen_workspace = workspace_next; size_t runtime_max_seqlen_q = static_cast(ck_fused_attn::get_runtime_max_seqlen( static_cast(b), devPtrCuSeqlensQ, nullptr, max_seqlen_workspace, stream)); size_t runtime_max_seqlen_kv = static_cast(ck_fused_attn::get_runtime_max_seqlen( static_cast(b), devPtrCuSeqlensKV, nullptr, max_seqlen_workspace, stream)); + workspace_next = static_cast(static_cast(workspace_next) + sizeof(uint64_t)); - if (std::getenv("NVTE_LOG_CK_CONFIG")) { + if (nvte_log_ck_config) { std::cout << std::endl << "attn_fwd(ck small-seq): "; std::cout << "b: " << b << ", "; std::cout << "runtime_max_seqlen_q: " << runtime_max_seqlen_q << ", "; - std::cout << "runtime_max_seqlen_kv: " << runtime_max_seqlen_kv << std::endl; + std::cout << "runtime_max_seqlen_kv: " << runtime_max_seqlen_kv << ", "; + std::cout << "flow: " + << (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && + runtime_max_seqlen_kv <= 16 + ? "ck-smallseq" + : "regular ck/aiter") + << std::endl; } if (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && runtime_max_seqlen_kv <= 16) { @@ -947,22 +953,27 @@ void fused_attn_ck_bwd_impl( const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ"); if (is_ragged && nvte_smallseq && std::string(nvte_smallseq) == "1") { - void* max_seqlen_workspace = workspace; - + void* max_seqlen_workspace = workspace_next; size_t runtime_max_seqlen_q = static_cast(ck_fused_attn::get_runtime_max_seqlen( b, devPtrCuSeqlensQ, nullptr, max_seqlen_workspace, stream)); size_t runtime_max_seqlen_kv = static_cast(ck_fused_attn::get_runtime_max_seqlen( b, devPtrCuSeqlensKV, nullptr, max_seqlen_workspace, stream)); - - if (std::getenv("NVTE_LOG_CK_CONFIG")) { + workspace_next = static_cast(static_cast(workspace_next) + sizeof(uint64_t)); + + if (nvte_log_ck_config) { std::cout << std::endl << "attn_bwd(ck small-seq): "; std::cout << "b: " << b << ", "; std::cout << "runtime_max_seqlen_q: " << runtime_max_seqlen_q << ", "; - std::cout << "runtime_max_seqlen_kv: " << runtime_max_seqlen_kv << std::endl; + std::cout << "runtime_max_seqlen_kv: " << runtime_max_seqlen_kv << ", "; + std::cout << "flow: " + << (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && + runtime_max_seqlen_kv <= 16 + ? "ck-smallseq" + : "regular ck/aiter") + << std::endl; } if (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && runtime_max_seqlen_kv <= 16) { - fused_attn_rocm::fused_attn_smallseq_bwd( b, h, hg, runtime_max_seqlen_kv, d_qk, d_v, scaling_factor, dropout_probability, @@ -1887,7 +1898,6 @@ void fused_attn_ck_fwd( size_t max_tokens_kv = std::accumulate((input_K->data).shape.begin(), (input_K->data).shape.end(), static_cast(1), std::multiplies())/h_kv/d_qk; bool is_ragged = nvte_get_qkv_format(qkv_layout)==NVTE_QKV_Format::NVTE_THD; - if (Aux_CTX_Tensors->size == 0) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { Aux_CTX_Tensors->size = 3; @@ -1942,7 +1952,6 @@ void fused_attn_ck_fwd( bool is_padding = (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); - fused_attn_ck_fwd_impl( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, bias_b, bias_h, max_tokens_q, max_tokens_kv, diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp index 9d484e83e..789beffa2 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp @@ -27,13 +27,13 @@ case N: \ dispatch_fwd(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, \ sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, \ - hip_stream); \ + stream); \ break; #define SMALLSEQ_DISPATCH_BWD_CASE(N) \ case N: \ dispatch_bwd(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, \ dropout_mask, dropout, sqr_dk_scale, dQ_ptr, dK_ptr, \ - dV_ptr, workspace_ptr, cu_kv, cu_kv_p, hip_stream); \ + dV_ptr, workspace_ptr, cu_kv, cu_kv_p, stream); \ break; namespace transformer_engine { @@ -851,7 +851,6 @@ void fused_attn_smallseq_fwd(size_t b, } float sqr_dk_scale = attn_scale; - hipStream_t hip_stream = reinterpret_cast(stream); TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(qkv_dtype, T, const T* Q_ptr = static_cast(devPtrQ); @@ -929,7 +928,6 @@ void fused_attn_smallseq_bwd(size_t b, } float sqr_dk_scale = attn_scale; - hipStream_t hip_stream = reinterpret_cast(stream); TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(qkv_dtype, T, const T* Q_ptr = static_cast(devPtrQ); diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.h b/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.h index 9a5e8cefc..818b5448a 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.h +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.h @@ -11,8 +11,7 @@ #ifndef TRANSFORMER_ENGINE_FUSED_ATTN_ROCM_FUSED_ATTN_SMALLSEQ_H_ #define TRANSFORMER_ENGINE_FUSED_ATTN_ROCM_FUSED_ATTN_SMALLSEQ_H_ -#include "../common.h" -#include "transformer_engine/fused_attn.h" +#include namespace transformer_engine { namespace fused_attn_rocm { diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index a8839f404..6b9b0a30a 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -391,7 +391,7 @@ def abstract( else: raise ValueError(f"Unsupported {backend=}") - if os.environ.get("NVTE_LOG_CK_CONFIG"): + if os.environ.get("NVTE_LOG_CK_CONFIG", "0") == "1": jax.debug.print( "attn_fwd(ck small-seq JAX abstract): batch_shape: {}, softmax_shape: {}, softmax_dtype: {}", batch_shape, From 366945e3c1b0217598f806f0e9ff6673ea53f2ea Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Tue, 3 Mar 2026 22:01:49 +0000 Subject: [PATCH 09/21] Refactored input generation for smallseq flow --- tests/jax/test_fused_attn.py | 99 ++++++++++++++++++++++++------------ 1 file changed, 66 insertions(+), 33 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 0bc1d25a6..8e2684a1b 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -423,6 +423,57 @@ def _check_configs(self): "the F16_arbitrary_seqlen backend." ) + def _setup_thd_segments_ck_smallseq(self, generate_random_segment_ids): + """ + Build THD segment descriptors for the CK small-seq path (NVTE_FUSED_ATTN_CK_SMALLSEQ=1). + + Uses num_segments_per_seq = max_seqlen_q for both Q and KV. For Q: if max_seqlen_q == 1, + uses a fixed layout (one token per batch, cu_seqlens [0,1,...,batch_size]); otherwise + generates random segments. For KV: always generates random segments. + """ + num_segments_per_seq = self.max_seqlen_q + if self.max_seqlen_q == 1: + # Q: deterministic - one segment of length 1 per batch -> cu_seqlen [0,1,2,...,batch_size] + segment_ids_q = jnp.ones((self.batch_size, self.max_seqlen_q), dtype=jnp.int32) + segment_pos_q = jnp.zeros((self.batch_size, self.max_seqlen_q), dtype=jnp.int32) + pad_q = jnp.zeros((self.batch_size, self.max_seqlen_q), dtype=jnp.int32) + seqlens_q = jnp.ones((self.batch_size, 1), dtype=jnp.int32) + offsets_q = jnp.concatenate( + [ + jnp.arange(self.batch_size, dtype=jnp.int32)[:, None], + jnp.full((self.batch_size, 1), -1, dtype=jnp.int32), + ], + axis=1, + ) + else: + segment_ids_q, segment_pos_q, pad_q = generate_random_segment_ids( + self.batch_size, self.max_seqlen_q, num_segments_per_seq, seed=42 + ) + seqlens_q, offsets_q = get_seqlens_and_offsets(segment_ids_q) + + min_segment_len = None if self.window_size is None else seqlens_q + segment_ids_kv, segment_pos_kv, pad_kv = generate_random_segment_ids( + self.batch_size, + self.max_seqlen_kv, + num_segments_per_seq, + seed=2024, + min_segment_len=min_segment_len, + ) + seqlens_kv, offsets_kv = get_seqlens_and_offsets(segment_ids_kv) + return ( + num_segments_per_seq, + segment_ids_q, + segment_pos_q, + pad_q, + seqlens_q, + offsets_q, + segment_ids_kv, + segment_pos_kv, + pad_kv, + seqlens_kv, + offsets_kv, + ) + def _setup_inputs(self): self._check_configs() @@ -544,40 +595,22 @@ def generate_random_segment_ids( return segment_ids, segment_pos, segment_pad if self.qkv_layout.is_thd(): - if self.max_seqlen_q == 1 and is_hip_extension() and os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") == "1": - self.num_segments_per_seq = 1 - # Q: deterministic — one segment of length 1 per batch -> cu_seqlen [0,1,2,...,batch_size] - self.segment_ids_q = jnp.ones((self.batch_size, self.max_seqlen_q), dtype=jnp.int32) - self.segment_pos_q = jnp.zeros((self.batch_size, self.max_seqlen_q), dtype=jnp.int32) - self.pad_q = jnp.zeros((self.batch_size, self.max_seqlen_q), dtype=jnp.int32) - self.seqlens_q = jnp.ones((self.batch_size, 1), dtype=jnp.int32) - self.offsets_q = jnp.concatenate( - [ - jnp.arange(self.batch_size, dtype=jnp.int32)[:, None], - jnp.full((self.batch_size, 1), -1, dtype=jnp.int32), - ], - axis=1, - ) - - # KV: one segment per batch (num_segments_per_seq=1) to match smallseq kernel - min_segment_len = None if self.window_size is None else self.seqlens_q - self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = ( - generate_random_segment_ids( - self.batch_size, - self.max_seqlen_kv, - self.num_segments_per_seq, # 1 for s_q=1 path - seed=2024, - min_segment_len=min_segment_len, - ) - ) - self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets( - self.segment_ids_kv - ) + if is_hip_extension() and os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") == "1": + ( + self.num_segments_per_seq, + self.segment_ids_q, + self.segment_pos_q, + self.pad_q, + self.seqlens_q, + self.offsets_q, + self.segment_ids_kv, + self.segment_pos_kv, + self.pad_kv, + self.seqlens_kv, + self.offsets_kv, + ) = self._setup_thd_segments_ck_smallseq(generate_random_segment_ids) else: - if is_hip_extension() and os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") == "1": - self.num_segments_per_seq = self.max_seqlen_q - else: - self.num_segments_per_seq = 2 + self.num_segments_per_seq = 2 self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids( self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42 ) From d5afb6fac89b6d42458d867cbffad98ec9a7d979 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Fri, 13 Mar 2026 16:49:12 -0500 Subject: [PATCH 10/21] [ROCm] apply more strict filtering for just cross-attn and fix the softmax shape bug --- .../common/fused_attn_rocm/fused_attn_ck.cpp | 4 ++-- transformer_engine/jax/cpp_extensions/attention.py | 8 ++++---- transformer_engine/jax/csrc/extensions/attention.cpp | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 2af841581..da89ed9fd 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -616,7 +616,7 @@ void fused_attn_ck_fwd_impl( void* workspace_next = workspace; const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ"); - if (is_ragged && nvte_smallseq && std::string(nvte_smallseq) == "1") { + if (is_ragged && s_q!=s_kv && nvte_smallseq && std::string(nvte_smallseq) == "1") { void* max_seqlen_workspace = workspace_next; size_t runtime_max_seqlen_q = static_cast(ck_fused_attn::get_runtime_max_seqlen( static_cast(b), devPtrCuSeqlensQ, nullptr, max_seqlen_workspace, stream)); @@ -952,7 +952,7 @@ void fused_attn_ck_bwd_impl( void* workspace_next = workspace; const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ"); - if (is_ragged && nvte_smallseq && std::string(nvte_smallseq) == "1") { + if (is_ragged && s_q!=s_kv && nvte_smallseq && std::string(nvte_smallseq) == "1") { void* max_seqlen_workspace = workspace_next; size_t runtime_max_seqlen_q = static_cast(ck_fused_attn::get_runtime_max_seqlen( b, devPtrCuSeqlensQ, nullptr, max_seqlen_workspace, stream)); diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 6b9b0a30a..cd64e259d 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -367,21 +367,21 @@ def abstract( elif backend == NVTE_Fused_Attn_Backend.NVTE_CK: if config.qkv_layout.is_thd(): # THD only: check env; run small-seq logic only when enabled - if os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") != "1": + if os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") != "1" or q_max_seqlen == kv_max_seqlen: softmax_shape = (*batch_shape, q_max_seqlen, attn_heads, 1) softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) else: batch_size = reduce(operator.mul, batch_shape) ck_standard_softmax_aux_size = ( - batch_size * attn_heads * q_max_seqlen * 1 - ) + batch_size * attn_heads * q_max_seqlen * 1 * 4 + ) # 4 bytes for the float32 ck_smallseq_softmax_aux_size = ( batch_size * attn_heads * q_max_seqlen * min(kv_max_seqlen, 16) * 2 ) # 2 bytes for bf16/fp16 if ck_standard_softmax_aux_size >= ck_smallseq_softmax_aux_size: softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1) - softmax_dtype = dtypes.canonicalize_dtype(q_dtype) + softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) else: softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, min(kv_max_seqlen, 16)) softmax_dtype = dtypes.canonicalize_dtype(q_dtype) diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 55f5575ed..b575a56a4 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -511,7 +511,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ"); - if (is_ragged && nvte_smallseq && std::string(nvte_smallseq) == "1") { + if (is_ragged && q_max_seqlen!=kv_max_seqlen && nvte_smallseq && std::string(nvte_smallseq) == "1") { size_t workspace_elems = product(work_shape); size_t elt_size = transformer_engine::typeToSize(query_workspace_tensor.dtype()); size_t workspace_bytes = workspace_elems * elt_size; From 006edee929dca7cd33c771204ad316d3c0874820 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Wed, 15 Apr 2026 16:59:02 +0000 Subject: [PATCH 11/21] Refactor small-seq kernels and add NVTE hooks for explicit fused-attn backend Refactor the ROCm small-sequence attention path so it is a first-class backend instead of branching from the generic CK fused-attention entry: add NVTE entry points and ROCm implementations under fused_attn_rocm, CMake wiring, and public declarations in fused_attn.h. Rename the small-seq sources to fused_attn_small_seq.* so filenames match the new API. Extend kernel dispatch to head sizes 128, 256, and 512. --- transformer_engine/common/CMakeLists.txt | 2 +- .../common/fused_attn_rocm/fused_attn.cpp | 177 ++++++++++++++++ .../common/fused_attn_rocm/fused_attn_ck.cpp | 69 ------- ..._smallseq.cpp => fused_attn_small_seq.cpp} | 191 ++++++++++++++++-- ...attn_smallseq.h => fused_attn_small_seq.h} | 33 ++- .../include/transformer_engine/fused_attn.h | 37 ++++ 6 files changed, 412 insertions(+), 97 deletions(-) rename transformer_engine/common/fused_attn_rocm/{fused_attn_smallseq.cpp => fused_attn_small_seq.cpp} (87%) rename transformer_engine/common/fused_attn_rocm/{fused_attn_smallseq.h => fused_attn_small_seq.h} (78%) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 6774acfd2..61be29cbb 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -200,7 +200,7 @@ else() fused_attn_rocm/fused_attn.cpp fused_attn_rocm/fused_attn_aotriton.cpp fused_attn_rocm/fused_attn_ck.cpp - fused_attn_rocm/fused_attn_smallseq.cpp + fused_attn_rocm/fused_attn_small_seq.cpp fused_attn_rocm/utils.cpp gemm/rocm_gemm.cu amd_detail/system.cpp) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index d8969b2ab..dd48c03a7 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -10,6 +10,7 @@ #include "transformer_engine/fused_attn.h" #include "fused_attn_aotriton.h" #include "fused_attn_ck.h" +#include "fused_attn_small_seq.h" #include "../common.h" #include "utils.h" @@ -875,6 +876,182 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso } } +bool nvte_is_small_seq_attn_supported( + NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, + size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, + size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right) { + return transformer_engine::fused_attn_rocm::is_small_seq_attn_supported( + q_dtype, kv_dtype, qkv_layout, bias_type, attn_mask_type, dropout, num_attn_heads, + num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, + window_size_right); +} + +size_t nvte_fused_attn_small_seq_bwd_workspace_size(size_t batch, size_t attn_heads, + size_t max_seqlen_kv, NVTEDType dtype) { + return transformer_engine::fused_attn_rocm::fused_attn_small_seq_bwd_workspace_size( + batch, attn_heads, max_seqlen_kv, static_cast(dtype)); +} + +void nvte_fused_attn_small_seq_fwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, + NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, + const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, + size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_fused_attn_small_seq_fwd); + using namespace transformer_engine; + const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); + const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); + const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); + const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); + const Tensor *input_Q = convertNVTETensorCheck(Q); + const Tensor *input_K = convertNVTETensorCheck(K); + const Tensor *input_V = convertNVTETensorCheck(V); + convertNVTETensorCheck(Bias); + convertNVTETensorCheck(S); + Tensor *output_O = convertNVTETensorCheck(O); + Tensor *wkspace = convertNVTETensorCheck(workspace); + + auto ndim = input_Q->data.shape.size(); + size_t b = input_cu_seqlens_q->data.shape[0] - 1; + size_t h_q = input_Q->data.shape[ndim - 2]; + size_t h_kv = input_K->data.shape[ndim - 2]; + size_t d_qk = input_Q->data.shape[ndim - 1]; + size_t d_v = input_V->data.shape[ndim - 1]; + + const NVTEDType Q_type = static_cast(input_Q->data.dtype); + const NVTEDType KV_type = static_cast(input_K->data.dtype); + + log_fused_attn_config(__FUNCTION__, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, + dropout, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, + window_size_left, window_size_right); + + std::tie(window_size_left, window_size_right) = + check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); + + NVTE_CHECK( + fused_attn_rocm::is_small_seq_attn_supported( + Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, + max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right), + "nvte_fused_attn_small_seq_fwd: configuration not supported for small-seq path."); + + Tensor *softmax_aux_tensor = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + void *attn_weights_buf = softmax_aux_tensor->data.dptr; + + if (wkspace->data.dptr == nullptr) { + wkspace->data.shape = {1}; + wkspace->data.dtype = DType::kByte; + return; + } + + void *dev_ptr_seed = input_rng_state->data.dptr; + void *dev_ptr_offset = + reinterpret_cast(reinterpret_cast(input_rng_state->data.dptr) + 1); + + size_t workspace_bytes = 1; + for (size_t i = 0; i < wkspace->data.shape.size(); ++i) { + workspace_bytes *= wkspace->data.shape[i]; + } + workspace_bytes *= nvte_dtype_size(wkspace->data.dtype); + + fused_attn_rocm::fused_attn_small_seq_fwd( + b, h_q, h_kv, max_seqlen_kv, d_qk, d_v, is_training, attn_scale, dropout, + input_Q->data.dptr, input_K->data.dptr, input_V->data.dptr, output_O->data.dptr, + attn_weights_buf, input_cu_seqlens_kv->data.dptr, input_cu_seqlens_kv_padded->data.dptr, + dev_ptr_seed, dev_ptr_offset, input_Q->data.dtype, wkspace->data.dptr, &workspace_bytes, + stream); + (void)page_table_k; + (void)page_table_v; +} + +void nvte_fused_attn_small_seq_bwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor O, + const NVTETensor dO, const NVTETensor S, NVTETensor dP, + const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, NVTETensor dV, + NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_fused_attn_small_seq_bwd); + (void)deterministic; + using namespace transformer_engine; + const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); + const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); + const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); + const Tensor *input_Q = convertNVTETensorCheck(Q); + const Tensor *input_K = convertNVTETensorCheck(K); + const Tensor *input_V = convertNVTETensorCheck(V); + const Tensor *input_O = convertNVTETensorCheck(O); + const Tensor *input_dO = convertNVTETensorCheck(dO); + convertNVTETensorCheck(S); + convertNVTETensorCheck(dP); + Tensor *output_dQ = convertNVTETensorCheck(dQ); + Tensor *output_dK = convertNVTETensorCheck(dK); + Tensor *output_dV = convertNVTETensorCheck(dV); + convertNVTETensorCheck(dBias); + Tensor *wkspace = convertNVTETensorCheck(workspace); + + const Tensor *attn_weights_tensor = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + + auto ndim = input_Q->data.shape.size(); + size_t b = input_cu_seqlens_q->data.shape[0] - 1; + size_t h_q = input_Q->data.shape[ndim - 2]; + size_t h_kv = input_K->data.shape[ndim - 2]; + size_t d_qk = input_Q->data.shape[ndim - 1]; + size_t d_v = input_V->data.shape[ndim - 1]; + + const NVTEDType Q_type = static_cast(input_Q->data.dtype); + const NVTEDType KV_type = static_cast(input_K->data.dtype); + + log_fused_attn_config(__FUNCTION__, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, + dropout, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, + window_size_left, window_size_right); + + std::tie(window_size_left, window_size_right) = + check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); + + NVTE_CHECK( + fused_attn_rocm::is_small_seq_attn_supported( + Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, + max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right), + "nvte_fused_attn_small_seq_bwd: configuration not supported for small-seq path."); + + size_t req_bytes = fused_attn_rocm::fused_attn_small_seq_bwd_workspace_size( + b, h_q, max_seqlen_kv, input_Q->data.dtype); + + if (wkspace->data.dptr == nullptr) { + wkspace->data.shape = {req_bytes}; + wkspace->data.dtype = DType::kByte; + return; + } + + size_t workspace_bytes = 1; + for (size_t i = 0; i < wkspace->data.shape.size(); ++i) { + workspace_bytes *= wkspace->data.shape[i]; + } + workspace_bytes *= nvte_dtype_size(wkspace->data.dtype); + NVTE_CHECK(workspace_bytes >= req_bytes, "nvte_fused_attn_small_seq_bwd: workspace too small."); + + fused_attn_rocm::fused_attn_small_seq_bwd( + b, h_q, h_kv, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, input_Q->data.dptr, + input_K->data.dptr, input_V->data.dptr, input_O->data.dptr, input_dO->data.dptr, + attn_weights_tensor->data.dptr, output_dQ->data.dptr, output_dK->data.dptr, + output_dV->data.dptr, input_cu_seqlens_kv->data.dptr, + input_cu_seqlens_kv_padded->data.dptr, input_Q->data.dtype, wkspace->data.dptr, + &workspace_bytes, stream); +} + uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t max_batch_size, cudaStream_t stream) { NVTE_API_CALL(nvte_get_runtime_num_segments); diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index da89ed9fd..f19cae104 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -9,7 +9,6 @@ #include // Required for std::accumulate #ifdef USE_FUSED_ATTN_CK #include -#include "fused_attn_smallseq.h" #endif // USE_FUSED_ATTN_CK #include "../util/cuda_runtime.h" #include "../util/system.h" @@ -615,40 +614,6 @@ void fused_attn_ck_fwd_impl( // denote the next available section of workspace from upstream void* workspace_next = workspace; - const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ"); - if (is_ragged && s_q!=s_kv && nvte_smallseq && std::string(nvte_smallseq) == "1") { - void* max_seqlen_workspace = workspace_next; - size_t runtime_max_seqlen_q = static_cast(ck_fused_attn::get_runtime_max_seqlen( - static_cast(b), devPtrCuSeqlensQ, nullptr, max_seqlen_workspace, stream)); - size_t runtime_max_seqlen_kv = static_cast(ck_fused_attn::get_runtime_max_seqlen( - static_cast(b), devPtrCuSeqlensKV, nullptr, max_seqlen_workspace, stream)); - workspace_next = static_cast(static_cast(workspace_next) + sizeof(uint64_t)); - - if (nvte_log_ck_config) { - std::cout << std::endl << "attn_fwd(ck small-seq): "; - std::cout << "b: " << b << ", "; - std::cout << "runtime_max_seqlen_q: " << runtime_max_seqlen_q << ", "; - std::cout << "runtime_max_seqlen_kv: " << runtime_max_seqlen_kv << ", "; - std::cout << "flow: " - << (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && - runtime_max_seqlen_kv <= 16 - ? "ck-smallseq" - : "regular ck/aiter") - << std::endl; - } - - if (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && runtime_max_seqlen_kv <= 16) { - fused_attn_rocm::fused_attn_smallseq_fwd( - b, h, hg, runtime_max_seqlen_kv, d_qk, d_v, - is_training, scaling_factor, dropout_probability, - devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxAux, - devPtrCuSeqlensKV, devPtrSeqOffsetsKV, - devPtrDropoutSeed, devPtrDropoutOffset, - dtype, workspace, workspace_size, stream); - return; - } - } - std::array q_stride; std::array k_stride; std::array v_stride; @@ -951,40 +916,6 @@ void fused_attn_ck_bwd_impl( // denote the next available section of workspace from upstream void* workspace_next = workspace; - const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ"); - if (is_ragged && s_q!=s_kv && nvte_smallseq && std::string(nvte_smallseq) == "1") { - void* max_seqlen_workspace = workspace_next; - size_t runtime_max_seqlen_q = static_cast(ck_fused_attn::get_runtime_max_seqlen( - b, devPtrCuSeqlensQ, nullptr, max_seqlen_workspace, stream)); - size_t runtime_max_seqlen_kv = static_cast(ck_fused_attn::get_runtime_max_seqlen( - b, devPtrCuSeqlensKV, nullptr, max_seqlen_workspace, stream)); - workspace_next = static_cast(static_cast(workspace_next) + sizeof(uint64_t)); - - if (nvte_log_ck_config) { - std::cout << std::endl << "attn_bwd(ck small-seq): "; - std::cout << "b: " << b << ", "; - std::cout << "runtime_max_seqlen_q: " << runtime_max_seqlen_q << ", "; - std::cout << "runtime_max_seqlen_kv: " << runtime_max_seqlen_kv << ", "; - std::cout << "flow: " - << (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && - runtime_max_seqlen_kv <= 16 - ? "ck-smallseq" - : "regular ck/aiter") - << std::endl; - } - - if (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && runtime_max_seqlen_kv <= 16) { - fused_attn_rocm::fused_attn_smallseq_bwd( - b, h, hg, runtime_max_seqlen_kv, d_qk, d_v, - scaling_factor, dropout_probability, - devPtrQ, devPtrK, devPtrV, devPtrO, devPtrdO, devPtrSoftmaxAux, - devPtrdQ, devPtrdK, devPtrdV, - devPtrCuSeqlensKV, devPtrSeqOffsetsKV, - dtype, workspace, workspace_size, stream); - return; - } - } - std::array q_stride; std::array k_stride; std::array v_stride; diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_small_seq.cpp similarity index 87% rename from transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp rename to transformer_engine/common/fused_attn_rocm/fused_attn_small_seq.cpp index 789beffa2..223ce8981 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_small_seq.cpp @@ -4,8 +4,8 @@ * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ -/*! \file fused_attn_smallseq.cpp - * \brief Unfused small-seq (varlen) attention: seq_q=1, max_seqlen_kv<=16, THD only. +/*! \file fused_attn_small_seq.cpp + * \brief small-seq (varlen) attention: seq_q=1, max_seqlen_kv<=16, THD only. */ #include @@ -15,25 +15,26 @@ #include #include #include +#include #include "../common.h" #include "../util/cuda_runtime.h" -#include "fused_attn_smallseq.h" +#include "fused_attn_small_seq.h" #include "utils.h" // Macros to avoid repeating dispatch switch cases for max_seqlen_kv in [2, 16]. -// T, bi, hi and the pointer/scale args must be in scope where these are used. +// T, bi, hi, d_qk and the pointer/scale args must be in scope where these are used. #define SMALLSEQ_DISPATCH_FWD_CASE(N) \ case N: \ dispatch_fwd(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, \ sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, \ - stream); \ + d_qk, stream); \ break; #define SMALLSEQ_DISPATCH_BWD_CASE(N) \ case N: \ dispatch_bwd(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, \ dropout_mask, dropout, sqr_dk_scale, dQ_ptr, dK_ptr, \ - dV_ptr, workspace_ptr, cu_kv, cu_kv_p, stream); \ + dV_ptr, workspace_ptr, cu_kv, cu_kv_p, d_qk, stream); \ break; namespace transformer_engine { @@ -57,9 +58,8 @@ struct SmallSeqConfig { /* MAX_SEQ_KV and HEAD_DIM are compile-time so kernels can use fixed stack arrays * (e.g. float results[max_seq_kv], T attn[max_seq_kv]) and constexpr grid/block - * sizes. This matches varlen_attn/attn_fwd.cpp (FmhaKernelConfig<..., MAX_SEQ_KV, HEAD_DIM>) - * and INTEGRATION_TASK.md: seq_q==1, max_seq_kv<=16; head_dim=128 is the only - * value tested in varlen_attn (main() uses TestRunner<2,16>::run<..., 128, ...>). */ + * sizes. This matches varlen_attn/attn_fwd.cpp (FmhaKernelConfig<..., MAX_SEQ_KV, HEAD_DIM>). + * Dispatch supports head_dim 128, 256, 512 (d_qk == d_v). */ // ----- Forward kernels (with runtime batch_size, head_num) ----- @@ -784,7 +784,7 @@ void run_attn_bwd_impl(int b, workspace, Q, K, grad_Q, grad_K, scale, cu_seqlens_kv, cu_seqlens_kv_padded, b, head_num); } -size_t fused_attn_smallseq_bwd_workspace_size(size_t b, +size_t fused_attn_small_seq_bwd_workspace_size(size_t b, size_t h_q, size_t max_seqlen_kv, DType dtype) { @@ -795,22 +795,56 @@ size_t fused_attn_smallseq_bwd_workspace_size(size_t b, template static void dispatch_fwd(int b, int h_q, const T* Q, const T* K, const T* V, const T* dropout_mask, float dropout, float scale, T* O, T* workspace, const int* cu_kv, - const int* cu_kv_p, hipStream_t stream) { - run_attn_fwd_impl>( - b, h_q, Q, K, V, dropout_mask, dropout, scale, O, workspace, cu_kv, cu_kv_p, stream); + const int* cu_kv_p, size_t d_qk, hipStream_t stream) { + switch (d_qk) { + case 128: + run_attn_fwd_impl>( + b, h_q, Q, K, V, dropout_mask, dropout, scale, O, workspace, cu_kv, cu_kv_p, stream); + break; + case 256: + run_attn_fwd_impl>( + b, h_q, Q, K, V, dropout_mask, dropout, scale, O, workspace, cu_kv, cu_kv_p, stream); + break; + case 512: + run_attn_fwd_impl>( + b, h_q, Q, K, V, dropout_mask, dropout, scale, O, workspace, cu_kv, cu_kv_p, stream); + break; + default: + NVTE_ERROR( + "Unsupported head dimension (d_qk) for small-seq attention: must be 128, 256, or " + "512."); + } } template static void dispatch_bwd(int b, int h_q, const T* Q, const T* K, const T* V, const T* grad_O, const T* attn_weights, const T* dropout_mask, float dropout, float scale, T* grad_Q, T* grad_K, T* grad_V, T* workspace, const int* cu_kv, - const int* cu_kv_p, hipStream_t stream) { - run_attn_bwd_impl>( - b, h_q, Q, K, V, grad_O, attn_weights, dropout_mask, dropout, scale, - grad_Q, grad_K, grad_V, workspace, cu_kv, cu_kv_p, stream); + const int* cu_kv_p, size_t d_qk, hipStream_t stream) { + switch (d_qk) { + case 128: + run_attn_bwd_impl>( + b, h_q, Q, K, V, grad_O, attn_weights, dropout_mask, dropout, scale, + grad_Q, grad_K, grad_V, workspace, cu_kv, cu_kv_p, stream); + break; + case 256: + run_attn_bwd_impl>( + b, h_q, Q, K, V, grad_O, attn_weights, dropout_mask, dropout, scale, + grad_Q, grad_K, grad_V, workspace, cu_kv, cu_kv_p, stream); + break; + case 512: + run_attn_bwd_impl>( + b, h_q, Q, K, V, grad_O, attn_weights, dropout_mask, dropout, scale, + grad_Q, grad_K, grad_V, workspace, cu_kv, cu_kv_p, stream); + break; + default: + NVTE_ERROR( + "Unsupported head dimension (d_qk) for small-seq attention: must be 128, 256, or " + "512."); + } } -void fused_attn_smallseq_fwd(size_t b, +void fused_attn_small_seq_fwd(size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_kv, @@ -887,7 +921,7 @@ void fused_attn_smallseq_fwd(size_t b, } -void fused_attn_smallseq_bwd(size_t b, +void fused_attn_small_seq_bwd(size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_kv, @@ -968,5 +1002,124 @@ void fused_attn_smallseq_bwd(size_t b, ); } +bool is_small_seq_attn_supported( + NVTEDType q_dtype, + NVTEDType kv_dtype, + NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, + float dropout, + size_t num_attn_heads, + size_t num_gqa_groups, + size_t max_seqlen_q, + size_t max_seqlen_kv, + size_t head_dim_qk, + size_t head_dim_v, + int64_t window_size_left, + int64_t window_size_right) { + bool log = false; + if (const char* env_p = std::getenv("NVTE_LOG_CK_CONFIG")) { + log = (env_p != nullptr && std::string(env_p) == "1"); + } + + if (num_gqa_groups == 0 || num_attn_heads % num_gqa_groups != 0) { + if (log) { + std::cout << "small-seq: heads must be divisible by GQA groups" << std::endl; + } + return false; + } + + if (q_dtype != kv_dtype || + !(q_dtype == NVTEDType::kNVTEFloat16 || q_dtype == NVTEDType::kNVTEBFloat16)) { + if (log) { + std::cout << "small-seq: Q/K/V must be FP16 or BF16 and match" << std::endl; + } + return false; + } + + if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) { + if (log) { + std::cout << "small-seq: bias not supported" << std::endl; + } + return false; + } + + if (dropout != 0.f) { + if (log) { + std::cout << "small-seq: dropout not supported in kernel path" << std::endl; + } + return false; + } + + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + if (layout_group != NVTE_QKV_Layout_Group::NVTE_HD_HD_HD || + nvte_get_qkv_format(qkv_layout) != NVTE_QKV_Format::NVTE_THD) { + if (log) { + std::cout << "small-seq: requires THD separate Q, K, V (THD_THD_THD)" << std::endl; + } + return false; + } + + if (qkv_layout != NVTE_QKV_Layout::NVTE_THD_THD_THD) { + if (log) { + std::cout << "small-seq: layout must be NVTE_THD_THD_THD" << std::endl; + } + return false; + } + + // max_seqlen_q / max_seqlen_kv are compile-time or padded upper bounds; actual varlen + // lengths are in cu_seqlens / offsets. Do not reject here based on those maxima — callers + // must ensure runtime lengths and the launch-time max_seqlen_kv passed to fwd/bwd match + // the small-seq kernel contract (e.g. dispatch cases 2..16, effective q len 1). + (void)max_seqlen_q; + (void)max_seqlen_kv; + + if (head_dim_qk != head_dim_v) { + if (log) { + std::cout << "small-seq: head_dim_qk and head_dim_v must match" << std::endl; + } + return false; + } + if (head_dim_qk != 128 && head_dim_qk != 256 && head_dim_qk != 512) { + if (log) { + std::cout << "small-seq: head_dim_qk must be 128, 256, or 512" << std::endl; + } + return false; + } + + bool is_causal = + (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); + if (is_causal) { + if (log) { + std::cout << "small-seq: causal mask types not supported" << std::endl; + } + return false; + } + + if (attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK) { + if (log) { + std::cout << "small-seq: only NO_MASK or PADDING_MASK supported" << std::endl; + } + return false; + } + + if (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) { + if (!((window_size_left == -1 && window_size_right == -1) || + (window_size_left >= 0 && window_size_right >= 0))) { + if (log) { + std::cout << "small-seq: invalid window size for mask type" << std::endl; + } + return false; + } + } + + return true; +} + } // namespace fused_attn_rocm } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.h b/transformer_engine/common/fused_attn_rocm/fused_attn_small_seq.h similarity index 78% rename from transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.h rename to transformer_engine/common/fused_attn_rocm/fused_attn_small_seq.h index 818b5448a..f3abd6bbd 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.h +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_small_seq.h @@ -4,20 +4,37 @@ * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ -/*! \file fused_attn_smallseq.h +/*! \file fused_attn_small_seq.h * \brief Small-seq (varlen) attention for ROCm: seq_q=1, max_seqlen_kv<=16, THD only. */ -#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_ROCM_FUSED_ATTN_SMALLSEQ_H_ -#define TRANSFORMER_ENGINE_FUSED_ATTN_ROCM_FUSED_ATTN_SMALLSEQ_H_ +#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_ROCM_FUSED_ATTN_SMALL_SEQ_H_ +#define TRANSFORMER_ENGINE_FUSED_ATTN_ROCM_FUSED_ATTN_SMALL_SEQ_H_ #include +#include "transformer_engine/fused_attn.h" namespace transformer_engine { namespace fused_attn_rocm { +bool is_small_seq_attn_supported( + NVTEDType q_dtype, + NVTEDType kv_dtype, + NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, + float dropout, + size_t num_attn_heads, + size_t num_gqa_groups, + size_t max_seqlen_q, + size_t max_seqlen_kv, + size_t head_dim_qk, + size_t head_dim_v, + int64_t window_size_left, + int64_t window_size_right); + /** Workspace size in bytes for small-seq backward path */ -size_t fused_attn_smallseq_bwd_workspace_size(size_t b, +size_t fused_attn_small_seq_bwd_workspace_size(size_t b, size_t h_q, size_t max_seqlen_kv, DType dtype); @@ -26,7 +43,7 @@ size_t fused_attn_smallseq_bwd_workspace_size(size_t b, * attn_weights_buffer is also used as internal workspace (scores then overwritten by attn * weights). No separate workspace required for the launcher; caller may use workspace for * get_runtime_max_seqlen (8 bytes). */ -void fused_attn_smallseq_fwd(size_t b, +void fused_attn_small_seq_fwd(size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_kv, @@ -50,8 +67,8 @@ void fused_attn_smallseq_fwd(size_t b, cudaStream_t stream); /** Backward: dO, O, attn_weights -> dQ, dK, dV. attn_weights is the buffer from forward - * (output_S). workspace must be at least fused_attn_smallseq_bwd_workspace_size. */ -void fused_attn_smallseq_bwd(size_t b, + * (output_S). workspace must be at least fused_attn_small_seq_bwd_workspace_size. */ +void fused_attn_small_seq_bwd(size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_kv, @@ -78,4 +95,4 @@ void fused_attn_smallseq_bwd(size_t b, } // namespace fused_attn_rocm } // namespace transformer_engine -#endif // TRANSFORMER_ENGINE_FUSED_ATTN_ROCM_FUSED_ATTN_SMALLSEQ_H_ +#endif // TRANSFORMER_ENGINE_FUSED_ATTN_ROCM_FUSED_ATTN_SMALL_SEQ_H_ diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 851032e04..f23a0ec0a 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -633,6 +633,43 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream); +#ifdef __HIP_PLATFORM_AMD__ +bool nvte_is_small_seq_attn_supported( + NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, + size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, + size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right); + +/*! \brief Device workspace bytes required for small-seq backward. */ +size_t nvte_fused_attn_small_seq_bwd_workspace_size(size_t batch, size_t attn_heads, + size_t max_seqlen_kv, NVTEDType dtype); + +/*! \brief ROCm small-seq forward: separate Q, K, V; same tensor roles as nvte_fused_attn_fwd. */ +void nvte_fused_attn_small_seq_fwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, + NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, + const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, + size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, + cudaStream_t stream); + +/*! \brief ROCm small-seq backward. */ +void nvte_fused_attn_small_seq_bwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor O, + const NVTETensor dO, const NVTETensor S, NVTETensor dP, + const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, NVTETensor dV, + NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, + cudaStream_t stream); +#endif // __HIP_PLATFORM_AMD__ + /*! \brief Update the RNG state with the seed and calculated offset. * * \warning This API is **experimental** and subject to change. From f8a5ce86a0eaa3013bc8f20c38f4848d50814ca2 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Fri, 17 Apr 2026 15:43:43 +0000 Subject: [PATCH 12/21] feat(jax): C++ FFI bridge for ROCm small-seq attention Wire the explicit small-sequence path in JAX csrc. - extensions.h: declare GetSmallSeqAttn{Forward,Backward}WorkspaceSizes, SmallSeqAttn{Forward,Backward}FFI handlers, and XLA_FFI_DECLARE_HANDLER_SYMBOL exports for XLA registration. - attention.cpp (USE_ROCM): add PrepareSmallSeqAttnForwardAuxTensors / PrepareSmallSeqAttnBackwardAuxTensors to build NVTETensorPack for small-seq (softmax slot = attention-weights buffer layout, RNG slot per fused API contract); memset ragged output/softmax aux as needed for THD. - GetSmallSeqAttnForwardWorkspaceSizes / GetSmallSeqAttnBackwardWorkspaceSizes: gate on nvte_is_small_seq_attn_supported, return minimal forward workspace and nvte_fused_attn_small_seq_bwd_workspace_size-backed backward scratch. - SmallSeqAttnForwardImpl / SmallSeqAttnBackwardImpl: reuse FUSED_ATTN_IMPL_COMMON_BLOCK for THD cu_seqlens/offsets, call nvte_fused_attn_small_seq_fwd / _bwd. - SmallSeqAttnForwardFFI / SmallSeqAttnBackwardFFI + XLA_FFI_DEFINE_HANDLER_SYMBOL: mirror FusedAttn*FFI attribute unpacking so JAX can invoke the dedicated backend. --- transformer_engine/jax/csrc/extensions.h | 21 ++ .../jax/csrc/extensions/attention.cpp | 330 ++++++++++++++++-- 2 files changed, 325 insertions(+), 26 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 453a4202b..b15d8c5c8 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -123,6 +123,27 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right); +#ifdef USE_ROCM +XLA_FFI_DECLARE_HANDLER_SYMBOL(SmallSeqAttnForwardHandler); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(SmallSeqAttnBackwardHandler); + +pybind11::tuple GetSmallSeqAttnForwardWorkspaceSizes( + size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, + size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, + size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, + size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right); + +pybind11::tuple GetSmallSeqAttnBackwardWorkspaceSizes( + size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, + size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, + size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, + bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, + int64_t window_size_right); +#endif // USE_ROCM + // GEMM XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index b575a56a4..dc65014e3 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -337,6 +337,186 @@ static void FusedAttnForwardImpl( nvte_tensor_pack_destroy(&aux_output_tensors); } +#ifdef USE_ROCM +void PrepareSmallSeqAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t input_batch, + const size_t attn_heads, const size_t q_max_seqlen, + const size_t kv_max_seqlen, DType dtype, + void *softmax_buf, void *rng_state_buf) { + tensor_pack->size = 2; + NVTETensor &softmax_aux = tensor_pack->tensors[0]; + NVTEBasicTensor softmax_aux_data; + softmax_aux_data.data_ptr = softmax_buf; + softmax_aux_data.shape.ndim = 4; + softmax_aux_data.shape.data[0] = input_batch; + softmax_aux_data.shape.data[1] = attn_heads; + softmax_aux_data.shape.data[2] = q_max_seqlen; + softmax_aux_data.shape.data[3] = std::min(kv_max_seqlen, size_t{16}); + softmax_aux_data.dtype = static_cast(dtype); + nvte_set_tensor_param(&softmax_aux, kNVTERowwiseData, &softmax_aux_data); + + NVTETensor &rng_state_aux = tensor_pack->tensors[1]; + NVTEBasicTensor rng_state_aux_data; + rng_state_aux_data.data_ptr = rng_state_buf; + rng_state_aux_data.shape = {}; + rng_state_aux_data.shape.ndim = 2; + rng_state_aux_data.dtype = static_cast(DType::kInt64); + nvte_set_tensor_param(&rng_state_aux, kNVTERowwiseData, &rng_state_aux_data); +} + +void PrepareSmallSeqAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, const size_t input_batch, + const size_t attn_heads, const size_t q_max_seqlen, + const size_t kv_max_seqlen, DType dtype, void *softmax_buf, + void *rng_state_buf) { + PrepareSmallSeqAttnForwardAuxTensors(tensor_pack, input_batch, attn_heads, q_max_seqlen, + kv_max_seqlen, dtype, softmax_buf, rng_state_buf); +} + +pybind11::tuple GetSmallSeqAttnForwardWorkspaceSizes( + size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, + size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, + size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, + size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right) { + (void)scaling_factor; + (void)is_training; + (void)max_segments_per_seq; + NVTE_CHECK( + nvte_is_small_seq_attn_supported( + static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, + mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, + qk_head_dim, v_head_dim, window_size_left, window_size_right), + "GetSmallSeqAttnForwardWorkspaceSizes: configuration not supported."); + NVTE_CHECK(bias_batch == 0 && bias_heads == 0, + "GetSmallSeqAttnForwardWorkspaceSizes: bias not supported for small-seq."); + TensorWrapper query_workspace_tensor(nullptr, std::vector{1}, DType::kByte); + return pybind11::make_tuple(MakeShapeVector(query_workspace_tensor.shape()), + query_workspace_tensor.dtype()); +} + +static void SmallSeqAttnForwardImpl( + cudaStream_t stream, void *q, void *k, void *v, void *bias, void *seed, void *q_cu_seqlens, + void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *output, void *softmax_aux, + void *rng_state, void *workspace, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, + size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, + size_t qk_head_dim, size_t v_head_dim, size_t max_segments_per_seq, size_t wkspace_size, + float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, + bool is_training, bool deterministic, int64_t window_size_left, int64_t window_size_right) { + (void)deterministic; + FUSED_ATTN_IMPL_COMMON_BLOCK; + NVTE_CHECK(layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD && is_ragged, + "SmallSeqAttnForwardImpl: requires THD separate Q, K, V."); + NVTE_CHECK(bias_batch == 0 && bias_heads == 0, + "SmallSeqAttnForwardImpl: bias not supported for small-seq."); + + auto bias_tensor = TensorWrapper(bias, bias_shape, dtype); + + if (is_ragged) { + auto output_size = input_batch * q_max_seqlen * attn_heads * v_head_dim; + (void)cudaMemsetAsync(output, 0, output_size * typeToSize(dtype), stream); + size_t kv_eff = std::min(kv_max_seqlen, size_t{16}); + auto softmax_aux_elems = input_batch * q_max_seqlen * attn_heads * kv_eff; + (void)cudaMemsetAsync(softmax_aux, 0, softmax_aux_elems * typeToSize(dtype), stream); + } + + auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); + auto o_shape = std::vector{input_batch * q_max_seqlen, attn_heads, v_head_dim}; + auto o_tensor = TensorWrapper(output, o_shape, dtype); + auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); + + nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, + NVTE_Fused_Attn_Backend::NVTE_CK, stream); + + NVTETensorPack aux_output_tensors; + nvte_tensor_pack_create(&aux_output_tensors); + PrepareSmallSeqAttnForwardAuxTensors(&aux_output_tensors, input_batch, attn_heads, q_max_seqlen, + kv_max_seqlen, dtype, softmax_aux, rng_state); + + auto workspace_tensor = + TensorWrapper(workspace, std::vector{wkspace_size}, wkspace_dtype); + auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector{1}, DType::kInt32); + + auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; + auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; + auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; + auto q_tensor = TensorWrapper(q, q_shape, dtype); + auto k_tensor = TensorWrapper(k, k_shape, dtype); + auto v_tensor = TensorWrapper(v, v_shape, dtype); + + nvte_fused_attn_small_seq_fwd( + q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(), + o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), + kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), + dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), + q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, + bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream); + + nvte_tensor_pack_destroy(&aux_output_tensors); +} + +static void SmallSeqAttnBackwardImpl( + cudaStream_t stream, void *q, void *k, void *v, void *bias, void *softmax_aux, void *rng_state, + void *output, void *doutput, void *q_cu_seqlens, void *kv_cu_seqlens, void *q_seq_offsets, + void *k_seq_offsets, void *dq, void *dk, void *dv, void *dbias, void *workspace, + size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, + size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, + size_t v_head_dim, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor, + float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, + bool deterministic, int64_t window_size_left, int64_t window_size_right) { + FUSED_ATTN_IMPL_COMMON_BLOCK; + NVTE_CHECK(layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD && is_ragged, + "SmallSeqAttnBackwardImpl: requires THD separate Q, K, V."); + NVTE_CHECK(bias_batch == 0 && bias_heads == 0, + "SmallSeqAttnBackwardImpl: bias not supported for small-seq."); + + auto output_shape = std::vector{input_batch * q_max_seqlen, attn_heads, v_head_dim}; + auto output_tensor = TensorWrapper(output, output_shape, dtype); + auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype); + auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); + auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype); + + NVTETensorPack aux_input_tensors; + nvte_tensor_pack_create(&aux_input_tensors); + PrepareSmallSeqAttnBackwardAuxTensors(&aux_input_tensors, input_batch, attn_heads, q_max_seqlen, + kv_max_seqlen, dtype, softmax_aux, rng_state); + + auto workspace_tensor = + TensorWrapper(workspace, std::vector{wkspace_size}, wkspace_dtype); + + auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; + auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; + auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; + auto q_tensor = TensorWrapper(q, q_shape, dtype); + auto k_tensor = TensorWrapper(k, k_shape, dtype); + auto v_tensor = TensorWrapper(v, v_shape, dtype); + auto dq_tensor = TensorWrapper(dq, q_shape, dtype); + auto dk_tensor = TensorWrapper(dk, k_shape, dtype); + auto dv_tensor = TensorWrapper(dv, v_shape, dtype); + + if (is_ragged) { + (void)cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * typeToSize(dtype), + stream); + (void)cudaMemsetAsync(dk, 0, transformer_engine::jax::product(k_shape) * typeToSize(dtype), + stream); + (void)cudaMemsetAsync(dv, 0, transformer_engine::jax::product(v_shape) * typeToSize(dtype), + stream); + } + + nvte_fused_attn_small_seq_bwd( + q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), + doutput_tensor.data(), s_tensor.data(), s_tensor.data(), &aux_input_tensors, + dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), dbias_tensor.data(), + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), + k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, + dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, + deterministic, workspace_tensor.data(), stream); + + nvte_tensor_pack_destroy(&aux_input_tensors); + (void)is_training; +} +#endif // USE_ROCM + #define FUSED_ATTN_FFI_GET_ATTRS \ size_t input_batch = get_attr_value(attrs, "input_batch"); \ size_t bias_batch = get_attr_value(attrs, "bias_batch"); \ @@ -510,35 +690,37 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); - const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ"); - if (is_ragged && q_max_seqlen!=kv_max_seqlen && nvte_smallseq && std::string(nvte_smallseq) == "1") { - size_t workspace_elems = product(work_shape); - size_t elt_size = transformer_engine::typeToSize(query_workspace_tensor.dtype()); - size_t workspace_bytes = workspace_elems * elt_size; - size_t unfused_small_seq_workspace = input_batch * attn_heads * 16 * 2; // min for unfused small-seq (bf16/fp16) - - if (workspace_bytes < unfused_small_seq_workspace) { - size_t min_elems = (unfused_small_seq_workspace + elt_size - 1) / elt_size; - work_shape = std::vector{min_elems}; - workspace_elems = min_elems; - workspace_bytes = workspace_elems * elt_size; - } - - const char* nvte_log_ck_config = std::getenv("NVTE_LOG_CK_CONFIG"); - if (nvte_log_ck_config && std::string(nvte_log_ck_config) == "1") { - std::cout << std::endl << "attn_bwd(ck unfused small-seq workspace size): "; - std::cout << "input_batch: " << input_batch << ", "; - std::cout << "is_ragged: " << is_ragged << ", "; - std::cout << "workspace_elems: " << workspace_elems << ", "; - std::cout << "workspace_bytes: " << workspace_bytes << ", "; - std::cout << "unfused_small_seq_min_bytes: " << unfused_small_seq_workspace << ", "; - std::cout << "workspace_bytes >= unfused_small_seq_workspace: " - << (workspace_bytes >= unfused_small_seq_workspace ? "true" : "false") << std::endl; - } - } return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); } +#ifdef USE_ROCM +pybind11::tuple GetSmallSeqAttnBackwardWorkspaceSizes( + size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, + size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, + size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, + bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, + int64_t window_size_right) { + (void)scaling_factor; + (void)is_training; + (void)deterministic; + (void)max_segments_per_seq; + NVTE_CHECK( + nvte_is_small_seq_attn_supported( + static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, + mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, + qk_head_dim, v_head_dim, window_size_left, window_size_right), + "GetSmallSeqAttnBackwardWorkspaceSizes: configuration not supported."); + NVTE_CHECK(bias_batch == 0 && bias_heads == 0, + "GetSmallSeqAttnBackwardWorkspaceSizes: bias not supported for small-seq."); + size_t bwd_bytes = nvte_fused_attn_small_seq_bwd_workspace_size( + input_batch, attn_heads, kv_max_seqlen, static_cast(dtype)); + TensorWrapper query_workspace_tensor(nullptr, std::vector{bwd_bytes}, DType::kByte); + return pybind11::make_tuple(MakeShapeVector(query_workspace_tensor.shape()), + query_workspace_tensor.dtype()); +} +#endif // USE_ROCM + static void FusedAttnBackwardImpl( cudaStream_t stream, void *q, void *k, void *v, void *bias, void *softmax_aux, void *rng_state, void *output, void *doutput, void *q_cu_seqlens, void *kv_cu_seqlens, void *q_seq_offsets, @@ -692,5 +874,101 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardHandler, FusedAttnBackwardFFI, .Attrs(), FFI_CudaGraph_Traits); +#ifdef USE_ROCM +Error_Type SmallSeqAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, + Buffer_Type v_buf, Buffer_Type bias_buf, Buffer_Type seed_buf, + Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf, + Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf, + Variadic_Buffer_Type _unused_args, Result_Type output_buf, + Result_Type softmax_aux_buf, Result_Type rng_state_buf, + Result_Type workspace_buf, Dictionary attrs) { + FUSED_ATTN_FFI_GET_ATTRS; + + SmallSeqAttnForwardImpl( + stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(), + bias_buf.untyped_data(), seed_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(), + kv_cu_seqlens_buf.untyped_data(), is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr, + is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, output_buf->untyped_data(), + softmax_aux_buf->untyped_data(), rng_state_buf->untyped_data(), workspace_buf->untyped_data(), + input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads, + qk_head_dim, v_head_dim, max_segments_per_seq, wkspace_size, scaling_factor, + dropout_probability, bias_type, mask_type, qkv_layout, dtype, wkspace_dtype, is_training, + deterministic, window_size_left, window_size_right); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(SmallSeqAttnForwardHandler, SmallSeqAttnForwardFFI, + FFI::Bind() + .Ctx() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .RemainingArgs() + .Ret() + .Ret() + .Ret() + .Ret() + .Attrs(), + FFI_CudaGraph_Traits); + +Error_Type SmallSeqAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, + Buffer_Type v_buf, Buffer_Type bias_buf, + Buffer_Type softmax_aux_buf, Buffer_Type rng_state_buf, + Buffer_Type output_buf, Buffer_Type doutput_buf, + Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf, + Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf, + Variadic_Buffer_Type _unused_args, Result_Type dq_buf, + Result_Type dk_buf, Result_Type dv_buf, Result_Type dbias_buf, + Result_Type workspace_buf, Dictionary attrs) { + FUSED_ATTN_FFI_GET_ATTRS; + + SmallSeqAttnBackwardImpl( + stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(), + bias_buf.untyped_data(), softmax_aux_buf.untyped_data(), rng_state_buf.untyped_data(), + output_buf.untyped_data(), doutput_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(), + kv_cu_seqlens_buf.untyped_data(), is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr, + is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, dq_buf->untyped_data(), + dk_buf->untyped_data(), dv_buf->untyped_data(), dbias_buf->untyped_data(), + workspace_buf->untyped_data(), input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, + attn_heads, num_gqa_groups, bias_heads, qk_head_dim, v_head_dim, max_segments_per_seq, + wkspace_size, scaling_factor, dropout_probability, bias_type, mask_type, qkv_layout, dtype, + wkspace_dtype, is_training, deterministic, window_size_left, window_size_right); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(SmallSeqAttnBackwardHandler, SmallSeqAttnBackwardFFI, + FFI::Bind() + .Ctx() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .RemainingArgs() + .Ret() + .Ret() + .Ret() + .Ret() + .Ret() + .Attrs(), + FFI_CudaGraph_Traits); +#endif // USE_ROCM + } // namespace jax } // namespace transformer_engine From 9dd9b7e96d0d97047893869d38f1c9eaa2e1935c Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Fri, 17 Apr 2026 18:51:26 +0000 Subject: [PATCH 13/21] feat(jax): pybind registration for ROCm small-seq attention FFI --- transformer_engine/jax/csrc/extensions/pybind.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 563675988..d12229728 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -75,6 +75,8 @@ pybind11::dict Registrations() { // Attention dict["te_fused_attn_forward_ffi"] = EncapsulateFFI(FusedAttnForwardHandler); dict["te_fused_attn_backward_ffi"] = EncapsulateFFI(FusedAttnBackwardHandler); + dict["te_small_seq_attn_forward_ffi"] = EncapsulateFFI(SmallSeqAttnForwardHandler); + dict["te_small_seq_attn_backward_ffi"] = EncapsulateFFI(SmallSeqAttnBackwardHandler); dict["te_gemm_ffi"] = EncapsulateFFI(GemmHandler); dict["te_grouped_gemm_ffi"] = EncapsulateFFI(GroupedGemmHandler); @@ -100,6 +102,10 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_norm_bwd_workspace_sizes", &GetNormBackwardWorkspaceSizes); m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); +#ifdef USE_ROCM + m.def("get_small_seq_attn_fwd_workspace_sizes", &GetSmallSeqAttnForwardWorkspaceSizes); + m.def("get_small_seq_attn_bwd_workspace_sizes", &GetSmallSeqAttnBackwardWorkspaceSizes); +#endif m.def("nvte_get_qkv_format", &nvte_get_qkv_format); m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported); From 1c6ffd5ccf146f7d1fe327a01baa547a2f84baf9 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Fri, 17 Apr 2026 19:59:09 +0000 Subject: [PATCH 14/21] feat(jax): XLA primitives for ROCm small-seq attention Add SmallSeqAttnFwdPrimitive / SmallSeqAttnBwdPrimitive in cpp_extensions/attention.py so JAX compiles and lowers to the dedicated small-seq FFI without nvte_get_fused_attn_backend or generic fused-attn workspace probing. - abstract: HIP-only; validate THD_THD_THD, no bias/dropout and head dims; softmax_aux shape (*batch, heads, q, min(kv,16)) in Q dtype; workspace from get_small_seq_attn_{fwd,bwd}_workspace_sizes. - lowering: ffi_lowering to te_small_seq_attn_{forward,backward}_ffi with the same flattened attrs pattern as generic fused attention. - fused_attn_small_seq_fwd / fused_attn_small_seq_bwd: thin bind helpers; export via __all__. register_primitive for both primitives. --- .../jax/cpp_extensions/attention.py | 1147 ++++++++++++++--- 1 file changed, 968 insertions(+), 179 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index cd64e259d..18688f56c 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -63,6 +63,8 @@ "FusedAttnHelper", "fused_attn_fwd", "fused_attn_bwd", + "fused_attn_small_seq_fwd", + "fused_attn_small_seq_bwd", ] @@ -366,39 +368,14 @@ def abstract( softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) elif backend == NVTE_Fused_Attn_Backend.NVTE_CK: if config.qkv_layout.is_thd(): - # THD only: check env; run small-seq logic only when enabled - if os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") != "1" or q_max_seqlen == kv_max_seqlen: - softmax_shape = (*batch_shape, q_max_seqlen, attn_heads, 1) - softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) - else: - batch_size = reduce(operator.mul, batch_shape) - ck_standard_softmax_aux_size = ( - batch_size * attn_heads * q_max_seqlen * 1 * 4 - ) # 4 bytes for the float32 - ck_smallseq_softmax_aux_size = ( - batch_size * attn_heads * q_max_seqlen - * min(kv_max_seqlen, 16) * 2 - ) # 2 bytes for bf16/fp16 - if ck_standard_softmax_aux_size >= ck_smallseq_softmax_aux_size: - softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1) - softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) - else: - softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, min(kv_max_seqlen, 16)) - softmax_dtype = dtypes.canonicalize_dtype(q_dtype) + softmax_shape = (*batch_shape, q_max_seqlen, attn_heads, 1) + softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) else: softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1) softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) else: raise ValueError(f"Unsupported {backend=}") - if os.environ.get("NVTE_LOG_CK_CONFIG", "0") == "1": - jax.debug.print( - "attn_fwd(ck small-seq JAX abstract): batch_shape: {}, softmax_shape: {}, softmax_dtype: {}", - batch_shape, - softmax_shape, - softmax_dtype, - ) - softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype) # JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with @@ -608,19 +585,819 @@ def convert_to_2d(offsets, batch, max_seqlen): fill_value = 0 else: fill_value = -1 - + + q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value) + kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value) + + # Flatten the offset calculation + # max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]] + q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen) + k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen) + + # Gather valid q_seq_offsets, which is greater and equal to 0 + # [[0, 3, 5, -1], [8, 11, 13, -1]] -> [[0, 3, 5, 8], [11, 13, -1, -1]] + # And set the unused position to max size (batch * max_seqlen) + # [[0, 3, 5, 8], [11, 13, -1, -1]] -> [[0, 3, 5, 8], [11, 13, b*s, b*s]] + q_seq_offsets = _fix_len_take( + q_seq_offsets, q_seq_offsets >= 0, fill_value=q_batch * q_max_seqlen + ) + k_seq_offsets = _fix_len_take( + k_seq_offsets, k_seq_offsets >= 0, fill_value=kv_batch * kv_max_seqlen + ) + + q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten()) + kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten()) + + output, softmax_aux, rng_state, _ = FusedAttnFwdPrimitive.inner_primitive.bind( + q, + k, + v, + bias, + seed, + q_cu_seqlen, + kv_cu_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + config=config, + ) + return output, softmax_aux, rng_state + + @staticmethod + def batcher(batched_args, batch_dims, *, config): + check_valid_batch_dims(batch_dims) + assert FusedAttnFwdPrimitive.outer_primitive is not None + q_bdim, _, _, _, seed_bdim, *_ = batch_dims + + out_bdims = q_bdim, q_bdim, seed_bdim + return ( + FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args, config=config), + out_bdims, + ) + + @staticmethod + def infer_sharding_from_operands(config, mesh, arg_infos, result_infos): + del result_infos + q_spec = get_padded_spec(arg_infos[0]) + + # when supported softmax_aux shape is (b, s, h, 1) for thd on cudnn 9.6+ + # otherwise softmax_aux shape is (b, h, s, 1) or (b, h, s, max_segments) + is_packed_softmax = get_cudnn_version() >= (9, 6, 0) and config.qkv_layout.is_thd() + + if config.qkv_layout.is_qkvpacked(): + # q_spec = (...batch, q_seqlen, 3, head, hidden) + out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:])) + if not is_packed_softmax: + softmax_aux_sharding = NamedSharding( + mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None) + ) + else: + softmax_aux_sharding = NamedSharding( + mesh, PartitionSpec(*q_spec[:-4], q_spec[-4], q_spec[-2], None) + ) + elif config.qkv_layout.is_kvpacked(): + # q_spec = (...batch, q_seqlen, head, hidden) + # k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden) + out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + if not is_packed_softmax: + softmax_aux_sharding = NamedSharding( + mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None) + ) + else: + softmax_aux_sharding = NamedSharding( + mesh, PartitionSpec(*q_spec[:-3], q_spec[-3], q_spec[-2], None) + ) + elif config.qkv_layout.is_separate(): + # q_spec = (...batch, q_seqlen, head, hidden) + # k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden) + out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + if not is_packed_softmax: + softmax_aux_sharding = NamedSharding( + mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None) + ) + else: + softmax_aux_sharding = NamedSharding( + mesh, PartitionSpec(*q_spec[:-3], q_spec[-3], q_spec[-2], None) + ) + else: + raise ValueError(f"Unsupported {config.qkv_layout=}") + + rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None)) + return (out_sharding, softmax_aux_sharding, rng_state_sharding) + + @staticmethod + def partition(config, mesh, arg_infos, result_infos): + out_sharding = result_infos[0].sharding + softmax_aux_sharding = result_infos[1].sharding + rng_state_sharding = seed_sharding = NamedSharding( + mesh, PartitionSpec(get_all_mesh_axes(), None) + ) + arg_shardings = [arg_i.sharding for arg_i in arg_infos] + arg_shardings[4] = seed_sharding + arg_shardings[-1] = arg_shardings[-3] + arg_shardings[-2] = arg_shardings[-4] + arg_shardings = tuple(arg_shardings) + out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) + impl = partial(FusedAttnFwdPrimitive.impl, config=config) + return mesh, impl, out_shardings, arg_shardings + + @staticmethod + def shardy_sharding_rule(config, mesh, value_types, result_types): + if version.parse(jax.__version__) < version.parse("0.5.0"): + raise ImportError("JAX version 0.5.0 or later is required for shardy sharding.") + del mesh, result_types + + # Keep in sync with `infer_sharding_from_operands`. + # We only need the first input. Fill up the rest with placeholders. + input_spec = [(f"…{x}",) for x in range(len(value_types))] + # The RNG state sharding cannot be expressed as a Shardy rule. We use with_sharding_constraint + # instead. This has to happen outside of the primitive, see `fused_attn_fwd`. + rng_sharding = (f"…{len(value_types)}",) + + if config.qkv_layout.is_qkvpacked(): + input_spec[0] = ("…0", "seqlen", "three", "head", "hidden") + elif config.qkv_layout.is_kvpacked() or config.qkv_layout.is_separate(): + input_spec[0] = ("…0", "seqlen", "head", "hidden") + else: + raise ValueError(f"Unsupported {config.qkv_layout=}") + + is_packed_softmax = get_cudnn_version() >= (9, 6, 0) and config.qkv_layout.is_thd() + out_sharding = ("…0", "seqlen", "head", "hidden") + if is_packed_softmax: + softmax_aux_sharding = ("…0", "seqlen", "head", "i") + else: + softmax_aux_sharding = ("…0", "head", "seqlen", "i") + + return SdyShardingRule( + tuple(input_spec), (out_sharding, softmax_aux_sharding, rng_sharding) + ) + + +register_primitive(FusedAttnFwdPrimitive) + + +class FusedAttnBwdPrimitive(BasePrimitive): + """ + Fused Attention Backward Primitive + """ + + name = "te_fused_attn_backward_ffi" + multiple_results = True + impl_static_args = (16,) + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + q_aval, + k_aval, + v_aval, + bias_aval, + softmax_aux_aval, + rng_state_aval, + output_aval, + doutput_aval, + q_seqlen_or_cu_seqlen_aval, + kv_seqlen_or_cu_seqlen_aval, + _q_seq_offsets, + _k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + *, + config, + ): + """ + Fused attention bwd abstract + """ + del softmax_aux_aval, rng_state_aval, output_aval + + q_dtype = dtypes.canonicalize_dtype(q_aval.dtype) + k_dtype = dtypes.canonicalize_dtype(k_aval.dtype) + v_dtype = dtypes.canonicalize_dtype(v_aval.dtype) + bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) + doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype) + assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype + assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype + + ( + batch_shape, + q_max_seqlen, + kv_max_seqlen, + attn_heads, + num_gqa_groups, + qk_head_dim, + v_head_dim, + ) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) + + if config.attn_bias_type == AttnBiasType.NO_BIAS: + bias_batch = bias_heads = 0 + else: + *bias_batch_shape, bias_heads, _, _ = bias_aval.shape + bias_batch = reduce(operator.mul, bias_batch_shape) + + deterministic = not FusedAttnHelper.is_non_deterministic_allowed() + + input_batch = reduce(operator.mul, batch_shape) + wkspace_shape, wkspace_dtype = transformer_engine_jax.get_fused_attn_bwd_workspace_sizes( + input_batch, + bias_batch, + q_max_seqlen, + kv_max_seqlen, + attn_heads, + num_gqa_groups, + bias_heads, + qk_head_dim, + v_head_dim, + config.scaling_factor, + config.dropout_probability, + config.attn_bias_type.value, + config.attn_mask_type.value, + config.qkv_layout.value, + jax_dtype_to_te_dtype(q_aval.dtype), + config.is_training, + deterministic, + config.max_segments_per_seq, + config.window_size[0], + config.window_size[1], + ) + + dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype) + dk_aval = k_aval.update(shape=k_aval.shape, dtype=k_dtype) + dv_aval = v_aval.update(shape=v_aval.shape, dtype=v_dtype) + dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype) + wkspace_aval = q_aval.update( + shape=wkspace_shape, dtype=te_dtype_to_jax_dtype(wkspace_dtype) + ) + + return dq_aval, dk_aval, dv_aval, dbias_aval, wkspace_aval + + @staticmethod + def outer_abstract(*args, **kwargs): + """ + Fused attention fwd outer primitive abstract + """ + dq_aval, dk_aval, dv_aval, dbias_aval, _ = FusedAttnBwdPrimitive.abstract(*args, **kwargs) + return dq_aval, dk_aval, dv_aval, dbias_aval + + @staticmethod + def lowering( + ctx, + q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_cu_seqlen, + kv_cu_seqlen, + q_seq_offsets, + k_seq_offsets, + q_segment_ids, + kv_segment_ids, + q_segment_pos, + kv_segment_pos, + *, + config, + ): + """ + Fused attention bwd lowering rules + """ + q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in + + ( + batch_shape, + q_max_seqlen, + kv_max_seqlen, + attn_heads, + num_gqa_groups, + qk_head_dim, + v_head_dim, + ) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) + + input_batch = reduce(operator.mul, batch_shape) + + if config.attn_bias_type == AttnBiasType.NO_BIAS: + bias_batch = bias_heads = 0 + else: + *bias_batch_shape, bias_heads, _, _ = bias_aval.shape + bias_batch = reduce(operator.mul, bias_batch_shape) + + if config.cp_striped_window_size is not None: + window_size_left = config.cp_striped_window_size[0] + window_size_right = config.cp_striped_window_size[1] + else: + window_size_left = config.window_size[0] + window_size_right = config.window_size[1] + + return ffi.ffi_lowering(FusedAttnBwdPrimitive.name)( + ctx, + q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_cu_seqlen, + kv_cu_seqlen, + q_seq_offsets, + k_seq_offsets, + q_segment_ids, + kv_segment_ids, + q_segment_pos, + kv_segment_pos, # ffi_lowering needs number of parameters meets primitive.lowering + input_batch=input_batch, + bias_batch=bias_batch, + q_max_seqlen=q_max_seqlen, + kv_max_seqlen=kv_max_seqlen, + attn_heads=attn_heads, + num_gqa_groups=num_gqa_groups, + bias_heads=bias_heads, + qk_head_dim=qk_head_dim, + v_head_dim=v_head_dim, + max_segments_per_seq=config.max_segments_per_seq, + scaling_factor=float(config.scaling_factor), + dropout_probability=float(config.dropout_probability), + bias_type=int(config.attn_bias_type.value), + mask_type=int(config.attn_mask_type.value), + qkv_layout=int(config.qkv_layout.value), + is_training=config.is_training, + deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), + window_size_left=window_size_left, + window_size_right=window_size_right, + ) + + @staticmethod + def impl( + q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + config, + ): + assert FusedAttnBwdPrimitive.inner_primitive is not None + + sequence_descriptor = SequenceDescriptor( + seqlens=(q_seqlen, kv_seqlen), + seq_offsets=(q_seq_offsets, k_seq_offsets), + segment_ids=(_q_segment_ids, _kv_segment_ids), + segment_pos=(_q_segment_pos, _kv_segment_pos), + ) + + (q_seqlen, kv_seqlen), (q_seq_offsets, k_seq_offsets) = ( + sequence_descriptor.get_seqlens_and_offsets( + config.attn_mask_type, + config.qkv_layout, + config.window_size, + config.max_segments_per_seq, + ) + ) + + if config.qkv_layout.is_thd(): + + def _fix_len_take(x, condition, fill_value=-1): + x_shape = x.shape + x = x.flatten() + size = x.size + indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0] + # TODO(rewang): try indices_are_sorted + y = jnp.take(x, indices, fill_value=fill_value) + return jnp.reshape(y, x_shape) + + def convert_to_2d(offsets, batch, max_seqlen): + offsets_2d = jnp.where( + offsets >= 0, + offsets + (jnp.arange(batch) * max_seqlen)[..., jnp.newaxis], + offsets, + ) + return offsets_2d + + batch, q_max_seqlen, kv_max_seqlen, *_ = FusedAttnHelper.parse_qkv_aval( + q, k, v, config.qkv_layout + ) + assert len(batch) == 1 + kv_batch = q_batch = batch[0] + + # Gather valid q_seqlen, which is greater than 0 + # cuDNN version < 9.3.0: + # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]] + # cuDNN version >= 9.3.0, which supports act_seqlen = 0 + # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, 0, 0, 0, 0]] + if get_cudnn_version() >= (9, 3, 0): + fill_value = 0 + else: + fill_value = -1 + q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value) + kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value) + + # Flatten the offset calculation + # max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]] + q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen) + k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen) + + # Gather valid q_seq_offsets, which is greater and equal to 0 + # [[0, 3, 5, -1], [8, 11, 13, -1]] -> [[0, 3, 5, 8], [11, 13, -1, -1]] + # And set the unused position to max size (batch * max_seqlen) + # [[0, 3, 5, 8], [11, 13, -1, -1]] -> [[0, 3, 5, 8], [11, 13, b*s, b*s]] + q_seq_offsets = _fix_len_take( + q_seq_offsets, q_seq_offsets >= 0, fill_value=q_batch * q_max_seqlen + ) + k_seq_offsets = _fix_len_take( + k_seq_offsets, k_seq_offsets >= 0, fill_value=kv_batch * kv_max_seqlen + ) + + q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten()) + kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten()) + + dq, dk, dv, dbias, _ = FusedAttnBwdPrimitive.inner_primitive.bind( + q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_cu_seqlen, + kv_cu_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + config=config, + ) + return dq, dk, dv, dbias + + @staticmethod + def batcher(batched_args, batch_dims, *, config): + check_valid_batch_dims(batch_dims) + assert FusedAttnBwdPrimitive.outer_primitive is not None + q_bdim, k_bdim, v_bdim, *_ = batch_dims + + out_bdims = q_bdim, k_bdim, v_bdim, q_bdim + return ( + FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, config=config), + out_bdims, + ) + + @staticmethod + def infer_sharding_from_operands(config, mesh, arg_infos, result_infos): + del config, result_infos + q_spec = get_padded_spec(arg_infos[0]) + k_spec = get_padded_spec(arg_infos[1]) + v_spec = get_padded_spec(arg_infos[2]) + bias_spec = get_padded_spec(arg_infos[3]) + dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) + dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) + dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) + return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) + + @staticmethod + def partition(config, mesh, arg_infos, result_infos): + del result_infos + q_spec = get_padded_spec(arg_infos[0]) + k_spec = get_padded_spec(arg_infos[1]) + v_spec = get_padded_spec(arg_infos[2]) + bias_spec = get_padded_spec(arg_infos[3]) + dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) + dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) + dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) + arg_shardings = [arg_i.sharding for arg_i in arg_infos] + arg_shardings[-1] = arg_shardings[-3] + arg_shardings[-2] = arg_shardings[-4] + arg_shardings = tuple(arg_shardings) + out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) + + def sharded_impl( + q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_cu_seqlen, + kv_cu_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + ): + local_dq, local_dk, local_dv, local_dbias = FusedAttnBwdPrimitive.impl( + q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_cu_seqlen, + kv_cu_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + config=config, + ) + global_dbias = local_dbias + if config.attn_bias_type is not AttnBiasType.NO_BIAS: + global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh) + return local_dq, local_dk, local_dv, global_dbias + + return mesh, sharded_impl, out_shardings, arg_shardings + + @staticmethod + def shardy_sharding_rule(config, mesh, value_types, result_types): + if version.parse(jax.__version__) < version.parse("0.5.0"): + raise ImportError("JAX version 0.5.0 or later is required for shardy sharding.") + del config, mesh + # We only care about the four first arguments. + # Keep in sync with `infer_sharding_from_operands`. + input_spec = tuple((f"…{x}",) for x in range(len(value_types))) + output_spec = tuple((f"…{x}",) for x in range(len(result_types))) + return SdyShardingRule(input_spec, output_spec) + + +register_primitive(FusedAttnBwdPrimitive) + + +class SmallSeqAttnFwdPrimitive(BasePrimitive): + """ROCm small-sequence cross-attention forward""" + + name = "te_small_seq_attn_forward_ffi" + multiple_results = True + impl_static_args = (13,) + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + q_aval, + k_aval, + v_aval, + bias_aval, + seed_aval, + q_seqlen_or_cu_seqlen_aval, + kv_seqlen_or_cu_seqlen_aval, + _q_seq_offsets, + _k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + *, + config: _FusedAttnConfig, + ): + if not is_hip_extension(): + raise ValueError( + "Small-seq attention requires Transformer Engine built for ROCm (HIP extension)." + ) + if not hasattr(transformer_engine_jax, "get_small_seq_attn_fwd_workspace_sizes"): + raise ValueError( + "Small-seq workspace helpers are unavailable; use a ROCm build of transformer_engine_jax." + ) + if config.qkv_layout != QKVLayout.THD_THD_THD: + raise ValueError(f"Small-seq requires QKVLayout.THD_THD_THD, got {config.qkv_layout}") + if config.attn_bias_type != AttnBiasType.NO_BIAS: + raise ValueError("Small-seq does not support attention bias.") + if config.dropout_probability != 0: + raise ValueError("Small-seq kernel path does not support dropout.") + + q_dtype = dtypes.canonicalize_dtype(q_aval.dtype) + k_dtype = dtypes.canonicalize_dtype(k_aval.dtype) + v_dtype = dtypes.canonicalize_dtype(v_aval.dtype) + bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) + assert q_dtype == k_dtype == v_dtype == bias_dtype + assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype + + ( + batch_shape, + q_max_seqlen, + kv_max_seqlen, + attn_heads, + num_gqa_groups, + q_head_dim, + v_head_dim, + ) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) + + if q_head_dim != v_head_dim or q_head_dim not in (128, 256, 512): + raise ValueError( + "Small-seq requires matching q/v head dims in {128, 256, 512}, " + f"got qk={q_head_dim} v={v_head_dim}" + ) + + output_shape = (*batch_shape, q_max_seqlen, attn_heads, v_head_dim) + out_aval = q_aval.update(shape=output_shape, dtype=q_dtype) + + kv_eff = min(kv_max_seqlen, 16) + softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_eff) + softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=q_dtype) + + checker = _FusedAttnRNGStateChecker() + seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype) + assert seed_dtype == checker.rng_state_dtype + rng_state_shape = (seed_aval.shape[0], checker.rng_state_size) + rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype) + + bias_batch = bias_heads = 0 + input_batch = reduce(operator.mul, batch_shape) + wkspace_info = transformer_engine_jax.get_small_seq_attn_fwd_workspace_sizes( + input_batch, + bias_batch, + q_max_seqlen, + kv_max_seqlen, + attn_heads, + num_gqa_groups, + bias_heads, + q_head_dim, + v_head_dim, + config.scaling_factor, + config.dropout_probability, + config.attn_bias_type.value, + config.attn_mask_type.value, + config.qkv_layout.value, + jax_dtype_to_te_dtype(q_aval.dtype), + config.is_training, + config.max_segments_per_seq, + config.window_size[0], + config.window_size[1], + ) + wkspace_aval = q_aval.update( + shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) + ) + + return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval + + @staticmethod + def outer_abstract(*args, **kwargs): + out_aval, softmax_aux_aval, rng_state_aval, _ = SmallSeqAttnFwdPrimitive.abstract( + *args, **kwargs + ) + return out_aval, softmax_aux_aval, rng_state_aval + + @staticmethod + def lowering( + ctx, + q, + k, + v, + bias, + seed, + q_cu_seqlen, + kv_cu_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + *, + config: _FusedAttnConfig, + ): + q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in + ( + batch_shape, + q_max_seqlen, + kv_max_seqlen, + attn_heads, + num_gqa_groups, + q_head_dim, + v_head_dim, + ) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) + input_batch = reduce(operator.mul, batch_shape) + bias_batch = bias_heads = 0 + window_size_left = config.window_size[0] + window_size_right = config.window_size[1] + + return ffi.ffi_lowering(SmallSeqAttnFwdPrimitive.name)( + ctx, + q, + k, + v, + bias, + seed, + q_cu_seqlen, + kv_cu_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + input_batch=input_batch, + bias_batch=bias_batch, + q_max_seqlen=q_max_seqlen, + kv_max_seqlen=kv_max_seqlen, + attn_heads=attn_heads, + num_gqa_groups=num_gqa_groups, + bias_heads=bias_heads, + qk_head_dim=q_head_dim, + v_head_dim=v_head_dim, + max_segments_per_seq=config.max_segments_per_seq, + scaling_factor=float(config.scaling_factor), + dropout_probability=float(config.dropout_probability), + bias_type=int(config.attn_bias_type.value), + mask_type=int(config.attn_mask_type.value), + qkv_layout=int(config.qkv_layout.value), + is_training=config.is_training, + deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), + window_size_left=window_size_left, + window_size_right=window_size_right, + ) + + @staticmethod + def impl( + q, + k, + v, + bias, + seed, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + config: _FusedAttnConfig, + ): + assert SmallSeqAttnFwdPrimitive.inner_primitive is not None + sequence_descriptor = SequenceDescriptor( + seqlens=(q_seqlen, kv_seqlen), + seq_offsets=(q_seq_offsets, k_seq_offsets), + segment_ids=(_q_segment_ids, _kv_segment_ids), + segment_pos=(_q_segment_pos, _kv_segment_pos), + ) + (q_seqlen, kv_seqlen), (q_seq_offsets, k_seq_offsets) = ( + sequence_descriptor.get_seqlens_and_offsets( + config.attn_mask_type, + config.qkv_layout, + config.window_size, + config.max_segments_per_seq, + ) + ) + if config.qkv_layout.is_thd(): + + def _fix_len_take(x, condition, fill_value=-1): + x_shape = x.shape + x = x.flatten() + size = x.size + indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0] + y = jnp.take(x, indices, fill_value=fill_value) + return jnp.reshape(y, x_shape) + + def convert_to_2d(offsets, batch, max_seqlen): + offsets_2d = jnp.where( + offsets >= 0, + offsets + (jnp.arange(batch) * max_seqlen)[..., jnp.newaxis], + offsets, + ) + return offsets_2d + + batch, q_max_seqlen, kv_max_seqlen, *_ = FusedAttnHelper.parse_qkv_aval( + q, k, v, config.qkv_layout + ) + assert len(batch) == 1 + kv_batch = q_batch = batch[0] + if get_cudnn_version() >= (9, 3, 0): + fill_value = 0 + else: + fill_value = -1 q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value) kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value) - - # Flatten the offset calculation - # max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]] q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen) k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen) - - # Gather valid q_seq_offsets, which is greater and equal to 0 - # [[0, 3, 5, -1], [8, 11, 13, -1]] -> [[0, 3, 5, 8], [11, 13, -1, -1]] - # And set the unused position to max size (batch * max_seqlen) - # [[0, 3, 5, 8], [11, 13, -1, -1]] -> [[0, 3, 5, 8], [11, 13, b*s, b*s]] q_seq_offsets = _fix_len_take( q_seq_offsets, q_seq_offsets >= 0, fill_value=q_batch * q_max_seqlen ) @@ -631,7 +1408,7 @@ def convert_to_2d(offsets, batch, max_seqlen): q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten()) kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten()) - output, softmax_aux, rng_state, _ = FusedAttnFwdPrimitive.inner_primitive.bind( + output, softmax_aux, rng_state, _ = SmallSeqAttnFwdPrimitive.inner_primitive.bind( q, k, v, @@ -652,62 +1429,22 @@ def convert_to_2d(offsets, batch, max_seqlen): @staticmethod def batcher(batched_args, batch_dims, *, config): check_valid_batch_dims(batch_dims) - assert FusedAttnFwdPrimitive.outer_primitive is not None + assert SmallSeqAttnFwdPrimitive.outer_primitive is not None q_bdim, _, _, _, seed_bdim, *_ = batch_dims - out_bdims = q_bdim, q_bdim, seed_bdim return ( - FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args, config=config), + SmallSeqAttnFwdPrimitive.outer_primitive.bind(*batched_args, config=config), out_bdims, ) @staticmethod def infer_sharding_from_operands(config, mesh, arg_infos, result_infos): - del result_infos + del config, result_infos q_spec = get_padded_spec(arg_infos[0]) - - # when supported softmax_aux shape is (b, s, h, 1) for thd on cudnn 9.6+ - # otherwise softmax_aux shape is (b, h, s, 1) or (b, h, s, max_segments) - is_packed_softmax = get_cudnn_version() >= (9, 6, 0) and config.qkv_layout.is_thd() - - if config.qkv_layout.is_qkvpacked(): - # q_spec = (...batch, q_seqlen, 3, head, hidden) - out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:])) - if not is_packed_softmax: - softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None) - ) - else: - softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-4], q_spec[-4], q_spec[-2], None) - ) - elif config.qkv_layout.is_kvpacked(): - # q_spec = (...batch, q_seqlen, head, hidden) - # k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden) - out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) - if not is_packed_softmax: - softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None) - ) - else: - softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-3], q_spec[-3], q_spec[-2], None) - ) - elif config.qkv_layout.is_separate(): - # q_spec = (...batch, q_seqlen, head, hidden) - # k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden) - out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) - if not is_packed_softmax: - softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None) - ) - else: - softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-3], q_spec[-3], q_spec[-2], None) - ) - else: - raise ValueError(f"Unsupported {config.qkv_layout=}") - + out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + softmax_aux_sharding = NamedSharding( + mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None) + ) rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None)) return (out_sharding, softmax_aux_sharding, rng_state_sharding) @@ -724,7 +1461,7 @@ def partition(config, mesh, arg_infos, result_infos): arg_shardings[-2] = arg_shardings[-4] arg_shardings = tuple(arg_shardings) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) - impl = partial(FusedAttnFwdPrimitive.impl, config=config) + impl = partial(SmallSeqAttnFwdPrimitive.impl, config=config) return mesh, impl, out_shardings, arg_shardings @staticmethod @@ -732,42 +1469,20 @@ def shardy_sharding_rule(config, mesh, value_types, result_types): if version.parse(jax.__version__) < version.parse("0.5.0"): raise ImportError("JAX version 0.5.0 or later is required for shardy sharding.") del mesh, result_types - - # Keep in sync with `infer_sharding_from_operands`. - # We only need the first input. Fill up the rest with placeholders. input_spec = [(f"…{x}",) for x in range(len(value_types))] - # The RNG state sharding cannot be expressed as a Shardy rule. We use with_sharding_constraint - # instead. This has to happen outside of the primitive, see `fused_attn_fwd`. rng_sharding = (f"…{len(value_types)}",) - - if config.qkv_layout.is_qkvpacked(): - input_spec[0] = ("…0", "seqlen", "three", "head", "hidden") - elif config.qkv_layout.is_kvpacked() or config.qkv_layout.is_separate(): - input_spec[0] = ("…0", "seqlen", "head", "hidden") - else: - raise ValueError(f"Unsupported {config.qkv_layout=}") - - is_packed_softmax = get_cudnn_version() >= (9, 6, 0) and config.qkv_layout.is_thd() + input_spec[0] = ("…0", "seqlen", "head", "hidden") out_sharding = ("…0", "seqlen", "head", "hidden") - if is_packed_softmax: - softmax_aux_sharding = ("…0", "seqlen", "head", "i") - else: - softmax_aux_sharding = ("…0", "head", "seqlen", "i") - + softmax_aux_sharding = ("…0", "head", "seqlen", "i") return SdyShardingRule( tuple(input_spec), (out_sharding, softmax_aux_sharding, rng_sharding) ) -register_primitive(FusedAttnFwdPrimitive) - - -class FusedAttnBwdPrimitive(BasePrimitive): - """ - Fused Attention Backward Primitive - """ +class SmallSeqAttnBwdPrimitive(BasePrimitive): + """ROCm small-sequence cross-attention backward.""" - name = "te_fused_attn_backward_ffi" + name = "te_small_seq_attn_backward_ffi" multiple_results = True impl_static_args = (16,) inner_primitive = None @@ -794,9 +1509,14 @@ def abstract( *, config, ): - """ - Fused attention bwd abstract - """ + if not is_hip_extension(): + raise ValueError( + "Small-seq attention requires Transformer Engine built for ROCm (HIP extension)." + ) + if not hasattr(transformer_engine_jax, "get_small_seq_attn_bwd_workspace_sizes"): + raise ValueError( + "Small-seq workspace helpers are unavailable; use a ROCm build of transformer_engine_jax." + ) del softmax_aux_aval, rng_state_aval, output_aval q_dtype = dtypes.canonicalize_dtype(q_aval.dtype) @@ -817,16 +1537,10 @@ def abstract( v_head_dim, ) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) - if config.attn_bias_type == AttnBiasType.NO_BIAS: - bias_batch = bias_heads = 0 - else: - *bias_batch_shape, bias_heads, _, _ = bias_aval.shape - bias_batch = reduce(operator.mul, bias_batch_shape) - + bias_batch = bias_heads = 0 deterministic = not FusedAttnHelper.is_non_deterministic_allowed() - input_batch = reduce(operator.mul, batch_shape) - wkspace_shape, wkspace_dtype = transformer_engine_jax.get_fused_attn_bwd_workspace_sizes( + wkspace_shape, wkspace_dtype = transformer_engine_jax.get_small_seq_attn_bwd_workspace_sizes( input_batch, bias_batch, q_max_seqlen, @@ -856,15 +1570,13 @@ def abstract( wkspace_aval = q_aval.update( shape=wkspace_shape, dtype=te_dtype_to_jax_dtype(wkspace_dtype) ) - return dq_aval, dk_aval, dv_aval, dbias_aval, wkspace_aval @staticmethod def outer_abstract(*args, **kwargs): - """ - Fused attention fwd outer primitive abstract - """ - dq_aval, dk_aval, dv_aval, dbias_aval, _ = FusedAttnBwdPrimitive.abstract(*args, **kwargs) + dq_aval, dk_aval, dv_aval, dbias_aval, _ = SmallSeqAttnBwdPrimitive.abstract( + *args, **kwargs + ) return dq_aval, dk_aval, dv_aval, dbias_aval @staticmethod @@ -889,11 +1601,7 @@ def lowering( *, config, ): - """ - Fused attention bwd lowering rules - """ q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in - ( batch_shape, q_max_seqlen, @@ -903,23 +1611,12 @@ def lowering( qk_head_dim, v_head_dim, ) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) - input_batch = reduce(operator.mul, batch_shape) + bias_batch = bias_heads = 0 + window_size_left = config.window_size[0] + window_size_right = config.window_size[1] - if config.attn_bias_type == AttnBiasType.NO_BIAS: - bias_batch = bias_heads = 0 - else: - *bias_batch_shape, bias_heads, _, _ = bias_aval.shape - bias_batch = reduce(operator.mul, bias_batch_shape) - - if config.cp_striped_window_size is not None: - window_size_left = config.cp_striped_window_size[0] - window_size_right = config.cp_striped_window_size[1] - else: - window_size_left = config.window_size[0] - window_size_right = config.window_size[1] - - return ffi.ffi_lowering(FusedAttnBwdPrimitive.name)( + return ffi.ffi_lowering(SmallSeqAttnBwdPrimitive.name)( ctx, q, k, @@ -936,7 +1633,7 @@ def lowering( q_segment_ids, kv_segment_ids, q_segment_pos, - kv_segment_pos, # ffi_lowering needs number of parameters meets primitive.lowering + kv_segment_pos, input_batch=input_batch, bias_batch=bias_batch, q_max_seqlen=q_max_seqlen, @@ -978,15 +1675,13 @@ def impl( _kv_segment_pos, config, ): - assert FusedAttnBwdPrimitive.inner_primitive is not None - + assert SmallSeqAttnBwdPrimitive.inner_primitive is not None sequence_descriptor = SequenceDescriptor( seqlens=(q_seqlen, kv_seqlen), seq_offsets=(q_seq_offsets, k_seq_offsets), segment_ids=(_q_segment_ids, _kv_segment_ids), segment_pos=(_q_segment_pos, _kv_segment_pos), ) - (q_seqlen, kv_seqlen), (q_seq_offsets, k_seq_offsets) = ( sequence_descriptor.get_seqlens_and_offsets( config.attn_mask_type, @@ -995,7 +1690,6 @@ def impl( config.max_segments_per_seq, ) ) - if config.qkv_layout.is_thd(): def _fix_len_take(x, condition, fill_value=-1): @@ -1003,7 +1697,6 @@ def _fix_len_take(x, condition, fill_value=-1): x = x.flatten() size = x.size indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0] - # TODO(rewang): try indices_are_sorted y = jnp.take(x, indices, fill_value=fill_value) return jnp.reshape(y, x_shape) @@ -1020,28 +1713,14 @@ def convert_to_2d(offsets, batch, max_seqlen): ) assert len(batch) == 1 kv_batch = q_batch = batch[0] - - # Gather valid q_seqlen, which is greater than 0 - # cuDNN version < 9.3.0: - # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]] - # cuDNN version >= 9.3.0, which supports act_seqlen = 0 - # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, 0, 0, 0, 0]] if get_cudnn_version() >= (9, 3, 0): fill_value = 0 else: fill_value = -1 q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value) kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value) - - # Flatten the offset calculation - # max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]] q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen) k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen) - - # Gather valid q_seq_offsets, which is greater and equal to 0 - # [[0, 3, 5, -1], [8, 11, 13, -1]] -> [[0, 3, 5, 8], [11, 13, -1, -1]] - # And set the unused position to max size (batch * max_seqlen) - # [[0, 3, 5, 8], [11, 13, -1, -1]] -> [[0, 3, 5, 8], [11, 13, b*s, b*s]] q_seq_offsets = _fix_len_take( q_seq_offsets, q_seq_offsets >= 0, fill_value=q_batch * q_max_seqlen ) @@ -1052,7 +1731,7 @@ def convert_to_2d(offsets, batch, max_seqlen): q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten()) kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten()) - dq, dk, dv, dbias, _ = FusedAttnBwdPrimitive.inner_primitive.bind( + dq, dk, dv, dbias, _ = SmallSeqAttnBwdPrimitive.inner_primitive.bind( q, k, v, @@ -1076,12 +1755,11 @@ def convert_to_2d(offsets, batch, max_seqlen): @staticmethod def batcher(batched_args, batch_dims, *, config): check_valid_batch_dims(batch_dims) - assert FusedAttnBwdPrimitive.outer_primitive is not None + assert SmallSeqAttnBwdPrimitive.outer_primitive is not None q_bdim, k_bdim, v_bdim, *_ = batch_dims - out_bdims = q_bdim, k_bdim, v_bdim, q_bdim return ( - FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, config=config), + SmallSeqAttnBwdPrimitive.outer_primitive.bind(*batched_args, config=config), out_bdims, ) @@ -1133,7 +1811,7 @@ def sharded_impl( _q_segment_pos, _kv_segment_pos, ): - local_dq, local_dk, local_dv, local_dbias = FusedAttnBwdPrimitive.impl( + return SmallSeqAttnBwdPrimitive.impl( q, k, v, @@ -1152,10 +1830,6 @@ def sharded_impl( _kv_segment_pos, config=config, ) - global_dbias = local_dbias - if config.attn_bias_type is not AttnBiasType.NO_BIAS: - global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh) - return local_dq, local_dk, local_dv, global_dbias return mesh, sharded_impl, out_shardings, arg_shardings @@ -1164,14 +1838,13 @@ def shardy_sharding_rule(config, mesh, value_types, result_types): if version.parse(jax.__version__) < version.parse("0.5.0"): raise ImportError("JAX version 0.5.0 or later is required for shardy sharding.") del config, mesh - # We only care about the four first arguments. - # Keep in sync with `infer_sharding_from_operands`. input_spec = tuple((f"…{x}",) for x in range(len(value_types))) output_spec = tuple((f"…{x}",) for x in range(len(result_types))) return SdyShardingRule(input_spec, output_spec) -register_primitive(FusedAttnBwdPrimitive) +register_primitive(SmallSeqAttnFwdPrimitive) +register_primitive(SmallSeqAttnBwdPrimitive) def reorder_causal_dual_chunk_swap(tensor, cp_size: int, seq_dim: int, to_contiguous: bool): @@ -2841,3 +3514,119 @@ def fused_attn_bwd( config=fused_config, ) return tuple(qkv_grads[: len(qkv)]), bias_grad + + +def fused_attn_small_seq_fwd( + qkv: Tuple[jnp.ndarray, ...], + bias: Optional[jnp.ndarray], + sequence_descriptor: SequenceDescriptor, + seed: Optional[jnp.ndarray], + attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, + qkv_layout: QKVLayout, + scaling_factor: float, + dropout_probability: float, + is_training: bool, + max_segments_per_seq: int, + window_size: Optional[Tuple[int, int]] = None, +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """ + Forward pass for ROCm small-sequence cross-attention (explicit TE API). + """ + if not is_hip_extension(): + raise RuntimeError( + "fused_attn_small_seq_fwd requires Transformer Engine built for ROCm (HIP)." + ) + if qkv_layout != QKVLayout.THD_THD_THD: + raise ValueError( + f"fused_attn_small_seq_fwd requires QKVLayout.THD_THD_THD, got {qkv_layout}" + ) + if len(qkv) != 3: + raise ValueError(f"fused_attn_small_seq_fwd expects qkv=(q, k, v), got len={len(qkv)}") + + seed = _FusedAttnRNGStateChecker().check_seed(seed, dropout_probability, is_training) + if attn_bias_type == AttnBiasType.NO_BIAS: + assert bias is None + bias = jnp.zeros(0, dtype=qkv[0].dtype) + else: + raise ValueError("fused_attn_small_seq_fwd only supports AttnBiasType.NO_BIAS") + + fused_config = _FusedAttnConfig( + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training, + max_segments_per_seq=max_segments_per_seq, + window_size=(-1, -1) if window_size is None else window_size, + context_parallel_load_balanced=False, + cp_axis="", + cp_striped_window_size=None, + ) + + seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor) + output, softmax_aux, rng_state = SmallSeqAttnFwdPrimitive.outer_primitive.bind( + *qkv, + bias, + seed, + *seq_desc_flatten, + config=fused_config, + ) + rng_state = with_sharding_constraint(rng_state, PartitionSpec(get_all_mesh_axes(), None)) + return (output, softmax_aux, rng_state) + + +def fused_attn_small_seq_bwd( + qkv: Tuple[jnp.ndarray, ...], + bias: Optional[jnp.ndarray], + softmax_aux: jnp.ndarray, + rng_state: jnp.ndarray, + output: jnp.ndarray, + doutput: jnp.ndarray, + sequence_descriptor: SequenceDescriptor, + attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, + qkv_layout: QKVLayout, + scaling_factor: float, + dropout_probability: float, + is_training: bool, + max_segments_per_seq: int, + window_size: Optional[Tuple[int, int]] = None, +): + if not is_hip_extension(): + raise RuntimeError( + "fused_attn_small_seq_bwd requires Transformer Engine built for ROCm (HIP)." + ) + if len(qkv) != 3: + raise ValueError(f"fused_attn_small_seq_bwd expects qkv=(q, k, v), got len={len(qkv)}") + if attn_bias_type == AttnBiasType.NO_BIAS: + assert bias is None + bias = jnp.zeros(0, dtype=qkv[0].dtype) + + fused_config = _FusedAttnConfig( + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training, + max_segments_per_seq=max_segments_per_seq, + window_size=(-1, -1) if window_size is None else window_size, + context_parallel_load_balanced=False, + cp_axis="", + cp_striped_window_size=None, + ) + + seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor) + dq, dk, dv, _dbias = SmallSeqAttnBwdPrimitive.outer_primitive.bind( + *qkv, + bias, + softmax_aux, + rng_state, + output, + doutput, + *seq_desc_flatten, + config=fused_config, + ) + return (dq, dk, dv), None From 0198a0e97a41899085257c7604e4318191c90706 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Tue, 21 Apr 2026 16:22:18 +0000 Subject: [PATCH 15/21] feat(jax): public fused_attn_small_seq API with custom_vjp MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Expose ROCm small-sequence cross-attention at the JAX layer next to fused_attn. - Custom primitive _fused_attn_small_seq with forward/backward rules calling cpp_extensions fused_attn_small_seq_fwd/bwd (tex.*). - fused_attn_small_seq(): user entry point taking (q,k,v), bias slot, SequenceDescriptor, seed, mask/layout/scaling/dropout/is_training — targets the explicit small-seq backend. --- transformer_engine/jax/attention.py | 167 ++++++++++++++++++++++++++++ 1 file changed, 167 insertions(+) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 093146162..7d3beb182 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -1001,6 +1001,130 @@ def _fused_attn_bwd_rule( _fused_attn.defvjp(_fused_attn_fwd_rule, _fused_attn_bwd_rule) +@partial( + jax.custom_vjp, + nondiff_argnums=(4, 5, 6, 7, 8, 9, 10, 11), +) +def _fused_attn_small_seq( + qkv: Tuple[jnp.ndarray, ...], + bias: Optional[jnp.ndarray], + sequence_descriptor: SequenceDescriptor, + seed: Optional[jnp.ndarray], + attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, + qkv_layout: QKVLayout, + scaling_factor: float, + dropout_probability: float, + is_training: bool, + max_segments_per_seq: int, + window_size: Optional[Tuple[int, int]], + context_checkpoint_name: str = "context", +): + output, _ = _fused_attn_small_seq_fwd_rule( + qkv, + bias, + sequence_descriptor, + seed, + attn_bias_type, + attn_mask_type, + qkv_layout, + scaling_factor, + dropout_probability, + is_training, + max_segments_per_seq, + window_size, + context_checkpoint_name=context_checkpoint_name, + ) + return output + + +def _fused_attn_small_seq_fwd_rule( + qkv, + bias, + sequence_descriptor, + seed, + attn_bias_type, + attn_mask_type, + qkv_layout, + scaling_factor, + dropout_probability, + is_training, + max_segments_per_seq, + window_size, + context_checkpoint_name, +): + output, softmax_aux, rng_state = tex.fused_attn_small_seq_fwd( + qkv, + bias, + sequence_descriptor, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training, + max_segments_per_seq=max_segments_per_seq, + window_size=window_size, + ) + output = checkpoint_name(output, context_checkpoint_name) + softmax_aux = checkpoint_name(softmax_aux, context_checkpoint_name) + rng_state = checkpoint_name(rng_state, context_checkpoint_name) + return output, ( + qkv, + bias, + sequence_descriptor, + softmax_aux, + rng_state, + output, + ) + + +def _fused_attn_small_seq_bwd_rule( + attn_bias_type, + attn_mask_type, + qkv_layout, + scaling_factor, + dropout_probability, + is_training, + max_segments_per_seq, + window_size, + context_checkpoint_name, + ctx, + dz, +): + del context_checkpoint_name + (qkv, bias, sequence_descriptor, softmax_aux, rng_state, output) = ctx + grad_qkv, grad_bias = tex.fused_attn_small_seq_bwd( + qkv, + bias, + softmax_aux, + rng_state, + output, + dz, + sequence_descriptor, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training, + max_segments_per_seq=max_segments_per_seq, + window_size=window_size, + ) + if attn_bias_type == AttnBiasType.NO_BIAS: + grad_bias = None + return ( + grad_qkv, + grad_bias, + None, + None, + ) + + +_fused_attn_small_seq.defvjp(_fused_attn_small_seq_fwd_rule, _fused_attn_small_seq_bwd_rule) + + def fused_attn( qkv: Tuple[jnp.ndarray, ...], bias: Optional[jnp.ndarray], @@ -1127,3 +1251,46 @@ def fused_attn( ) return output + + +def fused_attn_small_seq( + qkv: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], + bias: Optional[jnp.ndarray], + sequence_descriptor: SequenceDescriptor, + seed: Optional[jnp.ndarray], + attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, + qkv_layout: QKVLayout, + scaling_factor: float, + dropout_probability: float, + is_training: bool, + max_segments_per_seq: int = 1, + window_size: Optional[Tuple[int, int]] = None, + context_checkpoint_name: str = "context", +): + """ + ROCm small-sequence varlen cross-attention (explicit backend API). + + This entry point calls the dedicated small-seq kernels (not generic fused-attn backend + selection). Intended for THD layouts with compile-time head dimensions 128 / 256 / 512 (``d_qk == d_v``), FP16/BF16, + and ``AttnBiasType.NO_BIAS``. + Context parallelism is not supported. + """ + if sequence_descriptor is None or isinstance(sequence_descriptor, jnp.ndarray): + raise ValueError("fused_attn_small_seq requires a SequenceDescriptor.") + + return _fused_attn_small_seq( + qkv, + bias, + sequence_descriptor, + seed, + attn_bias_type, + attn_mask_type, + qkv_layout, + scaling_factor, + dropout_probability, + is_training, + max_segments_per_seq, + window_size, + context_checkpoint_name=context_checkpoint_name, + ) From 317f15202bfebf26110ca5e0292361d4ec85a63b Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Tue, 21 Apr 2026 16:59:27 +0000 Subject: [PATCH 16/21] commit_message_stage_f_tests_ci.txt --- ci/jax.sh | 6 +- tests/jax/test_fused_attn.py | 243 ++++++++++++++++++++++++++++++----- 2 files changed, 217 insertions(+), 32 deletions(-) diff --git a/ci/jax.sh b/ci/jax.sh index f048492ba..124215211 100755 --- a/ci/jax.sh +++ b/ci/jax.sh @@ -57,9 +57,9 @@ run_test_config() { export NVTE_JAX_UNITTEST_LEVEL=L0 # this env variable controls parameters set for some tests run_default_fa 1 test_custom_call_compute.py run_default_fa 1 test_functions.py - run 1 test_fused_attn.py -k 'not test_ck_unfused_smallseq_backend' # skip smallseq in normal flow - XLA_FLAGS='--xla_gpu_graph_level=0' run 1 test_fused_attn.py -k 'test_ck_unfused_smallseq_backend' # CK smallseq with GPU graph disabled - NVTE_CK_USES_FWD_V3=0 NVTE_CK_USES_BWD_V3=0 run_default_fa_lbl "v2" 3 test_fused_attn.py -k 'not test_ck_unfused_smallseq_backend' # Using FAv2 for forward and backward pass + run 1 test_fused_attn.py -k 'not test_fused_attn_small_seq_explicit_api' # skip smallseq in normal flow + XLA_FLAGS='--xla_gpu_enable_command_buffer=' run 1 test_fused_attn.py -k 'test_fused_attn_small_seq_explicit_api' # explicit small-seq API; empty command buffer (ROCm 7.2+) + NVTE_CK_USES_FWD_V3=0 NVTE_CK_USES_BWD_V3=0 run_default_fa_lbl "v2" 3 test_fused_attn.py -k 'not test_fused_attn_small_seq_explicit_api' # Using FAv2 for forward and backward pass run_default_fa 1 test_helper.py run_default_fa 1 test_layer.py #it effectevly always uses unfused attention run_default_fa 1 test_sanity_import.py diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 8e2684a1b..d0b8e29b9 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -26,7 +26,7 @@ from jax.typing import ArrayLike, DTypeLike from transformer_engine.jax import fp8_autocast -from transformer_engine.jax.cpp_extensions.misc import is_hip_extension +from transformer_engine.jax.cpp_extensions.misc import get_xla_flag, is_hip_extension from transformer_engine.jax.sharding import MeshResource from transformer_engine.jax.attention import ( AttnBiasType, @@ -36,6 +36,7 @@ reorder_causal_load_balancing, inverse_reorder_causal_load_balancing, fused_attn, + fused_attn_small_seq, make_swa_mask, SequenceDescriptor, CPStrategy, @@ -272,6 +273,35 @@ def customcall_fused_dpa( ) +def customcall_small_seq_dpa( + query, + key, + value, + bias, + sequence_descriptor, + dropout_rng, + **kwargs, +): + """TE ROCm small-seq explicit API (separate Q, K, V only).""" + small_kwargs = { + "attn_bias_type": kwargs["attn_bias_type"], + "attn_mask_type": kwargs["attn_mask_type"], + "scaling_factor": kwargs["scaling_factor"], + "dropout_probability": kwargs["dropout_probability"], + "is_training": kwargs["is_training"], + "qkv_layout": kwargs["qkv_layout"], + "max_segments_per_seq": kwargs["max_segments_per_seq"], + "window_size": kwargs.get("window_size"), + } + return fused_attn_small_seq( + (query, key, value), + bias, + sequence_descriptor, + dropout_rng, + **small_kwargs, + ).astype(query.dtype) + + class BiasShape(Enum): """ Enum class to represent the different bias shapes used in the fused attention. @@ -323,6 +353,9 @@ class FusedAttnRunner: cp_strategy: CPStrategy = CPStrategy.DEFAULT cp_load_balanced: bool = True + # THD segment layout for ROCm small-seq explicit API tests + use_small_seq_thd_setup: bool = False + # dictionary of expected collective comm bytes coll_count_ref: Optional[Dict[str, int]] = None @@ -330,11 +363,11 @@ class FusedAttnRunner: # generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases. def _get_max_segments_per_sequence(self): if self.qkv_layout.is_thd(): - if ( - 90400 <= get_cudnn_version() < 90500 - or ( is_hip_extension() and - os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") == "1") - ): + # Small-seq explicit API tests use fixed segment counts; skip +1 slack used + # by generic fused THD to probe runtime_segments < max_segments on cuDNN. + if self.use_small_seq_thd_setup: + return self.num_segments_per_seq + if 90400 <= get_cudnn_version() < 90500: return self.num_segments_per_seq else: # +1 for testing runtime_segments < max_segments @@ -423,9 +456,9 @@ def _check_configs(self): "the F16_arbitrary_seqlen backend." ) - def _setup_thd_segments_ck_smallseq(self, generate_random_segment_ids): + def _setup_thd_segments_small_seq(self, generate_random_segment_ids): """ - Build THD segment descriptors for the CK small-seq path (NVTE_FUSED_ATTN_CK_SMALLSEQ=1). + Build THD segment descriptors for ROCm small-seq cross-attention tests. Uses num_segments_per_seq = max_seqlen_q for both Q and KV. For Q: if max_seqlen_q == 1, uses a fixed layout (one token per batch, cu_seqlens [0,1,...,batch_size]); otherwise @@ -449,7 +482,25 @@ def _setup_thd_segments_ck_smallseq(self, generate_random_segment_ids): segment_ids_q, segment_pos_q, pad_q = generate_random_segment_ids( self.batch_size, self.max_seqlen_q, num_segments_per_seq, seed=42 ) - seqlens_q, offsets_q = get_seqlens_and_offsets(segment_ids_q) + # Compute seqlens/offsets directly instead of using get_seqlens_and_offsets. + # get_seqlens_and_offsets uses bincount(length=max_seqlen) which cannot capture + # segment IDs equal to max_seqlen (when num_segments == max_seqlen_q, segment + # IDs range from 1 to max_seqlen_q). The missing segment plus the appended + # sentinel causes _fix_len_take in impl() to leak entries across batches. + # Since each Q segment has exactly 1 token (max_segment_size = max_seqlen_q // + # num_segments_per_seq = 1), we build seqlens as all-ones with no sentinels. + seqlens_q = jnp.ones((self.batch_size, num_segments_per_seq), dtype=jnp.int32) + offsets_q = jnp.concatenate( + [ + jnp.tile( + jnp.arange(num_segments_per_seq, dtype=jnp.int32)[None, :], + (self.batch_size, 1), + ), + jnp.full((self.batch_size, 1), -1, dtype=jnp.int32), + ], + axis=1, + ) + min_segment_len = None if self.window_size is None else seqlens_q segment_ids_kv, segment_pos_kv, pad_kv = generate_random_segment_ids( @@ -595,7 +646,7 @@ def generate_random_segment_ids( return segment_ids, segment_pos, segment_pad if self.qkv_layout.is_thd(): - if is_hip_extension() and os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") == "1": + if self.use_small_seq_thd_setup: ( self.num_segments_per_seq, self.segment_ids_q, @@ -608,7 +659,7 @@ def generate_random_segment_ids( self.pad_kv, self.seqlens_kv, self.offsets_kv, - ) = self._setup_thd_segments_ck_smallseq(generate_random_segment_ids) + ) = self._setup_thd_segments_small_seq(generate_random_segment_ids) else: self.num_segments_per_seq = 2 self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids( @@ -1031,6 +1082,134 @@ def check_dqkv(primitive, reference, pad, idx): target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text() assert_equal_collectives(target_hlo, self.coll_count_ref) + def test_backward_small_seq_api(self): + """Backward test using fused_attn_small_seq (explicit ROCm API), not generic fused_attn.""" + self._setup_inputs() + + def grad_func(func, *args, cp_reverse_out=False, **kwargs): + gradient_multiplier = self.max_seqlen_q * self.num_heads_q + if self.attn_mask_type.is_causal(): + gradient_multiplier /= 10 + if not cp_reverse_out: + ret_valid = jnp.where( + self.pad_q[..., jnp.newaxis, jnp.newaxis], + 0, + func(*args, **kwargs), + ) + else: + ret_valid = jnp.where( + self.pad_q[..., jnp.newaxis, jnp.newaxis], + 0, + self.cp_inverse_reorder_fn(func(*args, **kwargs)), + ) + return ( + jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier + ).astype(self.dtype) + + args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng] + customcall_args = [ + jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding), + jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding), + jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding), + jax.device_put(self.bias, self.bias_sharding), + jax.device_put(self.sequence_desciptor, self.seq_desc_sharding), + jax.device_put(self.dropout_rng, self.dropout_rng_sharding), + ] + kwargs_small = { + "attn_bias_type": self.attn_bias_type, + "attn_mask_type": self.attn_mask_type, + "scaling_factor": self.scaling_factor, + "dropout_probability": self.dropout_prob, + "is_training": self.is_training, + "qkv_layout": self.qkv_layout, + "max_segments_per_seq": self._get_max_segments_per_sequence(), + "window_size": self.window_size, + } + + if self.bias_shape == BiasShape._1HSS: + arg_nums = (0, 1, 2, 3) + grad_shardings = ( + self.qkvo_sharding, + self.qkvo_sharding, + self.qkvo_sharding, + self.bias_sharding, + ) + else: + arg_nums = (0, 1, 2) + grad_shardings = (self.qkvo_sharding, self.qkvo_sharding, self.qkvo_sharding) + + jitted_primitive = jit( + value_and_grad( + lambda q, k, v, bias, *args: grad_func( + customcall_small_seq_dpa, + q, + k, + v, + bias, + *args, + cp_reverse_out=True, + **kwargs_small, + ), + arg_nums, + ), + in_shardings=( + self.qkvo_sharding, + self.qkvo_sharding, + self.qkvo_sharding, + self.bias_sharding, + self.seq_desc_sharding, + self.dropout_rng_sharding, + ), + out_shardings=(None, grad_shardings), + ) + kwargs_ref = { + "attn_bias_type": self.attn_bias_type, + "attn_mask_type": self.attn_mask_type, + "scaling_factor": self.scaling_factor, + "dropout_probability": self.dropout_prob, + "is_training": self.is_training, + "qkv_layout": self.qkv_layout, + "max_segments_per_seq": self._get_max_segments_per_sequence(), + "window_size": self.window_size, + "context_parallel_strategy": self.cp_strategy, + "context_parallel_causal_load_balanced": self.cp_load_balanced, + } + jitted_reference = jit( + value_and_grad( + lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs_ref), + arg_nums, + ) + ) + + with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + primitive_out, primitive_dgrad = jitted_primitive(*customcall_args) + + reference_out, reference_dgrad = jitted_reference(*args) + + if self.dropout_prob > 0.0: + return + + assert_allclose(primitive_out, reference_out, dtype=self.dtype) + + def check_dqkv(primitive, reference, pad, idx): + primitive_valid, primitive_invalid, reference_valid, reference_invalid = ( + _split_valid_and_invalid(primitive, reference, pad) + ) + assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype) + assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype) + assert_allclose(primitive_valid, reference_valid, dtype=self.dtype) + + primitive_dq, primitive_dk, primitive_dv = primitive_dgrad[:3] + reference_dq, reference_dk, reference_dv = reference_dgrad[:3] + + primitive_dq = self.cp_inverse_reorder_fn(primitive_dq) + primitive_dk = self.cp_inverse_reorder_fn(primitive_dk) + primitive_dv = self.cp_inverse_reorder_fn(primitive_dv) + + check_dqkv(primitive_dq, reference_dq, self.pad_q, 0) + check_dqkv(primitive_dk, reference_dk, self.pad_kv, 1) + check_dqkv(primitive_dv, reference_dv, self.pad_kv, 2) + @pytest.mark.parametrize( "attn_mask_type", @@ -1287,13 +1466,13 @@ def test_jax_new_rng(): runner.test_forward() -# ROCm CK small-seq varlen tests. +# ROCm small-seq varlen tests (explicit fused_attn_small_seq API). @pytest.fixture -def ck_smallseq_env(monkeypatch): - """Enable CK small-seq path and disable XLA GPU graphs for these tests.""" - if "xla_gpu_graph_level=0" not in os.environ.get("XLA_FLAGS", ""): - pytest.skip("Test must be run with XLA_FLAGS='--xla_gpu_graph_level=0'") - monkeypatch.setenv("NVTE_FUSED_ATTN_CK_SMALLSEQ", "1") +def xla_gpu_graph_disabled(): + if get_xla_flag("--xla_gpu_enable_command_buffer", default=None) != "": + pytest.skip( + "Test must set XLA_FLAGS with --xla_gpu_enable_command_buffer= " + ) yield @pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float16], ids=["BF16", "FP16"]) @@ -1306,20 +1485,27 @@ def ck_smallseq_env(monkeypatch): pytest.param(4000, 1, 8, 16, 16, 128, 128, id="4000-1-8-16-16-128-128"), pytest.param(4000, 1, 12, 16, 16, 128, 128, id="4000-1-12-16-16-128-128"), pytest.param(4000, 1, 16, 16, 16, 128, 128, id="4000-1-16-16-16-128-128"), - pytest.param(2048, 2, 4, 16, 16, 128, 128, id="seqpack-2048-2-4-16-16-128-128"), - pytest.param(2, 4096, 8192, 16, 16, 128, 128, id="seqpack-2-4096-8192-16-16-128-128"), + pytest.param(4000, 1, 2, 16, 16, 256, 256, id="4000-1-2-16-16-256-256"), + pytest.param(4000, 1, 4, 16, 16, 256, 256, id="4000-1-4-16-16-256-256"), + pytest.param(4000, 1, 6, 16, 16, 256, 256, id="4000-1-6-16-16-256-256"), + pytest.param(4000, 1, 8, 16, 16, 256, 256, id="4000-1-8-16-16-256-256"), + pytest.param(4000, 1, 12, 16, 16, 256, 256, id="4000-1-12-16-16-256-256"), + pytest.param(4000, 1, 16, 16, 16, 256, 256, id="4000-1-16-16-16-256-256"), + pytest.param(4000, 1, 2, 16, 16, 512, 512, id="4000-1-2-16-16-512-512"), + pytest.param(4000, 1, 4, 16, 16, 512, 512, id="4000-1-4-16-16-512-512"), + # pytest.param(2048, 2, 4, 16, 16, 128, 128, id="seqpack-2048-2-4-16-16-128-128"), + # pytest.param(2, 4096, 8192, 16, 16, 128, 128, id="seqpack-2-4096-8192-16-16-128-128"), ], ) @pytest.mark.skipif( - not is_hip_extension(), reason="CK unfused smallseq backend only available on AMD hardware" + not is_hip_extension(), reason="Small-seq explicit API only available on AMD hardware" ) -def test_ck_unfused_smallseq_backend( - ck_smallseq_env, b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype +def test_fused_attn_small_seq_explicit_api( + xla_gpu_graph_disabled, b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype ): """ - Test the CK unfused small-seq (varlen) path on ROCm: s_q=1, s_kv<=16, THD layout. - Uses THD_THD_THD (Q,K,V all THD). ck_smallseq_env sets NVTE_FUSED_ATTN_CK_SMALLSEQ=1 and - restores it after the test. + Test nvte_fused_attn_small_seq / fused_attn_small_seq: THD_THD_THD varlen cross-attention + (head dims 128, 256, or 512). """ runner = FusedAttnRunner( batch_size=b, @@ -1336,10 +1522,9 @@ def test_ck_unfused_smallseq_backend( dtype=dtype, is_training=True, qkv_layout=QKVLayout.THD_THD_THD, - bias_shape=None, + bias_shape=BiasShape._B1SS, window_size=None, seq_desc_format=SeqDescFormat.Seqlens, + use_small_seq_thd_setup=True, ) - runner._setup_inputs() - # runner.test_forward() - runner.test_backward() + runner.test_backward_small_seq_api() From 1104ee227603638d227c01928debc70cfc980394 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Thu, 23 Apr 2026 17:45:25 +0000 Subject: [PATCH 17/21] Fixed build issues --- transformer_engine/common/fused_attn_rocm/fused_attn.cpp | 4 ++-- transformer_engine/jax/csrc/extensions/attention.cpp | 6 ------ 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index dd48c03a7..2be2e0668 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -959,7 +959,7 @@ void nvte_fused_attn_small_seq_fwd( for (size_t i = 0; i < wkspace->data.shape.size(); ++i) { workspace_bytes *= wkspace->data.shape[i]; } - workspace_bytes *= nvte_dtype_size(wkspace->data.dtype); + workspace_bytes *= fused_attn_rocm::nvte_dtype_size(wkspace->data.dtype); fused_attn_rocm::fused_attn_small_seq_fwd( b, h_q, h_kv, max_seqlen_kv, d_qk, d_v, is_training, attn_scale, dropout, @@ -1040,7 +1040,7 @@ void nvte_fused_attn_small_seq_bwd( for (size_t i = 0; i < wkspace->data.shape.size(); ++i) { workspace_bytes *= wkspace->data.shape[i]; } - workspace_bytes *= nvte_dtype_size(wkspace->data.dtype); + workspace_bytes *= fused_attn_rocm::nvte_dtype_size(wkspace->data.dtype); NVTE_CHECK(workspace_bytes >= req_bytes, "nvte_fused_attn_small_seq_bwd: workspace too small."); fused_attn_rocm::fused_attn_small_seq_bwd( diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index dc65014e3..605513362 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -432,8 +432,6 @@ static void SmallSeqAttnForwardImpl( PrepareSmallSeqAttnForwardAuxTensors(&aux_output_tensors, input_batch, attn_heads, q_max_seqlen, kv_max_seqlen, dtype, softmax_aux, rng_state); - auto workspace_tensor = - TensorWrapper(workspace, std::vector{wkspace_size}, wkspace_dtype); auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector{1}, DType::kInt32); auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; @@ -481,9 +479,6 @@ static void SmallSeqAttnBackwardImpl( PrepareSmallSeqAttnBackwardAuxTensors(&aux_input_tensors, input_batch, attn_heads, q_max_seqlen, kv_max_seqlen, dtype, softmax_aux, rng_state); - auto workspace_tensor = - TensorWrapper(workspace, std::vector{wkspace_size}, wkspace_dtype); - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; @@ -910,7 +905,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(SmallSeqAttnForwardHandler, SmallSeqAttnForwardFFI .Arg() .Arg() .Arg() - .Arg() .RemainingArgs() .Ret() .Ret() From f4cc5fadd6588c473744902b946946816a8f72fd Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Fri, 24 Apr 2026 07:10:19 +0000 Subject: [PATCH 18/21] Fixed small-seq pytests --- tests/jax/test_fused_attn.py | 3 +- .../common/fused_attn_rocm/fused_attn.cpp | 43 +++++++++++++++++-- .../fused_attn_rocm/fused_attn_small_seq.cpp | 2 +- .../fused_attn_rocm/fused_attn_small_seq.h | 6 +-- transformer_engine/jax/attention.py | 2 +- .../jax/csrc/extensions/attention.cpp | 6 ++- 6 files changed, 52 insertions(+), 10 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index d0b8e29b9..bf50a3bd4 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -526,7 +526,8 @@ def _setup_thd_segments_small_seq(self, generate_random_segment_ids): ) def _setup_inputs(self): - self._check_configs() + if not self.use_small_seq_thd_setup: + self._check_configs() # Create a mesh for distributed tests self.devices = np.asarray(jax.devices()[: self.number_of_devices]).reshape(*self.mesh_shape) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index 2be2e0668..2860eddbe 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -4,6 +4,7 @@ * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ +#include #include #include #include @@ -894,6 +895,34 @@ size_t nvte_fused_attn_small_seq_bwd_workspace_size(size_t batch, size_t attn_he batch, attn_heads, max_seqlen_kv, static_cast(dtype)); } +namespace { + +// Validate runtime max s_q == 1 and max s_kv in [2, 16]; returns runtime max KV length. +size_t nvte_assert_small_seq_runtime_max_seqlen(uint64_t b, const void *dev_ptr_cu_seqlens_q, + const void *dev_ptr_cu_seqlens_kv, void *workspace, + size_t workspace_bytes, const char *log_tag, + cudaStream_t stream) { + constexpr size_t runtime_seqlen_bytes = sizeof(uint64_t); + NVTE_CHECK(workspace_bytes >= runtime_seqlen_bytes, log_tag, + "workspace too small to compute runtime max seqlen (need at least ", runtime_seqlen_bytes, + " bytes)."); + const size_t runtime_s_q = static_cast(ck_fused_attn::get_runtime_max_seqlen( + b, dev_ptr_cu_seqlens_q, nullptr, workspace, stream)); + const size_t runtime_s_kv = static_cast(ck_fused_attn::get_runtime_max_seqlen( + b, dev_ptr_cu_seqlens_kv, nullptr, workspace, stream)); + if (const char *env_ck = std::getenv("NVTE_LOG_CK_CONFIG"); + env_ck != nullptr && std::string(env_ck) == "1") { + std::cout << std::endl << log_tag << "b=" << b << ", runtime_max_seqlen_q=" << runtime_s_q + << ", runtime_max_seqlen_kv=" << runtime_s_kv << std::endl; + } + NVTE_CHECK(runtime_s_q == 1 && runtime_s_kv >= 2 && runtime_s_kv <= 16, log_tag, + "small-seq requires runtime s_q==1 and s_kv in [2,16]; got runtime_s_q=", runtime_s_q, + ", runtime_s_kv=", runtime_s_kv, "."); + return runtime_s_kv; +} + +} // namespace + void nvte_fused_attn_small_seq_fwd( const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, @@ -946,7 +975,7 @@ void nvte_fused_attn_small_seq_fwd( void *attn_weights_buf = softmax_aux_tensor->data.dptr; if (wkspace->data.dptr == nullptr) { - wkspace->data.shape = {1}; + wkspace->data.shape = {sizeof(uint64_t)}; wkspace->data.dtype = DType::kByte; return; } @@ -961,8 +990,12 @@ void nvte_fused_attn_small_seq_fwd( } workspace_bytes *= fused_attn_rocm::nvte_dtype_size(wkspace->data.dtype); + const size_t runtime_max_seqlen_kv = nvte_assert_small_seq_runtime_max_seqlen( + static_cast(b), input_cu_seqlens_q->data.dptr, input_cu_seqlens_kv->data.dptr, + wkspace->data.dptr, workspace_bytes, "attn_fwd(small-seq kernel): ", stream); + fused_attn_rocm::fused_attn_small_seq_fwd( - b, h_q, h_kv, max_seqlen_kv, d_qk, d_v, is_training, attn_scale, dropout, + b, h_q, h_kv, runtime_max_seqlen_kv, d_qk, d_v, is_training, attn_scale, dropout, input_Q->data.dptr, input_K->data.dptr, input_V->data.dptr, output_O->data.dptr, attn_weights_buf, input_cu_seqlens_kv->data.dptr, input_cu_seqlens_kv_padded->data.dptr, dev_ptr_seed, dev_ptr_offset, input_Q->data.dtype, wkspace->data.dptr, &workspace_bytes, @@ -1043,8 +1076,12 @@ void nvte_fused_attn_small_seq_bwd( workspace_bytes *= fused_attn_rocm::nvte_dtype_size(wkspace->data.dtype); NVTE_CHECK(workspace_bytes >= req_bytes, "nvte_fused_attn_small_seq_bwd: workspace too small."); + const size_t runtime_max_seqlen_kv = nvte_assert_small_seq_runtime_max_seqlen( + static_cast(b), input_cu_seqlens_q->data.dptr, input_cu_seqlens_kv->data.dptr, + wkspace->data.dptr, workspace_bytes, "attn_bwd(small-seq kernel): ", stream); + fused_attn_rocm::fused_attn_small_seq_bwd( - b, h_q, h_kv, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, input_Q->data.dptr, + b, h_q, h_kv, runtime_max_seqlen_kv, d_qk, d_v, attn_scale, dropout, input_Q->data.dptr, input_K->data.dptr, input_V->data.dptr, input_O->data.dptr, input_dO->data.dptr, attn_weights_tensor->data.dptr, output_dQ->data.dptr, output_dK->data.dptr, output_dV->data.dptr, input_cu_seqlens_kv->data.dptr, diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_small_seq.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_small_seq.cpp index 223ce8981..7996e910e 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_small_seq.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_small_seq.cpp @@ -1070,7 +1070,7 @@ bool is_small_seq_attn_supported( // max_seqlen_q / max_seqlen_kv are compile-time or padded upper bounds; actual varlen // lengths are in cu_seqlens / offsets. Do not reject here based on those maxima — callers // must ensure runtime lengths and the launch-time max_seqlen_kv passed to fwd/bwd match - // the small-seq kernel contract (e.g. dispatch cases 2..16, effective q len 1). + // the small-seq kernel contract. (void)max_seqlen_q; (void)max_seqlen_kv; diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_small_seq.h b/transformer_engine/common/fused_attn_rocm/fused_attn_small_seq.h index f3abd6bbd..c9e2cf134 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_small_seq.h +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_small_seq.h @@ -41,8 +41,7 @@ size_t fused_attn_small_seq_bwd_workspace_size(size_t b, /** Forward: Q,K,V -> O; attention weights written to attn_weights_buffer (same as output_S). * attn_weights_buffer is also used as internal workspace (scores then overwritten by attn - * weights). No separate workspace required for the launcher; caller may use workspace for - * get_runtime_max_seqlen (8 bytes). */ + * weights). */ void fused_attn_small_seq_fwd(size_t b, size_t h_q, size_t h_kv, @@ -67,7 +66,8 @@ void fused_attn_small_seq_fwd(size_t b, cudaStream_t stream); /** Backward: dO, O, attn_weights -> dQ, dK, dV. attn_weights is the buffer from forward - * (output_S). workspace must be at least fused_attn_small_seq_bwd_workspace_size. */ + * (output_S). workspace must be at least fused_attn_small_seq_bwd_workspace_size. + * max_seqlen_kv is the runtime max KV length when invoked from nvte_fused_attn_small_seq_bwd. */ void fused_attn_small_seq_bwd(size_t b, size_t h_q, size_t h_kv, diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 7d3beb182..f957273b3 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -1003,7 +1003,7 @@ def _fused_attn_bwd_rule( @partial( jax.custom_vjp, - nondiff_argnums=(4, 5, 6, 7, 8, 9, 10, 11), + nondiff_argnums=(4, 5, 6, 7, 8, 9, 10, 11, 12), ) def _fused_attn_small_seq( qkv: Tuple[jnp.ndarray, ...], diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 605513362..2fb881e7b 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -388,7 +388,11 @@ pybind11::tuple GetSmallSeqAttnForwardWorkspaceSizes( "GetSmallSeqAttnForwardWorkspaceSizes: configuration not supported."); NVTE_CHECK(bias_batch == 0 && bias_heads == 0, "GetSmallSeqAttnForwardWorkspaceSizes: bias not supported for small-seq."); - TensorWrapper query_workspace_tensor(nullptr, std::vector{1}, DType::kByte); + // At least 8 bytes: nvte_fused_attn_small_seq_fwd uses the start of workspace for + // ck_fused_attn::get_runtime_max_seqlen + constexpr size_t k_small_seq_runtime_probe_bytes = 8; + TensorWrapper query_workspace_tensor(nullptr, std::vector{k_small_seq_runtime_probe_bytes}, + DType::kByte); return pybind11::make_tuple(MakeShapeVector(query_workspace_tensor.shape()), query_workspace_tensor.dtype()); } From 493b7b4d54dd1577a27501c2f8972307bea11762 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Fri, 24 Apr 2026 07:34:21 +0000 Subject: [PATCH 19/21] Added seq-packing pytests for small-seq kernels --- tests/jax/test_fused_attn.py | 4 ++-- transformer_engine/jax/cpp_extensions/attention.py | 9 ++++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index bf50a3bd4..37c0fd77c 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -1494,8 +1494,8 @@ def xla_gpu_graph_disabled(): pytest.param(4000, 1, 16, 16, 16, 256, 256, id="4000-1-16-16-16-256-256"), pytest.param(4000, 1, 2, 16, 16, 512, 512, id="4000-1-2-16-16-512-512"), pytest.param(4000, 1, 4, 16, 16, 512, 512, id="4000-1-4-16-16-512-512"), - # pytest.param(2048, 2, 4, 16, 16, 128, 128, id="seqpack-2048-2-4-16-16-128-128"), - # pytest.param(2, 4096, 8192, 16, 16, 128, 128, id="seqpack-2-4096-8192-16-16-128-128"), + pytest.param(2048, 2, 4, 16, 16, 128, 128, id="seqpack-2048-2-4-16-16-128-128"), + pytest.param(2, 4096, 8192, 16, 16, 128, 128, id="seqpack-2-4096-8192-16-16-128-128"), ], ) @pytest.mark.skipif( diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 18688f56c..aed1e3ab8 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1528,7 +1528,7 @@ def abstract( assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype ( - batch_shape, + _, # logical batch_shape; small-seq NVTE batch is cu_seqlens_q.shape[0]-1 q_max_seqlen, kv_max_seqlen, attn_heads, @@ -1539,9 +1539,12 @@ def abstract( bias_batch = bias_heads = 0 deterministic = not FusedAttnHelper.is_non_deterministic_allowed() - input_batch = reduce(operator.mul, batch_shape) + # NVTE uses b = cu_seqlens_q.shape[0] - 1 (one packed segment per slot), not + # reduce(batch_shape). E.g. seqpack with max_seqlen_q>1 yields cu length + # batch*segments+1 while Q still has leading logical batch only. + small_seq_workspace_batch = q_seqlen_or_cu_seqlen_aval.shape[0] - 1 wkspace_shape, wkspace_dtype = transformer_engine_jax.get_small_seq_attn_bwd_workspace_sizes( - input_batch, + small_seq_workspace_batch, bias_batch, q_max_seqlen, kv_max_seqlen, From 98bde78bb9c1d44654c72868a9218e2d75a5b1f7 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Sat, 25 Apr 2026 02:11:30 +0000 Subject: [PATCH 20/21] Addressed reviews --- .../fused_attn_rocm/fused_attn_small_seq.cpp | 34 ++++++++----------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_small_seq.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_small_seq.cpp index 7996e910e..469279bb3 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_small_seq.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_small_seq.cpp @@ -1022,9 +1022,11 @@ bool is_small_seq_attn_supported( log = (env_p != nullptr && std::string(env_p) == "1"); } - if (num_gqa_groups == 0 || num_attn_heads % num_gqa_groups != 0) { + // CK small-seq path does not enable GQA/MQA yet + if (num_gqa_groups == 0 || num_attn_heads != num_gqa_groups) { if (log) { - std::cout << "small-seq: heads must be divisible by GQA groups" << std::endl; + std::cout << "small-seq: GQA/MQA not supported; require num_attn_heads == num_kv_heads" + << std::endl; } return false; } @@ -1051,15 +1053,6 @@ bool is_small_seq_attn_supported( return false; } - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - if (layout_group != NVTE_QKV_Layout_Group::NVTE_HD_HD_HD || - nvte_get_qkv_format(qkv_layout) != NVTE_QKV_Format::NVTE_THD) { - if (log) { - std::cout << "small-seq: requires THD separate Q, K, V (THD_THD_THD)" << std::endl; - } - return false; - } - if (qkv_layout != NVTE_QKV_Layout::NVTE_THD_THD_THD) { if (log) { std::cout << "small-seq: layout must be NVTE_THD_THD_THD" << std::endl; @@ -1107,15 +1100,18 @@ bool is_small_seq_attn_supported( return false; } - if (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) { - if (!((window_size_left == -1 && window_size_right == -1) || - (window_size_left >= 0 && window_size_right >= 0))) { - if (log) { - std::cout << "small-seq: invalid window size for mask type" << std::endl; - } - return false; + // Small-seq kernels never consume window bounds + // Full attention only: (-1,-1), or (-1,0) which nvte_fused_attn_small_seq_* normalizes via + // check_set_window_size (NO_MASK / PADDING_MASK). is_small_seq_attn_supported may run without + // that pass (e.g. JAX workspace sizing), so treat (-1,0) the same here. + const bool full_window = + (window_size_left == -1 && window_size_right == -1) || + (window_size_left == -1 && window_size_right == 0); + if (!full_window) { + if (log) { + std::cout << "small-seq: sliding/local window attention not supported" << std::endl; } + return false; } return true; From b183024fab07d0db49254fdfa3ba554c6c8969be Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Mon, 27 Apr 2026 17:39:36 +0000 Subject: [PATCH 21/21] Fixed jax/test_fused_attn.py --- tests/jax/test_fused_attn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 42135a621..83d2c0006 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -759,10 +759,10 @@ def generate_random_segment_ids( else: # Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support min_segment_len = None - if ( - self.window_size is not None or self.attn_mask_type.is_bottom_right() - ): # SWA or BRCM requires kv_len >= q_len - min_segment_len = self.seqlens_q + if ( + self.window_size is not None or self.attn_mask_type.is_bottom_right() + ): # SWA or BRCM requires kv_len >= q_len + min_segment_len = self.seqlens_q self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids( self.batch_size, self.max_seqlen_kv,