Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
10f7ee6
Integrate CK varlen cross attention for small-seq (s_q=1, s_kv<=16)
VeeraRajasekhar Feb 24, 2026
b3ef62c
Addressed comments
VeeraRajasekhar Feb 25, 2026
db685c4
Addressed reviews
VeeraRajasekhar Feb 27, 2026
b6a5ee8
Guard CK small-seq behind NVTE_FUSED_ATTN_CK_SMALLSEQ=1; add FP16 sup…
VeeraRajasekhar Feb 27, 2026
75f7cfa
ROCm CK unfused small-seq: env guard, FP16, tests, and logging
VeeraRajasekhar Feb 27, 2026
c737072
Disabled xla_gpu_graph_level
VeeraRajasekhar Feb 27, 2026
4537cce
Updated XLA_FLAGS in ci/jax.sh
VeeraRajasekhar Feb 27, 2026
c6e0eae
Adressed comments
VeeraRajasekhar Mar 3, 2026
366945e
Refactored input generation for smallseq flow
VeeraRajasekhar Mar 3, 2026
d5afb6f
[ROCm] apply more strict filtering for just cross-attn and fix the so…
wangye805 Mar 13, 2026
006edee
Refactor small-seq kernels and add NVTE hooks for explicit fused-attn…
VeeraRajasekhar Apr 15, 2026
f8a5ce8
feat(jax): C++ FFI bridge for ROCm small-seq attention
VeeraRajasekhar Apr 17, 2026
9dd9b7e
feat(jax): pybind registration for ROCm small-seq attention FFI
VeeraRajasekhar Apr 17, 2026
1c6ffd5
feat(jax): XLA primitives for ROCm small-seq attention
VeeraRajasekhar Apr 17, 2026
0198a0e
feat(jax): public fused_attn_small_seq API with custom_vjp
VeeraRajasekhar Apr 21, 2026
317f152
commit_message_stage_f_tests_ci.txt
VeeraRajasekhar Apr 21, 2026
1104ee2
Fixed build issues
VeeraRajasekhar Apr 23, 2026
f4cc5fa
Fixed small-seq pytests
VeeraRajasekhar Apr 24, 2026
493b7b4
Added seq-packing pytests for small-seq kernels
VeeraRajasekhar Apr 24, 2026
09ab963
Merge branch 'dev' of https://github.com/ROCm/TransformerEngine into …
VeeraRajasekhar Apr 24, 2026
98bde78
Addressed reviews
VeeraRajasekhar Apr 25, 2026
b183024
Fixed jax/test_fused_attn.py
VeeraRajasekhar Apr 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions ci/jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ run_test_config() {
export NVTE_JAX_UNITTEST_LEVEL=L0 # this env variable controls parameters set for some tests
run_default_fa 1 test_custom_call_compute.py
run_default_fa 1 test_functions.py
run 1 test_fused_attn.py
NVTE_CK_USES_FWD_V3=0 NVTE_CK_USES_BWD_V3=0 run_default_fa_lbl "v2" 3 test_fused_attn.py # Using FAv2 for forward and backward pass
run 1 test_fused_attn.py -k 'not test_fused_attn_small_seq_explicit_api' # skip smallseq in normal flow
XLA_FLAGS='--xla_gpu_enable_command_buffer=' run 1 test_fused_attn.py -k 'test_fused_attn_small_seq_explicit_api' # explicit small-seq API;
NVTE_CK_USES_FWD_V3=0 NVTE_CK_USES_BWD_V3=0 run_default_fa_lbl "v2" 3 test_fused_attn.py -k 'not test_fused_attn_small_seq_explicit_api' # Using FAv2 for forward and backward pass
run_default_fa 1 test_layer.py # it effectively always uses unfused attention
run_default_fa 1 test_sanity_import.py
run_default_fa 1 test_softmax.py
Expand Down
392 changes: 368 additions & 24 deletions tests/jax/test_fused_attn.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ else()
list(APPEND transformer_engine_cuda_sources
fused_attn_rocm/fused_attn_aotriton.cpp
fused_attn_rocm/fused_attn_ck.cpp
fused_attn_rocm/fused_attn_small_seq.cpp
fused_attn_rocm/utils.cpp)
endif()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,12 @@ hipError_t ck_attn_varlen_bwd(
int how_v3_bf16_cvt,
hipStream_t stream);

uint64_t get_runtime_max_seqlen(uint64_t b,
const void* cu_seqlen_ptr,
const void* cu_seqlen_padded_ptr,
void* workspace,
hipStream_t stream);

}//namespace ck_fused_attn
#endif // CK_FUSED_ATTN_H

214 changes: 214 additions & 0 deletions transformer_engine/common/fused_attn_rocm/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
* License for AMD contributions = MIT. See LICENSE for more information
************************************************************************/

#include <ck_fused_attn/ck_fused_attn.hpp>
#include <iostream>
#include <string>
#include <tuple>
#include "transformer_engine/fused_attn.h"
#include "fused_attn_aotriton.h"
#include "fused_attn_ck.h"
#include "fused_attn_small_seq.h"
#include "../common.h"
#include "utils.h"

Expand Down Expand Up @@ -894,6 +896,218 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
}
}

bool nvte_is_small_seq_attn_supported(
NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout,
size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right) {
return transformer_engine::fused_attn_rocm::is_small_seq_attn_supported(
q_dtype, kv_dtype, qkv_layout, bias_type, attn_mask_type, dropout, num_attn_heads,
num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left,
window_size_right);
}

size_t nvte_fused_attn_small_seq_bwd_workspace_size(size_t batch, size_t attn_heads,
size_t max_seqlen_kv, NVTEDType dtype) {
return transformer_engine::fused_attn_rocm::fused_attn_small_seq_bwd_workspace_size(
batch, attn_heads, max_seqlen_kv, static_cast<transformer_engine::DType>(dtype));
}

namespace {

// Validate runtime max s_q == 1 and max s_kv in [2, 16]; returns runtime max KV length.
size_t nvte_assert_small_seq_runtime_max_seqlen(uint64_t b, const void *dev_ptr_cu_seqlens_q,
const void *dev_ptr_cu_seqlens_kv, void *workspace,
size_t workspace_bytes, const char *log_tag,
cudaStream_t stream) {
constexpr size_t runtime_seqlen_bytes = sizeof(uint64_t);
NVTE_CHECK(workspace_bytes >= runtime_seqlen_bytes, log_tag,
"workspace too small to compute runtime max seqlen (need at least ", runtime_seqlen_bytes,
" bytes).");
const size_t runtime_s_q = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen(
b, dev_ptr_cu_seqlens_q, nullptr, workspace, stream));
const size_t runtime_s_kv = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen(
b, dev_ptr_cu_seqlens_kv, nullptr, workspace, stream));
Comment on lines +928 to +931
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This calls ck_fused_attn::get_runtime_max_seqlen unconditionally and breaks AOTriton-only builds.

if (const char *env_ck = std::getenv("NVTE_LOG_CK_CONFIG");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this API level we use NVTE_LOG_FUSED_ATTN_CONFIG

env_ck != nullptr && std::string(env_ck) == "1") {
std::cout << std::endl << log_tag << "b=" << b << ", runtime_max_seqlen_q=" << runtime_s_q
<< ", runtime_max_seqlen_kv=" << runtime_s_kv << std::endl;
}
NVTE_CHECK(runtime_s_q == 1 && runtime_s_kv >= 2 && runtime_s_kv <= 16, log_tag,
"small-seq requires runtime s_q==1 and s_kv in [2,16]; got runtime_s_q=", runtime_s_q,
", runtime_s_kv=", runtime_s_kv, ".");
return runtime_s_kv;
}

} // namespace

void nvte_fused_attn_small_seq_fwd(
const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias,
NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q,
size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_attn_small_seq_fwd);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv);
const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded);
const Tensor *input_rng_state = convertNVTETensorCheck(rng_state);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we even support dropout with this rng_state?

const Tensor *input_Q = convertNVTETensorCheck(Q);
const Tensor *input_K = convertNVTETensorCheck(K);
const Tensor *input_V = convertNVTETensorCheck(V);
convertNVTETensorCheck(Bias);
convertNVTETensorCheck(S);
Tensor *output_O = convertNVTETensorCheck(O);
Tensor *wkspace = convertNVTETensorCheck(workspace);

auto ndim = input_Q->data.shape.size();
size_t b = input_cu_seqlens_q->data.shape[0] - 1;
size_t h_q = input_Q->data.shape[ndim - 2];
size_t h_kv = input_K->data.shape[ndim - 2];
size_t d_qk = input_Q->data.shape[ndim - 1];
size_t d_v = input_V->data.shape[ndim - 1];

const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);

log_fused_attn_config(__FUNCTION__, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type,
dropout, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v,
window_size_left, window_size_right);

std::tie(window_size_left, window_size_right) =
check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right));

NVTE_CHECK(
fused_attn_rocm::is_small_seq_attn_supported(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right),
"nvte_fused_attn_small_seq_fwd: configuration not supported for small-seq path.");

Tensor *softmax_aux_tensor = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
void *attn_weights_buf = softmax_aux_tensor->data.dptr;

if (wkspace->data.dptr == nullptr) {
wkspace->data.shape = {sizeof(uint64_t)};
wkspace->data.dtype = DType::kByte;
return;
}
Comment thread
wangye805 marked this conversation as resolved.

void *dev_ptr_seed = input_rng_state->data.dptr;
void *dev_ptr_offset =
reinterpret_cast<void *>(reinterpret_cast<uint64_t *>(input_rng_state->data.dptr) + 1);

size_t workspace_bytes = 1;
for (size_t i = 0; i < wkspace->data.shape.size(); ++i) {
workspace_bytes *= wkspace->data.shape[i];
}
workspace_bytes *= fused_attn_rocm::nvte_dtype_size(wkspace->data.dtype);

const size_t runtime_max_seqlen_kv = nvte_assert_small_seq_runtime_max_seqlen(
static_cast<uint64_t>(b), input_cu_seqlens_q->data.dptr, input_cu_seqlens_kv->data.dptr,
wkspace->data.dptr, workspace_bytes, "attn_fwd(small-seq kernel): ", stream);

fused_attn_rocm::fused_attn_small_seq_fwd(
b, h_q, h_kv, runtime_max_seqlen_kv, d_qk, d_v, is_training, attn_scale, dropout,
input_Q->data.dptr, input_K->data.dptr, input_V->data.dptr, output_O->data.dptr,
attn_weights_buf, input_cu_seqlens_kv->data.dptr, input_cu_seqlens_kv_padded->data.dptr,
dev_ptr_seed, dev_ptr_offset, input_Q->data.dtype, wkspace->data.dptr, &workspace_bytes,
stream);
(void)page_table_k;
(void)page_table_v;
}

void nvte_fused_attn_small_seq_bwd(
const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor O,
const NVTETensor dO, const NVTETensor S, NVTETensor dP,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, NVTETensor dV,
NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_attn_small_seq_bwd);
(void)deterministic;
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv);
const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded);
const Tensor *input_Q = convertNVTETensorCheck(Q);
const Tensor *input_K = convertNVTETensorCheck(K);
const Tensor *input_V = convertNVTETensorCheck(V);
const Tensor *input_O = convertNVTETensorCheck(O);
const Tensor *input_dO = convertNVTETensorCheck(dO);
convertNVTETensorCheck(S);
convertNVTETensorCheck(dP);
Tensor *output_dQ = convertNVTETensorCheck(dQ);
Tensor *output_dK = convertNVTETensorCheck(dK);
Tensor *output_dV = convertNVTETensorCheck(dV);
convertNVTETensorCheck(dBias);
Tensor *wkspace = convertNVTETensorCheck(workspace);

const Tensor *attn_weights_tensor = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);

auto ndim = input_Q->data.shape.size();
size_t b = input_cu_seqlens_q->data.shape[0] - 1;
size_t h_q = input_Q->data.shape[ndim - 2];
size_t h_kv = input_K->data.shape[ndim - 2];
size_t d_qk = input_Q->data.shape[ndim - 1];
size_t d_v = input_V->data.shape[ndim - 1];

const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);

log_fused_attn_config(__FUNCTION__, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type,
dropout, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v,
window_size_left, window_size_right);

std::tie(window_size_left, window_size_right) =
check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right));

NVTE_CHECK(
fused_attn_rocm::is_small_seq_attn_supported(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right),
"nvte_fused_attn_small_seq_bwd: configuration not supported for small-seq path.");

size_t req_bytes = fused_attn_rocm::fused_attn_small_seq_bwd_workspace_size(
b, h_q, max_seqlen_kv, input_Q->data.dtype);

if (wkspace->data.dptr == nullptr) {
wkspace->data.shape = {req_bytes};
wkspace->data.dtype = DType::kByte;
return;
}

size_t workspace_bytes = 1;
for (size_t i = 0; i < wkspace->data.shape.size(); ++i) {
workspace_bytes *= wkspace->data.shape[i];
}
workspace_bytes *= fused_attn_rocm::nvte_dtype_size(wkspace->data.dtype);
NVTE_CHECK(workspace_bytes >= req_bytes, "nvte_fused_attn_small_seq_bwd: workspace too small.");
Comment on lines +1095 to +1096
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: now we don't have mixed old ck + new small_seq flow, will we still get more workspace than we need?


const size_t runtime_max_seqlen_kv = nvte_assert_small_seq_runtime_max_seqlen(
static_cast<uint64_t>(b), input_cu_seqlens_q->data.dptr, input_cu_seqlens_kv->data.dptr,
wkspace->data.dptr, workspace_bytes, "attn_bwd(small-seq kernel): ", stream);

fused_attn_rocm::fused_attn_small_seq_bwd(
b, h_q, h_kv, runtime_max_seqlen_kv, d_qk, d_v, attn_scale, dropout, input_Q->data.dptr,
input_K->data.dptr, input_V->data.dptr, input_O->data.dptr, input_dO->data.dptr,
attn_weights_tensor->data.dptr, output_dQ->data.dptr, output_dK->data.dptr,
output_dV->data.dptr, input_cu_seqlens_kv->data.dptr,
input_cu_seqlens_kv_padded->data.dptr, input_Q->data.dtype, wkspace->data.dptr,
&workspace_bytes, stream);
}

uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t max_batch_size,
cudaStream_t stream) {
NVTE_API_CALL(nvte_get_runtime_num_segments);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1796,7 +1796,7 @@ void fused_attn_ck_fwd(
size_t max_tokens_q = std::accumulate((input_Q->data).shape.begin(), (input_Q->data).shape.end(), static_cast<size_t>(1), std::multiplies<size_t>())/h_q/d_qk;
size_t max_tokens_kv = std::accumulate((input_K->data).shape.begin(), (input_K->data).shape.end(), static_cast<size_t>(1), std::multiplies<size_t>())/h_kv/d_qk;

bool is_ragged = nvte_get_qkv_format(qkv_layout)==NVTE_QKV_Format::NVTE_THD;
bool is_ragged = nvte_get_qkv_format(qkv_layout)==NVTE_QKV_Format::NVTE_THD;
if (Aux_CTX_Tensors->size == 0) {
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Aux_CTX_Tensors->size = 3;
Expand Down Expand Up @@ -1851,7 +1851,6 @@ void fused_attn_ck_fwd(
bool is_padding = (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK);

fused_attn_ck_fwd_impl(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, bias_b, bias_h,
max_tokens_q, max_tokens_kv,
Expand Down
Loading
Loading