-
Notifications
You must be signed in to change notification settings - Fork 28
[TE] Phase 2 of small-seq cross-attn integration: a separate cpp backend and a new jax api #542
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
10f7ee6
b3ef62c
db685c4
b6a5ee8
75f7cfa
c737072
4537cce
c6e0eae
366945e
d5afb6f
006edee
f8a5ce8
9dd9b7e
1c6ffd5
0198a0e
317f152
1104ee2
f4cc5fa
493b7b4
09ab963
98bde78
b183024
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,12 +4,14 @@ | |
| * License for AMD contributions = MIT. See LICENSE for more information | ||
| ************************************************************************/ | ||
|
|
||
| #include <ck_fused_attn/ck_fused_attn.hpp> | ||
| #include <iostream> | ||
| #include <string> | ||
| #include <tuple> | ||
| #include "transformer_engine/fused_attn.h" | ||
| #include "fused_attn_aotriton.h" | ||
| #include "fused_attn_ck.h" | ||
| #include "fused_attn_small_seq.h" | ||
| #include "../common.h" | ||
| #include "utils.h" | ||
|
|
||
|
|
@@ -894,6 +896,218 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso | |
| } | ||
| } | ||
|
|
||
| bool nvte_is_small_seq_attn_supported( | ||
| NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, | ||
| NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, | ||
| size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, | ||
| size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, | ||
| int64_t window_size_right) { | ||
| return transformer_engine::fused_attn_rocm::is_small_seq_attn_supported( | ||
| q_dtype, kv_dtype, qkv_layout, bias_type, attn_mask_type, dropout, num_attn_heads, | ||
| num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, | ||
| window_size_right); | ||
| } | ||
|
|
||
| size_t nvte_fused_attn_small_seq_bwd_workspace_size(size_t batch, size_t attn_heads, | ||
| size_t max_seqlen_kv, NVTEDType dtype) { | ||
| return transformer_engine::fused_attn_rocm::fused_attn_small_seq_bwd_workspace_size( | ||
| batch, attn_heads, max_seqlen_kv, static_cast<transformer_engine::DType>(dtype)); | ||
| } | ||
|
|
||
| namespace { | ||
|
|
||
| // Validate runtime max s_q == 1 and max s_kv in [2, 16]; returns runtime max KV length. | ||
| size_t nvte_assert_small_seq_runtime_max_seqlen(uint64_t b, const void *dev_ptr_cu_seqlens_q, | ||
| const void *dev_ptr_cu_seqlens_kv, void *workspace, | ||
| size_t workspace_bytes, const char *log_tag, | ||
| cudaStream_t stream) { | ||
| constexpr size_t runtime_seqlen_bytes = sizeof(uint64_t); | ||
| NVTE_CHECK(workspace_bytes >= runtime_seqlen_bytes, log_tag, | ||
| "workspace too small to compute runtime max seqlen (need at least ", runtime_seqlen_bytes, | ||
| " bytes)."); | ||
| const size_t runtime_s_q = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen( | ||
| b, dev_ptr_cu_seqlens_q, nullptr, workspace, stream)); | ||
| const size_t runtime_s_kv = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen( | ||
| b, dev_ptr_cu_seqlens_kv, nullptr, workspace, stream)); | ||
| if (const char *env_ck = std::getenv("NVTE_LOG_CK_CONFIG"); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At this API level we use |
||
| env_ck != nullptr && std::string(env_ck) == "1") { | ||
| std::cout << std::endl << log_tag << "b=" << b << ", runtime_max_seqlen_q=" << runtime_s_q | ||
| << ", runtime_max_seqlen_kv=" << runtime_s_kv << std::endl; | ||
| } | ||
| NVTE_CHECK(runtime_s_q == 1 && runtime_s_kv >= 2 && runtime_s_kv <= 16, log_tag, | ||
| "small-seq requires runtime s_q==1 and s_kv in [2,16]; got runtime_s_q=", runtime_s_q, | ||
| ", runtime_s_kv=", runtime_s_kv, "."); | ||
| return runtime_s_kv; | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| void nvte_fused_attn_small_seq_fwd( | ||
| const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, | ||
| NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, | ||
| const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, | ||
| const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, | ||
| const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, | ||
| size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, | ||
| NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, | ||
| int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, | ||
| cudaStream_t stream) { | ||
| NVTE_API_CALL(nvte_fused_attn_small_seq_fwd); | ||
| using namespace transformer_engine; | ||
| const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); | ||
| const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); | ||
| const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); | ||
| const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); | ||
| const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we even support dropout with this rng_state? |
||
| const Tensor *input_Q = convertNVTETensorCheck(Q); | ||
| const Tensor *input_K = convertNVTETensorCheck(K); | ||
| const Tensor *input_V = convertNVTETensorCheck(V); | ||
| convertNVTETensorCheck(Bias); | ||
| convertNVTETensorCheck(S); | ||
| Tensor *output_O = convertNVTETensorCheck(O); | ||
| Tensor *wkspace = convertNVTETensorCheck(workspace); | ||
|
|
||
| auto ndim = input_Q->data.shape.size(); | ||
| size_t b = input_cu_seqlens_q->data.shape[0] - 1; | ||
| size_t h_q = input_Q->data.shape[ndim - 2]; | ||
| size_t h_kv = input_K->data.shape[ndim - 2]; | ||
| size_t d_qk = input_Q->data.shape[ndim - 1]; | ||
| size_t d_v = input_V->data.shape[ndim - 1]; | ||
|
|
||
| const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype); | ||
| const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype); | ||
|
|
||
| log_fused_attn_config(__FUNCTION__, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, | ||
| dropout, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, | ||
| window_size_left, window_size_right); | ||
|
|
||
| std::tie(window_size_left, window_size_right) = | ||
| check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); | ||
|
|
||
| NVTE_CHECK( | ||
| fused_attn_rocm::is_small_seq_attn_supported( | ||
| Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, | ||
| max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right), | ||
| "nvte_fused_attn_small_seq_fwd: configuration not supported for small-seq path."); | ||
|
|
||
| Tensor *softmax_aux_tensor = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); | ||
| void *attn_weights_buf = softmax_aux_tensor->data.dptr; | ||
|
|
||
| if (wkspace->data.dptr == nullptr) { | ||
| wkspace->data.shape = {sizeof(uint64_t)}; | ||
| wkspace->data.dtype = DType::kByte; | ||
| return; | ||
| } | ||
|
wangye805 marked this conversation as resolved.
|
||
|
|
||
| void *dev_ptr_seed = input_rng_state->data.dptr; | ||
| void *dev_ptr_offset = | ||
| reinterpret_cast<void *>(reinterpret_cast<uint64_t *>(input_rng_state->data.dptr) + 1); | ||
|
|
||
| size_t workspace_bytes = 1; | ||
| for (size_t i = 0; i < wkspace->data.shape.size(); ++i) { | ||
| workspace_bytes *= wkspace->data.shape[i]; | ||
| } | ||
| workspace_bytes *= fused_attn_rocm::nvte_dtype_size(wkspace->data.dtype); | ||
|
|
||
| const size_t runtime_max_seqlen_kv = nvte_assert_small_seq_runtime_max_seqlen( | ||
| static_cast<uint64_t>(b), input_cu_seqlens_q->data.dptr, input_cu_seqlens_kv->data.dptr, | ||
| wkspace->data.dptr, workspace_bytes, "attn_fwd(small-seq kernel): ", stream); | ||
|
|
||
| fused_attn_rocm::fused_attn_small_seq_fwd( | ||
| b, h_q, h_kv, runtime_max_seqlen_kv, d_qk, d_v, is_training, attn_scale, dropout, | ||
| input_Q->data.dptr, input_K->data.dptr, input_V->data.dptr, output_O->data.dptr, | ||
| attn_weights_buf, input_cu_seqlens_kv->data.dptr, input_cu_seqlens_kv_padded->data.dptr, | ||
| dev_ptr_seed, dev_ptr_offset, input_Q->data.dtype, wkspace->data.dptr, &workspace_bytes, | ||
| stream); | ||
| (void)page_table_k; | ||
| (void)page_table_v; | ||
| } | ||
|
|
||
| void nvte_fused_attn_small_seq_bwd( | ||
| const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor O, | ||
| const NVTETensor dO, const NVTETensor S, NVTETensor dP, | ||
| const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, NVTETensor dV, | ||
| NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, | ||
| const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, | ||
| size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, | ||
| NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, | ||
| int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, | ||
| cudaStream_t stream) { | ||
| NVTE_API_CALL(nvte_fused_attn_small_seq_bwd); | ||
| (void)deterministic; | ||
| using namespace transformer_engine; | ||
| const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); | ||
| const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); | ||
| const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); | ||
| const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); | ||
| const Tensor *input_Q = convertNVTETensorCheck(Q); | ||
| const Tensor *input_K = convertNVTETensorCheck(K); | ||
| const Tensor *input_V = convertNVTETensorCheck(V); | ||
| const Tensor *input_O = convertNVTETensorCheck(O); | ||
| const Tensor *input_dO = convertNVTETensorCheck(dO); | ||
| convertNVTETensorCheck(S); | ||
| convertNVTETensorCheck(dP); | ||
| Tensor *output_dQ = convertNVTETensorCheck(dQ); | ||
| Tensor *output_dK = convertNVTETensorCheck(dK); | ||
| Tensor *output_dV = convertNVTETensorCheck(dV); | ||
| convertNVTETensorCheck(dBias); | ||
| Tensor *wkspace = convertNVTETensorCheck(workspace); | ||
|
|
||
| const Tensor *attn_weights_tensor = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); | ||
| convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); | ||
|
|
||
| auto ndim = input_Q->data.shape.size(); | ||
| size_t b = input_cu_seqlens_q->data.shape[0] - 1; | ||
| size_t h_q = input_Q->data.shape[ndim - 2]; | ||
| size_t h_kv = input_K->data.shape[ndim - 2]; | ||
| size_t d_qk = input_Q->data.shape[ndim - 1]; | ||
| size_t d_v = input_V->data.shape[ndim - 1]; | ||
|
|
||
| const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype); | ||
| const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype); | ||
|
|
||
| log_fused_attn_config(__FUNCTION__, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, | ||
| dropout, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, | ||
| window_size_left, window_size_right); | ||
|
|
||
| std::tie(window_size_left, window_size_right) = | ||
| check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); | ||
|
|
||
| NVTE_CHECK( | ||
| fused_attn_rocm::is_small_seq_attn_supported( | ||
| Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, | ||
| max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right), | ||
| "nvte_fused_attn_small_seq_bwd: configuration not supported for small-seq path."); | ||
|
|
||
| size_t req_bytes = fused_attn_rocm::fused_attn_small_seq_bwd_workspace_size( | ||
| b, h_q, max_seqlen_kv, input_Q->data.dtype); | ||
|
|
||
| if (wkspace->data.dptr == nullptr) { | ||
| wkspace->data.shape = {req_bytes}; | ||
| wkspace->data.dtype = DType::kByte; | ||
| return; | ||
| } | ||
|
|
||
| size_t workspace_bytes = 1; | ||
| for (size_t i = 0; i < wkspace->data.shape.size(); ++i) { | ||
| workspace_bytes *= wkspace->data.shape[i]; | ||
| } | ||
| workspace_bytes *= fused_attn_rocm::nvte_dtype_size(wkspace->data.dtype); | ||
| NVTE_CHECK(workspace_bytes >= req_bytes, "nvte_fused_attn_small_seq_bwd: workspace too small."); | ||
|
Comment on lines
+1095
to
+1096
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: now we don't have mixed old ck + new small_seq flow, will we still get more workspace than we need? |
||
|
|
||
| const size_t runtime_max_seqlen_kv = nvte_assert_small_seq_runtime_max_seqlen( | ||
| static_cast<uint64_t>(b), input_cu_seqlens_q->data.dptr, input_cu_seqlens_kv->data.dptr, | ||
| wkspace->data.dptr, workspace_bytes, "attn_bwd(small-seq kernel): ", stream); | ||
|
|
||
| fused_attn_rocm::fused_attn_small_seq_bwd( | ||
| b, h_q, h_kv, runtime_max_seqlen_kv, d_qk, d_v, attn_scale, dropout, input_Q->data.dptr, | ||
| input_K->data.dptr, input_V->data.dptr, input_O->data.dptr, input_dO->data.dptr, | ||
| attn_weights_tensor->data.dptr, output_dQ->data.dptr, output_dK->data.dptr, | ||
| output_dV->data.dptr, input_cu_seqlens_kv->data.dptr, | ||
| input_cu_seqlens_kv_padded->data.dptr, input_Q->data.dtype, wkspace->data.dptr, | ||
| &workspace_bytes, stream); | ||
| } | ||
|
|
||
| uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t max_batch_size, | ||
| cudaStream_t stream) { | ||
| NVTE_API_CALL(nvte_get_runtime_num_segments); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This calls
ck_fused_attn::get_runtime_max_seqlenunconditionally and breaks AOTriton-only builds.