diff --git a/benchmarks/benchmark_flash_attention.py b/benchmarks/benchmark_flash_attention.py index 9624ba0c334..a97668298fb 100644 --- a/benchmarks/benchmark_flash_attention.py +++ b/benchmarks/benchmark_flash_attention.py @@ -77,7 +77,7 @@ def time_fwd_bwd(func, *args, **kwargs): dim = 2048 dropout_p = 0.0 -methods = (["Flash2", "Pytorch"] +methods = (["Flash2", "Pytorch", "Flash2Sink"] + (["Triton"] if attention_triton is not None else []) + (["xformers.c"] if xops is not None else []) + (["xformers.f"] if xops is not None else [])) @@ -106,6 +106,17 @@ def time_fwd_bwd(func, *args, **kwargs): time_f[config, "Flash2"] = f time_b[config, "Flash2"] = b + # FlashAttention 2 with sink + if "Flash2Sink" in methods: + qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, + device=device, dtype=dtype, requires_grad=True) + sink = torch.randn((nheads,), dtype=dtype, device=device, requires_grad=True) + f, b = time_fwd_bwd( + flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, learnable_sink=sink, repeats=repeats, verbose=False + ) + time_f[config, "Flash2Sink"] = f + time_b[config, "Flash2Sink"] = b + # PyTorch baseline if "Pytorch" in methods: try: diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index c0c0e42176c..9eb4c73a2b9 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -360,7 +360,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_mult int window_size_right, const float softcap, const bool return_softmax, - std::optional gen_) { + std::optional gen_, + std::optional learnable_sink_) { // Otherwise the kernel will be launched from cuda:0 device at::cuda::CUDAGuard device_guard{q.device()}; @@ -405,7 +406,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_mult // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza - const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value(); + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value() && !learnable_sink_.has_value(); const int ngroups = num_heads / num_heads_k; if (seqlenq_ngroups_swapped) { q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2); @@ -469,6 +470,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_mult softcap ); + // Keep references to these tensors to extend their lifetime at::Tensor softmax_lse_accum, out_accum; std::tie(softmax_lse_accum, out_accum) = set_params_splitkv( @@ -494,6 +496,18 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_mult set_params_alibi(params, alibi_slopes_, batch_size, num_heads); + at::Tensor learnable_sink; + if (learnable_sink_.has_value()) { + learnable_sink = learnable_sink_.value().to(torch::kFloat32); + TORCH_CHECK(params.alibi_slopes_ptr == nullptr, "Learnable sink and ALiBi slopes cannot be used together"); + CHECK_DEVICE(learnable_sink); CHECK_CONTIGUOUS(learnable_sink); + TORCH_CHECK(learnable_sink.stride(-1) == 1, "Learnable sink tensor must have contiguous last dimension"); + CHECK_SHAPE(learnable_sink, num_heads); + params.learnable_sink_ptr = learnable_sink.data_ptr(); + } else { + params.learnable_sink_ptr = nullptr; + } + if (seqlen_k > 0) { auto stream = at::cuda::getCurrentCUDAStream().stream(); run_mha_fwd(params, stream); @@ -515,7 +529,7 @@ std::vector mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. - std::optional &out_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. @@ -532,7 +546,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s int window_size_right, const float softcap, const bool return_softmax, - std::optional gen_) { + std::optional gen_, + std::optional learnable_sink_) { // Otherwise the kernel will be launched from cuda:0 device at::cuda::CUDAGuard device_guard{q.device()}; @@ -589,7 +604,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza - const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value(); + const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value() && !learnable_sink_.has_value(); const int ngroups = num_heads / num_heads_k; if (seqlenq_ngroups_swapped) { q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size}); @@ -734,6 +749,18 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s set_params_alibi(params, alibi_slopes_, batch_size, num_heads); + at::Tensor learnable_sink; + if (learnable_sink_.has_value()) { + learnable_sink = learnable_sink_.value().to(torch::kFloat32); + TORCH_CHECK(params.alibi_slopes_ptr == nullptr, "Learnable sink and ALiBi slopes cannot be used together"); + CHECK_DEVICE(learnable_sink); CHECK_CONTIGUOUS(learnable_sink); + TORCH_CHECK(learnable_sink.stride(-1) == 1, "Learnable sink tensor must have contiguous last dimension"); + CHECK_SHAPE(learnable_sink, num_heads); + params.learnable_sink_ptr = learnable_sink.data_ptr(); + } else { + params.learnable_sink_ptr = nullptr; + } + if (max_seqlen_k > 0) { auto stream = at::cuda::getCurrentCUDAStream().stream(); run_mha_fwd(params, stream, paged_KV); @@ -783,7 +810,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl const float softcap, const bool deterministic, std::optional gen_, - std::optional &rng_state) { + std::optional &rng_state, + std::optional learnable_sink_, + std::optional dsink_) { #ifdef FLASHATTENTION_DISABLE_BACKWARD TORCH_CHECK(false, "This flash attention build does not support backward."); @@ -952,6 +981,30 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl set_params_alibi(params, alibi_slopes_, batch_size, num_heads); + at::Tensor dsink; + at::Tensor learnable_sink; + if (learnable_sink_.has_value()) { + learnable_sink = learnable_sink_.value().to(torch::kFloat32); + TORCH_CHECK(params.alibi_slopes_ptr == nullptr, "Learnable sink and ALiBi slopes cannot be used together"); + CHECK_DEVICE(learnable_sink); CHECK_CONTIGUOUS(learnable_sink); + TORCH_CHECK(learnable_sink.stride(-1) == 1, "Learnable sink tensor must have contiguous last dimension"); + CHECK_SHAPE(learnable_sink, num_heads); + if (dsink_.has_value() && dsink_.value().dtype() == torch::kFloat32) { + dsink = dsink_.value(); + CHECK_DEVICE(dsink); CHECK_CONTIGUOUS(dsink); + TORCH_CHECK(dsink.stride(-1) == 1, "dsink tensor must have contiguous last dimension"); + CHECK_SHAPE(dsink, num_heads); + } else { + dsink = torch::empty_like(learnable_sink); + } + dsink.zero_(); + params.learnable_sink_ptr = learnable_sink.data_ptr(); + params.dsink_ptr = dsink.data_ptr(); + } else { + params.learnable_sink_ptr = nullptr; + params.dsink_ptr = nullptr; + } + if (seqlen_q > 0) { launch(params, stream); } else { @@ -967,7 +1020,16 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); } - return { dq, dk, dv, softmax_d }; + if (learnable_sink_.has_value()) { + if (!dsink_.has_value()) { + dsink = dsink.to(learnable_sink_.value().dtype()); + } else if (dsink_.value().dtype() != torch::kFloat32) { + dsink = dsink.to(dsink_.value().dtype()); + dsink_.value().copy_(dsink); + } + } + + return { dq, dk, dv, softmax_d, dsink }; } std::vector @@ -994,7 +1056,9 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size const float softcap, const bool deterministic, std::optional gen_, - std::optional &rng_state) { + std::optional &rng_state, + std::optional learnable_sink_, + std::optional dsink_) { #ifdef FLASHATTENTION_DISABLE_BACKWARD TORCH_CHECK(false, "This flash attention build does not support backward."); @@ -1181,6 +1245,30 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size set_params_alibi(params, alibi_slopes_, batch_size, num_heads); + at::Tensor dsink; + at::Tensor learnable_sink; + if (learnable_sink_.has_value()) { + learnable_sink = learnable_sink_.value().to(torch::kFloat32); + TORCH_CHECK(params.alibi_slopes_ptr == nullptr, "Learnable sink and ALiBi slopes cannot be used together"); + CHECK_DEVICE(learnable_sink); CHECK_CONTIGUOUS(learnable_sink); + TORCH_CHECK(learnable_sink.stride(-1) == 1, "Learnable sink tensor must have contiguous last dimension"); + CHECK_SHAPE(learnable_sink, num_heads); + if (dsink_.has_value() && dsink_.value().dtype() == torch::kFloat32) { + dsink = dsink_.value(); + CHECK_DEVICE(dsink); CHECK_CONTIGUOUS(dsink); + TORCH_CHECK(dsink.stride(-1) == 1, "dsink tensor must have contiguous last dimension"); + CHECK_SHAPE(dsink, num_heads); + } else { + dsink = torch::empty_like(learnable_sink); + } + dsink.zero_(); + params.learnable_sink_ptr = learnable_sink.data_ptr(); + params.dsink_ptr = dsink.data_ptr(); + } else { + params.learnable_sink_ptr = nullptr; + params.dsink_ptr = nullptr; + } + if (max_seqlen_q > 0) { launch(params, stream); } else { @@ -1196,7 +1284,16 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2}); } - return { dq, dk, dv, softmax_d }; + if (learnable_sink_.has_value()) { + if (!dsink_.has_value()) { + dsink = dsink.to(learnable_sink_.value().dtype()); + } else if (dsink_.value().dtype() != torch::kFloat32) { + dsink = dsink.to(dsink_.value().dtype()); + dsink_.value().copy_(dsink); + } + } + + return { dq, dk, dv, softmax_d, dsink }; } std::vector @@ -1219,7 +1316,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he int window_size_right, const float softcap, bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 - int num_splits + int num_splits, + std::optional learnable_sink_ ) { // Otherwise the kernel will be launched from cuda:0 device @@ -1275,7 +1373,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza - const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value(); + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value() && !learnable_sink_.has_value(); if (seqlenq_ngroups_swapped) { const int ngroups = num_heads / num_heads_k; q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); @@ -1451,6 +1549,18 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he set_params_alibi(params, alibi_slopes_, batch_size, num_heads); + at::Tensor learnable_sink; + if (learnable_sink_.has_value()) { + learnable_sink = learnable_sink_.value().to(torch::kFloat32); + TORCH_CHECK(params.alibi_slopes_ptr == nullptr, "Learnable sink and ALiBi slopes cannot be used together"); + CHECK_DEVICE(learnable_sink); CHECK_CONTIGUOUS(learnable_sink); + TORCH_CHECK(learnable_sink.stride(-1) == 1, "Learnable sink tensor must have contiguous last dimension"); + CHECK_SHAPE(learnable_sink, num_heads); + params.learnable_sink_ptr = learnable_sink.data_ptr(); + } else { + params.learnable_sink_ptr = nullptr; + } + auto stream = at::cuda::getCurrentCUDAStream().stream(); // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx, // or paged KV cache diff --git a/csrc/flash_attn/src/flash.h b/csrc/flash_attn/src/flash.h index 8ffbb62d66e..4453c9e55f0 100644 --- a/csrc/flash_attn/src/flash.h +++ b/csrc/flash_attn/src/flash.h @@ -140,6 +140,7 @@ struct Flash_fwd_params : public Qkv_params { bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q]. bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d). + void *__restrict__ learnable_sink_ptr; }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -182,6 +183,7 @@ struct Flash_bwd_params : public Flash_fwd_params { bool deterministic; index_t dq_accum_split_stride; + void *__restrict__ dsink_ptr; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index a9e9fe0ae8e..fbd29ee7bc1 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -77,7 +77,7 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom const& copy_atom, //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, const int bidh, const int n_block) { using Element = typename Kernel_traits::Element; @@ -100,6 +100,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in const BlockInfo binfo(params, bidb); if (n_block * kBlockN >= binfo.actual_seqlen_k) return; + float dsink_val = 0.0f; + + const float sink_val = !Has_sink ? -INFINITY : reinterpret_cast(params.learnable_sink_ptr)[bidh]; + int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM); if (Is_local) { m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left, kBlockM)); @@ -451,6 +455,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in clear(acc_dv); clear(acc_dk); + if constexpr (Has_sink) { + float* dsink_block_sum_ptr = reinterpret_cast(smem_ + Kernel_traits::kSmemSize1colblock - sizeof(float)); + if (tidx == 0) *dsink_block_sum_ptr = 0.0f; + } + const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; FLASH_NAMESPACE::Alibi alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q); @@ -586,12 +595,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in }; #pragma unroll for (int mi = 0; mi < size<0>(dS); ++mi) { + float dsink_val_cols = 0.0f; #pragma unroll for (int ni = 0; ni < size<1>(dS); ++ni) { + if constexpr (Has_sink) { dsink_val_cols += pointwise_mult(scores(mi, ni), dS(mi, ni), 0.0f); } float scaled_ds = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi)); if constexpr (Is_softcap) { scaled_ds *= dtanh(mi, ni); } dS(mi, ni) = scaled_ds; } + if constexpr (Has_sink) { dsink_val += dsink_val_cols * expf(sink_val - lse(mi)); } } // if (cute::thread0()) { print(dS); } @@ -792,6 +804,23 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN ); + if constexpr (Has_sink) { + SumOp sum_op; + float* dsink_block_sum_ptr = reinterpret_cast(smem_ + Kernel_traits::kSmemSize1colblock - sizeof(float)); + dsink_val = Allreduce<32>::run(dsink_val, sum_op); + + if (tidx % 32 == 0) { + atomicAdd(dsink_block_sum_ptr, dsink_val); + } + __syncthreads(); + + if (tidx == 0 && params.dsink_ptr != nullptr && *dsink_block_sum_ptr != 0) { + float* dsink_ptr = reinterpret_cast(params.dsink_ptr); + float val = -(*dsink_block_sum_ptr) * params.rp_dropout; + atomicAdd(dsink_ptr + bidh, val); + } + } + } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -823,7 +852,7 @@ inline __device__ void compute_dq_dk_dv(const Params ¶ms) { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) { // The block index for the batch. @@ -833,7 +862,7 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) { // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer. for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) { - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); } } diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index 72e7a333b3a..925942ac513 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -39,10 +39,10 @@ DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bo #endif } -DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) { +DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Has_sink) { #if defined(ARCH_SUPPORTS_FLASH) static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false - FLASH_NAMESPACE::compute_dq_dk_dv_seqk_parallel(params); + FLASH_NAMESPACE::compute_dq_dk_dv_seqk_parallel(params); #else FLASH_UNSUPPORTED_ARCH #endif @@ -99,17 +99,19 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] { ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If head dim > 128, set IsEvenMNConst to false to reduce number of templates - // If Is_local, set Is_causal to false - auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; - // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; - if (smem_size_dq_dk_dv >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + SINK_SWITCH(params.learnable_sink_ptr != nullptr, Has_sink, [&] { + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + if (smem_size_dq_dk_dv >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); }); }); diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index d492c87b5c8..6f1fb01e475 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -48,7 +48,7 @@ __forceinline__ __device__ auto get_lse_tile(const Params ¶ms, const int bid } -template +template inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { using Element = typename Kernel_traits::Element; @@ -80,6 +80,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi const BlockInfo binfo(params, bidb); if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + const float sink_val = !Has_sink ? -INFINITY : reinterpret_cast(params.learnable_sink_ptr)[bidh]; + const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); if (Is_causal || Is_local) { @@ -122,7 +124,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi #pragma unroll for (int m = 0; m < size<1>(tOgO); ++m) { const int row = get<0>(tOcO(0, m, 0)); - if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; } + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = Has_sink ? sink_val : INFINITY; } } return; } @@ -282,7 +284,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi clear(acc_o); - FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax; + FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax(sink_val); const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; FLASH_NAMESPACE::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); @@ -340,7 +342,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // TODO: when we have key_padding_mask we'll need to Check_inf masking_step == 0 - ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) + ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2, params.scale_softmax) : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); // Convert acc_s from fp32 to fp16/bf16 @@ -430,7 +432,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Epilogue - Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, params.rp_dropout); + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, params.rp_dropout); // Convert acc_o from fp32 to fp16/bf16 Tensor rO = FLASH_NAMESPACE::convert_type(acc_o); @@ -495,7 +497,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { using Element = typename Kernel_traits::Element; @@ -575,6 +577,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons return; } + const float sink_val = !Has_sink ? -INFINITY : reinterpret_cast(params.learnable_sink_ptr)[bidh]; // We iterate over the blocks in reverse order. This is because the last block is the only one // that needs masking when we read K and V from global memory. Moreover, iterating in reverse // might save us 1 register (we just need n_block instead of both n_block and n_block_max). @@ -833,7 +836,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons clear(acc_o); - FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax; + FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax(sink_val); const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; FLASH_NAMESPACE::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); @@ -914,7 +917,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // We have key_padding_mask so we'll need to Check_inf masking_step == 0 - ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) + ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2, params.scale_softmax) : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } @@ -994,7 +997,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Epilogue - Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); // if (cute::thread0()) { print(lse); } Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) @@ -1072,7 +1075,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn(const Params ¶ms) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1088,12 +1091,12 @@ inline __device__ void compute_attn(const Params ¶ms) { // the attention matrix. This way, as long as we have the batch, head, and the location of // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. - FLASH_NAMESPACE::compute_attn_1rowblock(params, bidb, bidh, m_block); + FLASH_NAMESPACE::compute_attn_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_splitkv(const Params ¶ms) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1102,12 +1105,12 @@ inline __device__ void compute_attn_splitkv(const Params ¶ms) { const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; const int n_split_idx = Split ? blockIdx.y : 0; const int num_n_splits = Split ? gridDim.y : 1; - FLASH_NAMESPACE::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); + FLASH_NAMESPACE::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; @@ -1192,6 +1195,25 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); } SumOp sum_op; lse_sum = Allreduce::run(lse_sum, sum_op); + + if constexpr(Has_sink) { + const int row = tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + if (row < params.num_splits && col < kBlockM) { + const index_t lse_offset = row_offset_lse + col; + int head_idx; + if (params.unpadded_lse) { + // LSE is written as (h, seqlen_q, b) or (h, b, seqlen_q). + head_idx = lse_offset / (params.b * params.seqlen_q); + } else { + // LSE is written as (b, h, seqlen_q). + head_idx = (lse_offset % (params.h * params.seqlen_q)) / params.seqlen_q; + } + const float sink_val_exp = __expf(reinterpret_cast(params.learnable_sink_ptr)[head_idx] - lse_max); + lse_sum += sink_val_exp; + } + } + // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max; diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index 934e7b9114b..b83bc9fce3f 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -29,26 +29,26 @@ namespace FLASH_NAMESPACE { template \ __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) { +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, bool Has_sink) { #if defined(ARCH_SUPPORTS_FLASH) static_assert(!(Is_causal && Is_local)); // Enforce constraints - FLASH_NAMESPACE::compute_attn(params); + FLASH_NAMESPACE::compute_attn(params); #else FLASH_UNSUPPORTED_ARCH #endif } -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV) { +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_sink) { #if defined(ARCH_SUPPORTS_FLASH) - FLASH_NAMESPACE::compute_attn_splitkv(params); + FLASH_NAMESPACE::compute_attn_splitkv(params); #else FLASH_UNSUPPORTED_ARCH #endif } -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) { +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K, bool Has_sink) { static_assert(Log_max_splits >= 1); - FLASH_NAMESPACE::combine_attn_seqk_parallel(params); + FLASH_NAMESPACE::combine_attn_seqk_parallel(params); } template @@ -71,25 +71,27 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { - // Will only return softmax if dropout, to reduce compilation time. - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If return_softmax, set IsEvenMNConst to false to reduce number of templates - // If head dim > 128, set IsEvenMNConst to false to reduce number of templates - // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_kernel; - // auto kernel = &flash_fwd_kernel; - // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); - // auto kernel = &flash_fwd_kernel; - if (smem_size >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - // int ctas_per_sm; - // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); - // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + SINK_SWITCH(params.learnable_sink_ptr != nullptr, Has_sink, [&] { + // Will only return softmax if dropout, to reduce compilation time. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If return_softmax, set IsEvenMNConst to false to reduce number of templates + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_kernel; + // auto kernel = &flash_fwd_kernel; + // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); }); }); @@ -114,18 +116,21 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { - // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_splitkv_kernel; - // auto kernel = &flash_fwd_splitkv_kernel; - // auto kernel = &flash_fwd_splitkv_kernel; - if (smem_size >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + SINK_SWITCH(params.learnable_sink_ptr != nullptr, Has_sink, [&] { + // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If Is_local, set Is_causal to false + // We fix the lse for sink in combine_attn_seqk_parallel if Split is true. + auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); }); }); @@ -140,22 +145,24 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { - if (params.num_splits <= 2) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 4) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 8) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 16) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 32) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 64) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 128) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } - C10_CUDA_KERNEL_LAUNCH_CHECK(); + SINK_SWITCH(params.learnable_sink_ptr != nullptr, Has_sink, [&] { + if (params.num_splits <= 2) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 4) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 8) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 16) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 32) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 64) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 128) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); } } diff --git a/csrc/flash_attn/src/kernel_traits.h b/csrc/flash_attn/src/kernel_traits.h index 8c0897488dc..49c824fa987 100644 --- a/csrc/flash_attn/src/kernel_traits.h +++ b/csrc/flash_attn/src/kernel_traits.h @@ -287,7 +287,7 @@ struct Flash_bwd_kernel_traits : public Base { static constexpr int kSmemSize1colblock = kSmemQdOSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + kSmemPSize - : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)) + sizeof(float); // + sizeof(float) for dsink static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); diff --git a/csrc/flash_attn/src/softmax.h b/csrc/flash_attn/src/softmax.h index 01589adedb3..2fd406bba86 100644 --- a/csrc/flash_attn/src/softmax.h +++ b/csrc/flash_attn/src/softmax.h @@ -130,16 +130,24 @@ struct Softmax { using TensorT = decltype(make_tensor(Shape>{})); TensorT row_max, row_sum; + float sink_val; - __forceinline__ __device__ Softmax() {}; + __forceinline__ __device__ Softmax(const float sink_val = -INFINITY) : sink_val(sink_val) {}; - template - __forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) { + template + __forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2, float softmax_scale=1.0f) { // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(acc_s.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_s.layout())); static_assert(decltype(size<0>(scores))::value == kNRows); if (Is_first) { - FLASH_NAMESPACE::template reduce_max(scores, row_max); + if constexpr (Has_sink) { + const float sink_scaled = sink_val / softmax_scale; + #pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { row_max(mi) = sink_scaled; } + FLASH_NAMESPACE::template reduce_max(scores, row_max); + } else { + FLASH_NAMESPACE::template reduce_max(scores, row_max); + } FLASH_NAMESPACE::scale_apply_exp2(scores, row_max, softmax_scale_log2); FLASH_NAMESPACE::reduce_sum(scores, row_sum); } else { @@ -166,7 +174,7 @@ struct Softmax { } }; - template + template __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) { SumOp sum_op; quad_allreduce_(row_sum, row_sum, sum_op); @@ -176,6 +184,10 @@ struct Softmax { #pragma unroll for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { float sum = row_sum(mi); + if (Has_sink) { + const float max_scaled = row_max(mi) == -INFINITY ? 0.f : row_max(mi) * softmax_scale; + sum += exp2f((sink_val - max_scaled) * float(M_LOG2E)); + } float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; diff --git a/csrc/flash_attn/src/static_switch.h b/csrc/flash_attn/src/static_switch.h index 70d14daf69d..7daa1592b0d 100644 --- a/csrc/flash_attn/src/static_switch.h +++ b/csrc/flash_attn/src/static_switch.h @@ -109,3 +109,13 @@ return __VA_ARGS__(); \ } \ }() + +#ifdef FLASHATTENTION_DISABLE_SINK + #define SINK_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define SINK_SWITCH BOOL_SWITCH +#endif diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index a53b4a3108a..7787a25ea0f 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -85,7 +85,8 @@ def _flash_attn_forward( window_size_right: int, softcap: float, alibi_slopes: Optional[torch.Tensor], - return_softmax: bool + return_softmax: bool, + learnable_sink: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd( @@ -102,6 +103,7 @@ def _flash_attn_forward( softcap, return_softmax, None, + learnable_sink, ) return out, softmax_lse, S_dmask, rng_state @@ -118,7 +120,8 @@ def _flash_attn_forward_fake( window_size_right: int, softcap: float, alibi_slopes: Optional[torch.Tensor], - return_softmax: bool + return_softmax: bool, + learnable_sink: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, v = [maybe_contiguous(x) for x in (q, k, v)] batch_size, seqlen_q, num_heads, head_size = q.shape @@ -136,7 +139,7 @@ def _flash_attn_forward_fake( return out, softmax_lse, p, rng_state -if torch.__version__ >= "2.4.0": +if False: # if torch.__version__ >= "2.4.0": _wrapped_flash_attn_forward = torch.ops.flash_attn._flash_attn_forward else: _wrapped_flash_attn_forward = _flash_attn_forward @@ -163,6 +166,7 @@ def _flash_attn_varlen_forward( leftpad_k: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, zero_tensors: bool = False, + learnable_sink: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd( @@ -187,6 +191,7 @@ def _flash_attn_varlen_forward( softcap, return_softmax, None, + learnable_sink, ) # if out.isnan().any() or softmax_lse.isnan().any(): # breakpoint() @@ -214,6 +219,7 @@ def _flash_attn_varlen_forward_fake( leftpad_k: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, zero_tensors: bool = False, + learnable_sink: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, v = [maybe_contiguous(x) for x in (q, k, v)] paged_kv = block_table is not None @@ -232,7 +238,7 @@ def _flash_attn_varlen_forward_fake( return out, softmax_lse, p, rng_state -if torch.__version__ >= "2.4.0": +if False: # if torch.__version__ >= "2.4.0": _wrapped_flash_attn_varlen_forward = torch.ops.flash_attn._flash_attn_varlen_forward else: _wrapped_flash_attn_varlen_forward = _flash_attn_varlen_forward @@ -258,6 +264,8 @@ def _flash_attn_backward( alibi_slopes: Optional[torch.Tensor], deterministic: bool, rng_state: Optional[torch.Tensor] = None, + learnable_sink: Optional[torch.Tensor] = None, + dsink: Optional[torch.Tensor] = None, ) -> torch.Tensor: # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] @@ -266,6 +274,7 @@ def _flash_attn_backward( dk, dv, softmax_d, + dsink, ) = flash_attn_gpu.bwd( dout, q, @@ -286,6 +295,8 @@ def _flash_attn_backward( deterministic, None, rng_state, + learnable_sink, + dsink, ) return softmax_d @@ -310,6 +321,8 @@ def _flash_attn_backward_fake( alibi_slopes: Optional[torch.Tensor], deterministic: bool, rng_state: Optional[torch.Tensor] = None, + learnable_sink: Optional[torch.Tensor] = None, + dsink: Optional[torch.Tensor] = None, ) -> torch.Tensor: dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] if dq is None: @@ -318,6 +331,8 @@ def _flash_attn_backward_fake( dk = torch.empty_like(k) if dv is None: dv = torch.empty_like(v) + if dsink is None and learnable_sink is not None: + dsink = torch.empty_like(learnable_sink) batch_size, seqlen_q, num_heads, _ = q.shape if torch.cuda.is_available() and torch.version.hip: softmax_d = torch.empty((batch_size, num_heads, seqlen_q), device=q.device, dtype=torch.float32) @@ -327,7 +342,7 @@ def _flash_attn_backward_fake( return softmax_d -if torch.__version__ >= "2.4.0": +if False: # if torch.__version__ >= "2.4.0": _wrapped_flash_attn_backward = torch.ops.flash_attn._flash_attn_backward else: _wrapped_flash_attn_backward = _flash_attn_backward @@ -358,6 +373,8 @@ def _flash_attn_varlen_backward( deterministic: bool, rng_state: Optional[torch.Tensor] = None, zero_tensors: bool = False, + learnable_sink: Optional[torch.Tensor] = None, + dsink: Optional[torch.Tensor] = None, ) -> torch.Tensor: # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] @@ -366,6 +383,7 @@ def _flash_attn_varlen_backward( dk, dv, softmax_d, + dsink, ) = flash_attn_gpu.varlen_bwd( dout, q, @@ -391,6 +409,8 @@ def _flash_attn_varlen_backward( deterministic, None, rng_state, + learnable_sink, + dsink, ) # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): # breakpoint() @@ -422,6 +442,8 @@ def _flash_attn_varlen_backward_fake( deterministic: bool, rng_state: Optional[torch.Tensor] = None, zero_tensors: bool = False, + learnable_sink: Optional[torch.Tensor] = None, + dsink: Optional[torch.Tensor] = None, ) -> torch.Tensor: dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] batch_size = cu_seqlens_q.numel() - 1 @@ -433,6 +455,8 @@ def _flash_attn_varlen_backward_fake( dk = torch.empty_like(k) if dv is None: dv = torch.empty_like(v) + if dsink is None and learnable_sink is not None: + dsink = torch.empty_like(learnable_sink) if torch.cuda.is_available() and torch.version.hip: softmax_d = torch.empty((num_heads, total_q), device=q.device, dtype=torch.float32) else: @@ -441,7 +465,7 @@ def _flash_attn_varlen_backward_fake( return softmax_d -if torch.__version__ >= "2.4.0": +if False: # if torch.__version__ >= "2.4.0": _wrapped_flash_attn_varlen_backward = torch.ops.flash_attn._flash_attn_varlen_backward else: _wrapped_flash_attn_varlen_backward = _flash_attn_varlen_backward @@ -461,6 +485,7 @@ def forward( deterministic, return_softmax, is_grad_enabled, + learnable_sink, ): is_grad = is_grad_enabled and qkv.requires_grad if softmax_scale is None: @@ -483,9 +508,10 @@ def forward( softcap=softcap, alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, + learnable_sink=learnable_sink, ) if is_grad: - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state, learnable_sink) ctx.dropout_p = dropout_p ctx.softmax_scale = softmax_scale ctx.causal = causal @@ -498,9 +524,10 @@ def forward( @staticmethod def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors + q, k, v, out, softmax_lse, rng_state, learnable_sink = ctx.saved_tensors qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) + dsink = None if learnable_sink is None else torch.empty_like(learnable_sink) head_size_og = dout.size(3) dout_padded = dout if head_size_og % 8 != 0: @@ -524,9 +551,11 @@ def backward(ctx, dout, *args): ctx.alibi_slopes, ctx.deterministic, rng_state=rng_state, + learnable_sink=learnable_sink, + dsink=dsink, ) dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None, None, None, None, None, None + return dqkv, None, None, None, None, None, None, None, None, None, dsink class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): @@ -545,6 +574,7 @@ def forward( deterministic, return_softmax, is_grad_enabled, + learnable_sink, ): is_grad = is_grad_enabled and qkv.requires_grad if softmax_scale is None: @@ -572,9 +602,10 @@ def forward( alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, block_table=None, + learnable_sink=learnable_sink, ) if is_grad: - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state, learnable_sink) ctx.dropout_p = dropout_p ctx.max_seqlen = max_seqlen ctx.softmax_scale = softmax_scale @@ -588,9 +619,10 @@ def forward( @staticmethod def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors + q, k, v, out, softmax_lse, cu_seqlens, rng_state, learnable_sink = ctx.saved_tensors qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) + dsink = None if learnable_sink is None else torch.empty_like(learnable_sink) head_size_og = dout.size(2) dout_padded = dout if head_size_og % 8 != 0: @@ -618,9 +650,11 @@ def backward(ctx, dout, *args): ctx.alibi_slopes, ctx.deterministic, rng_state=rng_state, + learnable_sink=learnable_sink, + dsink=dsink, ) dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None, None, None, None, None, None, None, None + return dqkv, None, None, None, None, None, None, None, None, None, None, None, dsink class FlashAttnKVPackedFunc(torch.autograd.Function): @@ -638,6 +672,7 @@ def forward( deterministic, return_softmax, is_grad_enabled, + learnable_sink, ): is_grad = is_grad_enabled and any( x.requires_grad for x in [q, kv] @@ -662,9 +697,10 @@ def forward( softcap=softcap, alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, + learnable_sink=learnable_sink, ) if is_grad: - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state, learnable_sink) ctx.dropout_p = dropout_p ctx.softmax_scale = softmax_scale ctx.causal = causal @@ -677,10 +713,11 @@ def forward( @staticmethod def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors + q, k, v, out, softmax_lse, rng_state, learnable_sink = ctx.saved_tensors dq = torch.empty_like(q) kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) + dsink = None if learnable_sink is None else torch.empty_like(learnable_sink) head_size_og = dout.size(3) dout_padded = dout if head_size_og % 8 != 0: @@ -704,10 +741,12 @@ def backward(ctx, dout, *args): ctx.alibi_slopes, ctx.deterministic, rng_state=rng_state, + learnable_sink=learnable_sink, + dsink=dsink, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dkv = dkv[..., : dout.shape[-1]] - return dq, dkv, None, None, None, None, None, None, None, None, None + return dq, dkv, None, None, None, None, None, None, None, None, None, dsink class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): @@ -729,6 +768,7 @@ def forward( deterministic, return_softmax, is_grad_enabled, + learnable_sink, ): is_grad = is_grad_enabled and any( x.requires_grad for x in [q, kv] @@ -758,10 +798,11 @@ def forward( alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, block_table=None, + learnable_sink=learnable_sink, ) if is_grad: ctx.save_for_backward( - q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state + q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, learnable_sink ) ctx.dropout_p = dropout_p ctx.max_seqlen_q = max_seqlen_q @@ -777,10 +818,11 @@ def forward( @staticmethod def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, learnable_sink = ctx.saved_tensors dq = torch.empty_like(q) kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) + dsink = None if learnable_sink is None else torch.empty_like(learnable_sink) head_size_og = dout.size(2) dout_padded = dout if head_size_og % 8 != 0: @@ -808,10 +850,12 @@ def backward(ctx, dout, *args): ctx.alibi_slopes, ctx.deterministic, rng_state=rng_state, + learnable_sink=learnable_sink, + dsink=dsink, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dkv = dkv[..., : dout.shape[-1]] - return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None, None + return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None, None, dsink class FlashAttnFunc(torch.autograd.Function): @@ -830,6 +874,7 @@ def forward( deterministic, return_softmax, is_grad_enabled, + learnable_sink, ): is_grad = is_grad_enabled and any( x.requires_grad for x in [q, k, v] @@ -853,9 +898,10 @@ def forward( softcap=softcap, alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, + learnable_sink=learnable_sink, ) if is_grad: - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state, learnable_sink) ctx.dropout_p = dropout_p ctx.softmax_scale = softmax_scale ctx.causal = causal @@ -868,8 +914,9 @@ def forward( @staticmethod def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors + q, k, v, out, softmax_lse, rng_state, learnable_sink = ctx.saved_tensors dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + dsink = None if learnable_sink is None else torch.empty_like(learnable_sink) head_size_og = dout.size(3) dout_padded = dout if head_size_og % 8 != 0: @@ -893,11 +940,13 @@ def backward(ctx, dout, *args): ctx.alibi_slopes, ctx.deterministic, rng_state=rng_state, + learnable_sink=learnable_sink, + dsink=dsink, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, dsink class FlashAttnVarlenFunc(torch.autograd.Function): @@ -921,6 +970,7 @@ def forward( return_softmax, block_table, is_grad_enabled, + learnable_sink, ): is_grad = is_grad_enabled and any( x.requires_grad for x in [q, k, v] @@ -949,10 +999,11 @@ def forward( alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, block_table=block_table, + learnable_sink=learnable_sink, ) if is_grad: ctx.save_for_backward( - q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state + q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, learnable_sink ) ctx.dropout_p = dropout_p ctx.max_seqlen_q = max_seqlen_q @@ -969,8 +1020,9 @@ def forward( @staticmethod def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, learnable_sink = ctx.saved_tensors dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + dsink = None if learnable_sink is None else torch.empty_like(learnable_sink) head_size_og = dout.size(2) dout_padded = dout if head_size_og % 8 != 0: @@ -998,11 +1050,13 @@ def backward(ctx, dout, *args): ctx.alibi_slopes, ctx.deterministic, rng_state=rng_state, + learnable_sink=learnable_sink, + dsink=dsink, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, dsink def flash_attn_qkvpacked_func( @@ -1015,6 +1069,7 @@ def flash_attn_qkvpacked_func( alibi_slopes=None, deterministic=False, return_attn_probs=False, + learnable_sink=None, ): """dropout_p should be set to 0.0 during evaluation If Q, K, V are already stacked into 1 tensor, this function will be faster than @@ -1041,6 +1096,8 @@ def flash_attn_qkvpacked_func( return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). + learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an + additional attention logit to each query's attention scores before softmax. Return: out: (batch_size, seqlen, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The @@ -1061,6 +1118,7 @@ def flash_attn_qkvpacked_func( deterministic, return_attn_probs, torch.is_grad_enabled(), + learnable_sink, ) @@ -1075,6 +1133,7 @@ def flash_attn_kvpacked_func( alibi_slopes=None, deterministic=False, return_attn_probs=False, + learnable_sink=None, ): """dropout_p should be set to 0.0 during evaluation If K, V are already stacked into 1 tensor, this function will be faster than @@ -1118,6 +1177,8 @@ def flash_attn_kvpacked_func( return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). + learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an + additional attention logit to each query's attention scores before softmax. Return: out: (batch_size, seqlen, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The @@ -1139,6 +1200,7 @@ def flash_attn_kvpacked_func( deterministic, return_attn_probs, torch.is_grad_enabled(), + learnable_sink, ) @@ -1154,6 +1216,7 @@ def flash_attn_func( alibi_slopes=None, deterministic=False, return_attn_probs=False, + learnable_sink=None, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads @@ -1194,6 +1257,8 @@ def flash_attn_func( return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). + learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an + additional attention logit to each query's attention scores before softmax. Return: out: (batch_size, seqlen, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The @@ -1216,6 +1281,7 @@ def flash_attn_func( deterministic, return_attn_probs, torch.is_grad_enabled(), + learnable_sink, ) @@ -1231,6 +1297,7 @@ def flash_attn_varlen_qkvpacked_func( alibi_slopes=None, deterministic=False, return_attn_probs=False, + learnable_sink=None, ): """dropout_p should be set to 0.0 during evaluation If Q, K, V are already stacked into 1 tensor, this function will be faster than @@ -1260,6 +1327,8 @@ def flash_attn_varlen_qkvpacked_func( return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). + learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an + additional attention logit to each query's attention scores before softmax. Return: out: (total, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The @@ -1282,6 +1351,7 @@ def flash_attn_varlen_qkvpacked_func( deterministic, return_attn_probs, torch.is_grad_enabled(), + learnable_sink, ) @@ -1300,6 +1370,7 @@ def flash_attn_varlen_kvpacked_func( alibi_slopes=None, deterministic=False, return_attn_probs=False, + learnable_sink=None, ): """dropout_p should be set to 0.0 during evaluation If K, V are already stacked into 1 tensor, this function will be faster than @@ -1349,6 +1420,8 @@ def flash_attn_varlen_kvpacked_func( return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). + learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an + additional attention logit to each query's attention scores before softmax. Return: out: (total, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The @@ -1374,6 +1447,7 @@ def flash_attn_varlen_kvpacked_func( deterministic, return_attn_probs, torch.is_grad_enabled(), + learnable_sink, ) @@ -1394,6 +1468,7 @@ def flash_attn_varlen_func( deterministic=False, return_attn_probs=False, block_table=None, + learnable_sink=None, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads @@ -1441,6 +1516,8 @@ def flash_attn_varlen_func( return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). + learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an + additional attention logit to each query's attention scores before softmax. Return: out: (total, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The @@ -1468,6 +1545,7 @@ def flash_attn_varlen_func( return_attn_probs, block_table, torch.is_grad_enabled(), + learnable_sink, ) @@ -1491,6 +1569,7 @@ def flash_attn_with_kvcache( alibi_slopes=None, num_splits=0, return_softmax_lse=False, + learnable_sink=None, ): """ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from @@ -1612,5 +1691,6 @@ def flash_attn_with_kvcache( softcap, rotary_interleaved, num_splits, + learnable_sink, ) return (out, softmax_lse) if return_softmax_lse else out diff --git a/setup.py b/setup.py index fafea904998..2eb88cca0ed 100644 --- a/setup.py +++ b/setup.py @@ -65,6 +65,7 @@ USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" SKIP_CK_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CK_BUILD", "TRUE") == "TRUE" if USE_TRITON_ROCM else False NVCC_THREADS = os.getenv("NVCC_THREADS") or "4" +DISABLE_SINK = os.getenv("FLASH_ATTENTION_DISABLE_SINK", "FALSE") == "TRUE" @functools.lru_cache(maxsize=None) def cuda_archs() -> str: @@ -271,6 +272,7 @@ def validate_and_update_archs(archs): # "-DFLASHATTENTION_DISABLE_SOFTCAP", # "-DFLASHATTENTION_DISABLE_UNEVEN_K", # "-DFLASHATTENTION_DISABLE_LOCAL", + # "-DFLASHATTENTION_DISABLE_SINK", ] compiler_c17_flag=["-O3", "-std=c++17"] diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index d5bb6ba8531..101937230ed 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -1,4 +1,5 @@ import math +from typing import Optional import pytest import torch @@ -229,6 +230,7 @@ def attention_ref( upcast=True, reorder_ops=False, key_leftpad=None, + learnable_sink=None, ): """ Arguments: @@ -283,7 +285,15 @@ def attention_ref( scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: scores = scores + attn_bias - attention = torch.softmax(scores, dim=-1).to(v.dtype) + if learnable_sink is None: + attention = torch.softmax(scores, dim=-1).to(v.dtype) + else: + learnable_sink = learnable_sink.to(q.dtype) + sinks = repeat(learnable_sink, 'h -> b h n 1', b=scores.shape[0], n=scores.shape[2]) + scores = torch.cat([scores, sinks], dim=-1) + attention = torch.softmax(scores, dim=-1).to(v.dtype) + attention = attention[..., :-1] + # Some rows might be completely masked out so we fill them with zero instead of NaN if window_size[0] >= 0 or window_size[1] >= 0: attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) @@ -318,6 +328,7 @@ def attention_kvpacked_ref( upcast=True, reorder_ops=False, key_leftpad=None, + learnable_sink=None, ): return attention_ref( q, @@ -334,6 +345,7 @@ def attention_kvpacked_ref( softcap=softcap, reorder_ops=reorder_ops, key_leftpad=key_leftpad, + learnable_sink=learnable_sink, ) @@ -348,6 +360,7 @@ def attention_qkvpacked_ref( softcap=0.0, upcast=True, reorder_ops=False, + learnable_sink=None, ): return attention_ref( qkv[:, :, 0], @@ -363,6 +376,7 @@ def attention_qkvpacked_ref( window_size=window_size, softcap=softcap, reorder_ops=reorder_ops, + learnable_sink=learnable_sink, ) @@ -473,6 +487,7 @@ def normalize_flash_attn_S( is_dropout=False, causal=False, window_size=(-1, -1), # -1 means infinite window size + learnable_sink=None, ): """ Arguments: @@ -507,11 +522,17 @@ def normalize_flash_attn_S( block_size_n = _get_block_size_n(scores.device, head_dim, is_dropout, causal) scores_block = scores.split(block_size_n, dim=-1) lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1) - lse = torch.logsumexp(lse_block, dim=-1) + if learnable_sink is None: + lse = torch.logsumexp(lse_block, dim=-1) + scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1) + else: + learnable_sink = learnable_sink.to(q.dtype) + sinks = repeat(learnable_sink, 'h -> b h n 1', b=scores.shape[0], n=scores.shape[2]) + lse = torch.logsumexp(torch.cat([lse_block, sinks], dim=-1), dim=-1) + scores_max_block = torch.stack([torch.amax(torch.cat([s, sinks], dim=-1), dim=-1) for s in scores_block], dim=-1) # lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf # so that when we do torch.exp(m - lse), we get 0.0 instead of NaN. lse[lse == float("-inf")] = float("inf") - scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1) cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1) attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1) attn_norm = torch.cat( @@ -583,9 +604,12 @@ def get_dropout_fraction( # @pytest.mark.parametrize("seqlen", [512]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) # @pytest.mark.parametrize("dropout_p", [0.0]) -def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype): +@pytest.mark.parametrize("has_learnable_sink", [False, True]) +def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype, has_learnable_sink): if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: pytest.skip() # Reference implementation OOM + if alibi and has_learnable_sink: + pytest.skip("Alibi and learnable sink not supported together") device = "cuda" # set seed torch.random.manual_seed(0) @@ -600,6 +624,9 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal) else: alibi_slopes, attn_bias = None, None + learnable_sink = torch.randn(nheads, device=device, dtype=dtype, requires_grad=True) * 0.3 + if not has_learnable_sink: + learnable_sink = None out, lse, S_dmask = flash_attn_qkvpacked_func( qkv, dropout_p, @@ -608,6 +635,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, + learnable_sink=learnable_sink, ) if dropout_p > 0.0: S_dmask_converted = convert_flash_attn_S_to_softmax( @@ -634,6 +662,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ dropout_p > 0.0, causal=causal, window_size=window_size, + learnable_sink=learnable_sink, ) dropout_fraction = get_dropout_fraction( dropout_mask, None, None, causal=causal, window_size=window_size @@ -643,7 +672,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ dropout_mask = None out_ref, attn_ref = attention_qkvpacked_ref( - qkv, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size + qkv, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, learnable_sink=learnable_sink ) out_pt, attn_pt = attention_qkvpacked_ref( qkv, @@ -655,6 +684,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ window_size=window_size, upcast=False, reorder_ops=True, + learnable_sink=learnable_sink, ) # v = qkv[:, :, 2].float() # qk = torch.einsum('bshd,bthd->bhst', qkv[:, :, 0], qkv[:, :, 1]).float() @@ -687,9 +717,9 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ # dv_tmp = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, :64], g[:, :64]) # dv_tmp1 = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, 64:], g[:, 64:]) if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): - (dqkv,) = torch.autograd.grad(out, qkv, g) - (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g) - (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g) + (dqkv,) = torch.autograd.grad(out, qkv, g, retain_graph=has_learnable_sink) + (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g, retain_graph=has_learnable_sink) + (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g, retain_graph=has_learnable_sink) print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}") print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") @@ -698,6 +728,14 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}") + if has_learnable_sink: + (dsink,) = torch.autograd.grad(out, learnable_sink, g) + (dsink_ref,) = torch.autograd.grad(out_ref, learnable_sink, g) + (dsink_pt,) = torch.autograd.grad(out_pt, learnable_sink, g) + print(f"dSink max diff: {(dsink - dsink_ref).abs().max().item()}") + print(f"dSink mean diff: {(dsink - dsink_ref).abs().mean().item()}") + print(f"dSink Pytorch max diff: {(dsink_pt - dsink_ref).abs().max().item()}") + print(f"dSink Pytorch mean diff: {(dsink_pt - dsink_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. @@ -711,6 +749,8 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() + if has_learnable_sink: + assert (dsink - dsink_ref).abs().max().item() <= 2 * (dsink_pt - dsink_ref).abs().max().item() @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @@ -730,11 +770,14 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ # @pytest.mark.parametrize('seqlen', [128]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) # @pytest.mark.parametrize('dropout_p', [0.0]) +@pytest.mark.parametrize("has_learnable_sink", [False, True]) def test_flash_attn_varlen_qkvpacked( - seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype + seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype, has_learnable_sink ): if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: pytest.skip() # Reference implementation OOM + if alibi and has_learnable_sink: + pytest.skip("Alibi and learnable sink not supported together") device = "cuda" # set seed torch.random.manual_seed(0) @@ -755,6 +798,10 @@ def test_flash_attn_varlen_qkvpacked( else: alibi_slopes, attn_bias = None, None + learnable_sink = torch.randn(nheads, device=device, dtype=dtype, requires_grad=True) * 0.3 + if not has_learnable_sink: + learnable_sink = None + qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv( *qkv.unbind(dim=2), key_padding_mask, key_padding_mask, qkvpacked=True ) @@ -769,6 +816,7 @@ def test_flash_attn_varlen_qkvpacked( alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, + learnable_sink=learnable_sink, ) out = output_pad_fn(out_unpad) if dropout_p > 0.0: @@ -796,6 +844,7 @@ def test_flash_attn_varlen_qkvpacked( dropout_p > 0.0, causal=causal, window_size=window_size, + learnable_sink=learnable_sink, ) dropout_fraction = get_dropout_fraction( dropout_mask, key_padding_mask, key_padding_mask, causal=causal, window_size=window_size @@ -812,6 +861,7 @@ def test_flash_attn_varlen_qkvpacked( dropout_mask, causal=causal, window_size=window_size, + learnable_sink=learnable_sink, ) out_pt, attn_pt = attention_qkvpacked_ref( qkv, @@ -823,6 +873,7 @@ def test_flash_attn_varlen_qkvpacked( window_size=window_size, upcast=False, reorder_ops=True, + learnable_sink=learnable_sink, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") @@ -834,10 +885,10 @@ def test_flash_attn_varlen_qkvpacked( g = torch.randn_like(out) if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): - (dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g) + (dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g, retain_graph=has_learnable_sink) dqkv = dqkv_pad_fn(dqkv_unpad) - (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g) - (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g) + (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g, retain_graph=has_learnable_sink) + (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g, retain_graph=has_learnable_sink) print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}") print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") @@ -846,6 +897,14 @@ def test_flash_attn_varlen_qkvpacked( print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}") + if has_learnable_sink: + (dsink,) = torch.autograd.grad(out, learnable_sink, g) + (dsink_ref,) = torch.autograd.grad(out_ref, learnable_sink, g) + (dsink_pt,) = torch.autograd.grad(out_pt, learnable_sink, g) + print(f"dSink max diff: {(dsink - dsink_ref).abs().max().item()}") + print(f"dSink mean diff: {(dsink - dsink_ref).abs().mean().item()}") + print(f"dSink Pytorch max diff: {(dsink_pt - dsink_ref).abs().max().item()}") + print(f"dSink Pytorch mean diff: {(dsink_pt - dsink_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. @@ -859,6 +918,8 @@ def test_flash_attn_varlen_qkvpacked( if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() + if has_learnable_sink: + assert (dsink - dsink_ref).abs().max().item() <= 3 * (dsink_pt - dsink_ref).abs().max().item() @pytest.mark.parametrize("kvpacked", [True, False]) @@ -900,8 +961,9 @@ def test_flash_attn_varlen_qkvpacked( @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) # @pytest.mark.parametrize("dropout_p", [0.0]) @pytest.mark.parametrize("softcap", [0.0, 50.0]) +@pytest.mark.parametrize("has_learnable_sink", [False, True]) def test_flash_attn_output( - seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap + seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap, has_learnable_sink ): if ( max(seqlen_q, seqlen_k) >= 2048 @@ -910,6 +972,8 @@ def test_flash_attn_output( pytest.skip() # Reference implementation OOM if softcap > 0.0 and dropout_p > 0.0: pytest.skip("Softcap and dropout not supported together") + if alibi and has_learnable_sink: + pytest.skip("Alibi and learnable sink not supported together") device = "cuda" # set seed torch.random.manual_seed(0) @@ -939,6 +1003,10 @@ def test_flash_attn_output( else: alibi_slopes, attn_bias = None, None + learnable_sink = torch.randn(nheads, device=device, dtype=dtype, requires_grad=True) * 0.3 + if not has_learnable_sink: + learnable_sink = None + if kvpacked: out, lse, S_dmask = flash_attn_kvpacked_func( q, @@ -950,6 +1018,7 @@ def test_flash_attn_output( alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, + learnable_sink=learnable_sink, ) else: out, lse, S_dmask = flash_attn_func( @@ -963,6 +1032,7 @@ def test_flash_attn_output( alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, + learnable_sink=learnable_sink, ) if dropout_p > 0.0: S_dmask_converted = convert_flash_attn_S_to_softmax( @@ -995,6 +1065,7 @@ def test_flash_attn_output( dropout_p > 0.0, causal=causal, window_size=window_size, + learnable_sink=learnable_sink, ) dropout_fraction = get_dropout_fraction( dropout_mask, None, None, causal=causal, window_size=window_size @@ -1015,6 +1086,7 @@ def test_flash_attn_output( causal=causal, window_size=window_size, softcap=softcap, + learnable_sink=learnable_sink, ) out_pt, attn_pt = attention_kvpacked_ref( q, @@ -1029,6 +1101,7 @@ def test_flash_attn_output( softcap=softcap, upcast=False, reorder_ops=True, + learnable_sink=learnable_sink, ) else: out_ref, attn_ref = attention_ref( @@ -1043,6 +1116,7 @@ def test_flash_attn_output( causal=causal, window_size=window_size, softcap=softcap, + learnable_sink=learnable_sink, ) out_pt, attn_pt = attention_ref( q, @@ -1058,6 +1132,7 @@ def test_flash_attn_output( softcap=softcap, upcast=False, reorder_ops=True, + learnable_sink=learnable_sink, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") @@ -1075,34 +1150,34 @@ def test_flash_attn_output( ( dq, dkv, - ) = torch.autograd.grad(out, (q, kv), g) + ) = torch.autograd.grad(out, (q, kv), g, retain_graph=has_learnable_sink) dk, dv = dkv.unbind(2) ( dq_ref, dkv_ref, - ) = torch.autograd.grad(out_ref, (q, kv), g) + ) = torch.autograd.grad(out_ref, (q, kv), g, retain_graph=has_learnable_sink) dk_ref, dv_ref = dkv_ref.unbind(2) ( dq_pt, dkv_pt, - ) = torch.autograd.grad(out_pt, (q, kv), g) + ) = torch.autograd.grad(out_pt, (q, kv), g, retain_graph=has_learnable_sink) dk_pt, dv_pt = dkv_pt.unbind(2) else: ( dq, dk, dv, - ) = torch.autograd.grad(out, (q, k, v), g) + ) = torch.autograd.grad(out, (q, k, v), g, retain_graph=has_learnable_sink) ( dq_ref, dk_ref, dv_ref, - ) = torch.autograd.grad(out_ref, (q, k, v), g) + ) = torch.autograd.grad(out_ref, (q, k, v), g, retain_graph=has_learnable_sink) ( dq_pt, dk_pt, dv_pt, - ) = torch.autograd.grad(out_pt, (q, k, v), g) + ) = torch.autograd.grad(out_pt, (q, k, v), g, retain_graph=has_learnable_sink) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") @@ -1115,6 +1190,14 @@ def test_flash_attn_output( print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + if has_learnable_sink: + (dsink,) = torch.autograd.grad(out, learnable_sink, g) + (dsink_ref,) = torch.autograd.grad(out_ref, learnable_sink, g) + (dsink_pt,) = torch.autograd.grad(out_pt, learnable_sink, g) + print(f"dSink max diff: {(dsink - dsink_ref).abs().max().item()}") + print(f"dSink mean diff: {(dsink - dsink_ref).abs().mean().item()}") + print(f"dSink Pytorch max diff: {(dsink_pt - dsink_ref).abs().max().item()}") + print(f"dSink Pytorch mean diff: {(dsink_pt - dsink_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. @@ -1130,6 +1213,9 @@ def test_flash_attn_output( assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() + if has_learnable_sink: + mult = 3 if not (dtype == torch.float16 and softcap > 0) else softcap + assert (dsink - dsink_ref).abs().max().item() <= mult * (dsink_pt - dsink_ref).abs().max().item() + 1e-5 @pytest.mark.parametrize("kvpacked", [True, False]) @@ -1146,9 +1232,9 @@ def test_flash_attn_output( # @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize('causal', [True]) -@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) -# @pytest.mark.parametrize('d', [64]) +@pytest.mark.parametrize('d', [64]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -1169,8 +1255,9 @@ def test_flash_attn_output( @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) @pytest.mark.parametrize("softcap", [0.0, 50.0]) # @pytest.mark.parametrize('dropout_p', [0.0]) +@pytest.mark.parametrize("has_learnable_sink", [False, True]) def test_flash_attn_varlen_output( - seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap + seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap, has_learnable_sink ): if ( max(seqlen_q, seqlen_k) >= 2048 @@ -1179,6 +1266,8 @@ def test_flash_attn_varlen_output( pytest.skip() # Reference implementation OOM if softcap > 0.0 and dropout_p > 0.0: pytest.skip("Softcap and dropout not supported together") + if alibi and has_learnable_sink: + pytest.skip("Alibi and learnable sink not supported together") device = "cuda" # set seed torch.random.manual_seed(0) @@ -1203,7 +1292,6 @@ def test_flash_attn_varlen_output( v = torch.randn( batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True ) - query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full') @@ -1215,6 +1303,10 @@ def test_flash_attn_varlen_output( else: alibi_slopes, attn_bias = None, None + learnable_sink = torch.randn(nheads, device=device, dtype=dtype, requires_grad=True) * 0.3 + if not has_learnable_sink: + learnable_sink = None + if kvpacked: ( q_unpad, @@ -1243,6 +1335,7 @@ def test_flash_attn_varlen_output( alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, + learnable_sink=learnable_sink, ) else: ( @@ -1275,6 +1368,7 @@ def test_flash_attn_varlen_output( alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, + learnable_sink=learnable_sink, ) out = output_pad_fn(out_unpad) if dropout_p > 0.0: @@ -1308,6 +1402,7 @@ def test_flash_attn_varlen_output( dropout_p > 0.0, causal=causal, window_size=window_size, + learnable_sink=learnable_sink, ) dropout_fraction = get_dropout_fraction( dropout_mask, @@ -1332,6 +1427,7 @@ def test_flash_attn_varlen_output( causal=causal, window_size=window_size, softcap=softcap, + learnable_sink=learnable_sink, ) out_pt, attn_pt = attention_kvpacked_ref( q, @@ -1346,6 +1442,7 @@ def test_flash_attn_varlen_output( softcap=softcap, upcast=False, reorder_ops=True, + learnable_sink=learnable_sink, ) else: out_ref, attn_ref = attention_ref( @@ -1360,6 +1457,7 @@ def test_flash_attn_varlen_output( causal=causal, window_size=window_size, softcap=softcap, + learnable_sink=learnable_sink, ) out_pt, attn_pt = attention_ref( q, @@ -1375,6 +1473,7 @@ def test_flash_attn_varlen_output( softcap=softcap, upcast=False, reorder_ops=True, + learnable_sink=learnable_sink, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") @@ -1391,36 +1490,36 @@ def test_flash_attn_varlen_output( ( dq_unpad, dkv_unpad, - ) = torch.autograd.grad(out, (q_unpad, kv_unpad), g) + ) = torch.autograd.grad(out, (q_unpad, kv_unpad), g, retain_graph=has_learnable_sink) dk, dv = dkv_pad_fn(dkv_unpad).unbind(2) ( dq_ref, dkv_ref, - ) = torch.autograd.grad(out_ref, (q, kv), g) + ) = torch.autograd.grad(out_ref, (q, kv), g, retain_graph=has_learnable_sink) dk_ref, dv_ref = dkv_ref.unbind(2) ( dq_pt, dkv_pt, - ) = torch.autograd.grad(out_pt, (q, kv), g) + ) = torch.autograd.grad(out_pt, (q, kv), g, retain_graph=has_learnable_sink) dk_pt, dv_pt = dkv_pt.unbind(2) else: ( dq_unpad, dk_unpad, dv_unpad, - ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g) + ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=has_learnable_sink) dk = dk_pad_fn(dk_unpad) dv = dk_pad_fn(dv_unpad) ( dq_ref, dk_ref, dv_ref, - ) = torch.autograd.grad(out_ref, (q, k, v), g) + ) = torch.autograd.grad(out_ref, (q, k, v), g, retain_graph=has_learnable_sink) ( dq_pt, dk_pt, dv_pt, - ) = torch.autograd.grad(out_pt, (q, k, v), g) + ) = torch.autograd.grad(out_pt, (q, k, v), g, retain_graph=has_learnable_sink) dq = dq_pad_fn(dq_unpad) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") @@ -1434,6 +1533,14 @@ def test_flash_attn_varlen_output( print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + if has_learnable_sink: + (dsink,) = torch.autograd.grad(out, learnable_sink, g) + (dsink_ref,) = torch.autograd.grad(out_ref, learnable_sink, g) + (dsink_pt,) = torch.autograd.grad(out_pt, learnable_sink, g) + print(f"dSink max diff: {(dsink - dsink_ref).abs().max().item()}") + print(f"dSink mean diff: {(dsink - dsink_ref).abs().mean().item()}") + print(f"dSink Pytorch max diff: {(dsink_pt - dsink_ref).abs().max().item()}") + print(f"dSink Pytorch mean diff: {(dsink_pt - dsink_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. @@ -1449,6 +1556,8 @@ def test_flash_attn_varlen_output( assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() + if has_learnable_sink: + assert (dsink - dsink_ref).abs().max().item() <= 3 * (dsink_pt - dsink_ref).abs().max().item() @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @@ -1478,8 +1587,8 @@ def test_flash_attn_varlen_output( (1023, 1024), ], ) -# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) -def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): +@pytest.mark.parametrize("has_learnable_sink", [False, True]) +def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype, has_learnable_sink): if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 @@ -1497,9 +1606,12 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) - out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size) + learnable_sink = torch.randn(nheads, device=device, dtype=dtype, requires_grad=True) * 0.3 + if not has_learnable_sink: + learnable_sink = None + out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, learnable_sink=learnable_sink) out_ref, attn_ref = attention_ref( - q, k, v, None, None, None, 0.0, None, causal=causal, window_size=window_size + q, k, v, None, None, None, 0.0, None, causal=causal, window_size=window_size, learnable_sink=learnable_sink ) out_pt, attn_pt = attention_ref( q, @@ -1514,6 +1626,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): window_size=window_size, upcast=False, reorder_ops=True, + learnable_sink=learnable_sink, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") @@ -1527,17 +1640,17 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): dq, dk, dv, - ) = torch.autograd.grad(out, (q, k, v), g) + ) = torch.autograd.grad(out, (q, k, v), g, retain_graph=has_learnable_sink) ( dq_ref, dk_ref, dv_ref, - ) = torch.autograd.grad(out_ref, (q, k, v), g) + ) = torch.autograd.grad(out_ref, (q, k, v), g, retain_graph=has_learnable_sink) ( dq_pt, dk_pt, dv_pt, - ) = torch.autograd.grad(out_pt, (q, k, v), g) + ) = torch.autograd.grad(out_pt, (q, k, v), g, retain_graph=has_learnable_sink) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") @@ -1550,6 +1663,14 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + if has_learnable_sink: + (dsink,) = torch.autograd.grad(out, learnable_sink, g) + (dsink_ref,) = torch.autograd.grad(out_ref, learnable_sink, g) + (dsink_pt,) = torch.autograd.grad(out_pt, learnable_sink, g) + print(f"dSink max diff: {(dsink - dsink_ref).abs().max().item()}") + print(f"dSink mean diff: {(dsink - dsink_ref).abs().mean().item()}") + print(f"dSink Pytorch max diff: {(dsink_pt - dsink_ref).abs().max().item()}") + print(f"dSink Pytorch mean diff: {(dsink_pt - dsink_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. @@ -1558,6 +1679,8 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5 assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 + if has_learnable_sink: + assert (dsink - dsink_ref).abs().max().item() <= 2 * (dsink_pt - dsink_ref).abs().max().item() + 1e-5 @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @@ -1590,8 +1713,9 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): # TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512]) # @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)]) +@pytest.mark.parametrize("has_learnable_sink", [False, True]) def test_flash_attn_varlen_causal( - seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype + seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype, has_learnable_sink ): if ( max(seqlen_q, seqlen_k) >= 2048 @@ -1621,6 +1745,9 @@ def test_flash_attn_varlen_causal( k, v, block_table, k_cache_paged, v_cache_paged, num_blocks = _generate_block_kvcache( seqlen_k, paged_kv_block_size, batch_size, nheads, d, device, dtype ) + learnable_sink = torch.randn(nheads, device=device, dtype=dtype, requires_grad=True) * 0.3 + if not has_learnable_sink: + learnable_sink = None query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") ( @@ -1650,6 +1777,7 @@ def test_flash_attn_varlen_causal( causal=causal, window_size=window_size, block_table=block_table, + learnable_sink=learnable_sink, ) out = output_pad_fn(out_unpad) out_ref, attn_ref = attention_ref( @@ -1663,6 +1791,7 @@ def test_flash_attn_varlen_causal( None, causal=causal, window_size=window_size, + learnable_sink=learnable_sink, ) out_pt, attn_pt = attention_ref( q, @@ -1677,6 +1806,7 @@ def test_flash_attn_varlen_causal( window_size=window_size, upcast=False, reorder_ops=True, + learnable_sink=learnable_sink, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") @@ -1692,7 +1822,7 @@ def test_flash_attn_varlen_causal( dq_unpad, dk_unpad, dv_unpad, - ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g) + ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=has_learnable_sink) dq = dq_pad_fn(dq_unpad) dk = dk_pad_fn(dk_unpad) dv = dk_pad_fn(dv_unpad) @@ -1700,12 +1830,12 @@ def test_flash_attn_varlen_causal( dq_ref, dk_ref, dv_ref, - ) = torch.autograd.grad(out_ref, (q, k, v), g) + ) = torch.autograd.grad(out_ref, (q, k, v), g, retain_graph=has_learnable_sink) ( dq_pt, dk_pt, dv_pt, - ) = torch.autograd.grad(out_pt, (q, k, v), g) + ) = torch.autograd.grad(out_pt, (q, k, v), g, retain_graph=has_learnable_sink) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") @@ -1718,6 +1848,14 @@ def test_flash_attn_varlen_causal( print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + if has_learnable_sink: + (dsink,) = torch.autograd.grad(out, learnable_sink, g) + (dsink_ref,) = torch.autograd.grad(out_ref, learnable_sink, g) + (dsink_pt,) = torch.autograd.grad(out_pt, learnable_sink, g) + print(f"dSink max diff: {(dsink - dsink_ref).abs().max().item()}") + print(f"dSink mean diff: {(dsink - dsink_ref).abs().mean().item()}") + print(f"dSink Pytorch max diff: {(dsink_pt - dsink_ref).abs().max().item()}") + print(f"dSink Pytorch mean diff: {(dsink_pt - dsink_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. @@ -1727,6 +1865,8 @@ def test_flash_attn_varlen_causal( assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5 assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 + if has_learnable_sink: + assert (dsink - dsink_ref).abs().max().item() <= 2 * (dsink_pt - dsink_ref).abs().max().item() + 1e-5 @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @@ -1762,9 +1902,12 @@ def test_flash_attn_varlen_causal( ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +@pytest.mark.parametrize("has_learnable_sink", [False, True]) def test_flash_attn_splitkv( - seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, deterministic, dtype + seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, deterministic, dtype, has_learnable_sink ): + if alibi and has_learnable_sink: + pytest.skip("Alibi and learnable sink not supported together") if swap_sq_sk: seqlen_q, seqlen_k = seqlen_k, seqlen_q device = "cuda" @@ -1781,6 +1924,11 @@ def test_flash_attn_splitkv( attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal) else: alibi_slopes, attn_bias = None, None + + learnable_sink = torch.randn(nheads, device=device, dtype=dtype, requires_grad=True) * 0.3 + if not has_learnable_sink: + learnable_sink = None + out, lse, _ = flash_attn_func( q, k, @@ -1791,9 +1939,10 @@ def test_flash_attn_splitkv( alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, + learnable_sink=learnable_sink, ) out_ref, attn_ref = attention_ref( - q, k, v, None, None, attn_bias, 0.0, None, causal=causal, window_size=window_size + q, k, v, None, None, attn_bias, 0.0, None, causal=causal, window_size=window_size, learnable_sink=learnable_sink ) out_pt, attn_pt = attention_ref( q, @@ -1808,6 +1957,7 @@ def test_flash_attn_splitkv( window_size=window_size, upcast=False, reorder_ops=True, + learnable_sink=learnable_sink, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") @@ -1821,17 +1971,17 @@ def test_flash_attn_splitkv( dq, dk, dv, - ) = torch.autograd.grad(out, (q, k, v), g) + ) = torch.autograd.grad(out, (q, k, v), g, retain_graph=has_learnable_sink) ( dq_ref, dk_ref, dv_ref, - ) = torch.autograd.grad(out_ref, (q, k, v), g) + ) = torch.autograd.grad(out_ref, (q, k, v), g, retain_graph=has_learnable_sink) ( dq_pt, dk_pt, dv_pt, - ) = torch.autograd.grad(out_pt, (q, k, v), g) + ) = torch.autograd.grad(out_pt, (q, k, v), g, retain_graph=has_learnable_sink) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") @@ -1844,15 +1994,25 @@ def test_flash_attn_splitkv( print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + if has_learnable_sink: + (dsink,) = torch.autograd.grad(out, learnable_sink, g) + (dsink_ref,) = torch.autograd.grad(out_ref, learnable_sink, g) + (dsink_pt,) = torch.autograd.grad(out_pt, learnable_sink, g) + print(f"dSink max diff: {(dsink - dsink_ref).abs().max().item()}") + print(f"dSink mean diff: {(dsink - dsink_ref).abs().mean().item()}") + print(f"dSink Pytorch max diff: {(dsink_pt - dsink_ref).abs().max().item()}") + print(f"dSink Pytorch mean diff: {(dsink_pt - dsink_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 - mult = 2 if not alibi else 8 + mult = 2 if not (alibi or (has_learnable_sink and dtype == torch.float16)) else 8 assert (dq - dq_ref).abs().max().item() <= mult * (dq_pt - dq_ref).abs().max().item() + 2e-4 assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4 assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4 + if has_learnable_sink: + assert (dsink - dsink_ref).abs().max().item() <= 3 * (dsink_pt - dsink_ref).abs().max().item() + 2e-4 # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @@ -1904,6 +2064,7 @@ def test_flash_attn_splitkv( ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +@pytest.mark.parametrize("has_learnable_sink", [False, True]) def test_flash_attn_kvcache( seqlen_q, seqlen_k, @@ -1921,6 +2082,7 @@ def test_flash_attn_kvcache( mha_type, num_splits, dtype, + has_learnable_sink, ): if seqlen_q > seqlen_k and new_kv: pytest.skip() @@ -1930,6 +2092,9 @@ def test_flash_attn_kvcache( pytest.skip() if has_leftpad and paged_kv_block_size is not None: pytest.skip() + if alibi and has_learnable_sink: + pytest.skip("Alibi and learnable sink not supported together") + device = "cuda" # set seed torch.random.manual_seed(0) @@ -2001,6 +2166,10 @@ def test_flash_attn_kvcache( ) else: alibi_slopes, attn_bias = None, None + + learnable_sink = torch.randn(nheads, device=device, dtype=dtype, requires_grad=True) * 0.3 + if not has_learnable_sink: + learnable_sink = None # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) if rotary_dim > 0: angle = ( @@ -2069,6 +2238,7 @@ def test_flash_attn_kvcache( rotary_interleaved=rotary_interleaved, alibi_slopes=alibi_slopes, num_splits=num_splits, + learnable_sink=learnable_sink, ) # out = flash_attn_with_kvcache( # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size @@ -2092,6 +2262,7 @@ def test_flash_attn_kvcache( causal=causal, window_size=window_size, key_leftpad=cache_leftpad, + learnable_sink=learnable_sink, ) out_pt, _ = attention_ref( q_ro, @@ -2107,6 +2278,7 @@ def test_flash_attn_kvcache( upcast=False, reorder_ops=True, key_leftpad=cache_leftpad, + learnable_sink=learnable_sink, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") @@ -2136,7 +2308,7 @@ def test_flash_attn_kvcache( )[:, :seqlen_k] assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) assert torch.equal(v_cache_select, v_cache_ref) - mult = 3 if not alibi else 5 + mult = 3 if not (alibi or (has_learnable_sink and dtype == torch.float16)) else 5 assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 @@ -2196,7 +2368,8 @@ def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, ) @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) # @pytest.mark.parametrize("dropout_p", [0.0]) -def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dtype): +@pytest.mark.parametrize("has_learnable_sink", [False, True]) +def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dtype, has_learnable_sink): device = "cuda" # set seed torch.random.manual_seed(0) @@ -2205,21 +2378,27 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dty q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + learnable_sink = torch.randn(nheads, device=device, dtype=dtype, requires_grad=True) * 0.3 + if not has_learnable_sink: + learnable_sink = None torch.random.manual_seed(42) - out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True) + out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True, learnable_sink=learnable_sink) g = torch.randn_like(out0) if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): ( dq0, dk0, dv0, - ) = torch.autograd.grad(out0, (q, k, v), g) + ) = torch.autograd.grad(out0, (q, k, v), g, retain_graph=has_learnable_sink) # Numerical error if we just do any arithmetic on dq dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item() + if has_learnable_sink: + dlearnable_sink0, = torch.autograd.grad(out0, (learnable_sink,), g) + dsink_atol = 2 * ((dlearnable_sink0 + 0.3 - 0.3) - dlearnable_sink0).abs().max().item() + 2e-4 for i in range(250): torch.random.manual_seed(42) - out, lse, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True) + out, lse, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True, learnable_sink=learnable_sink) assert torch.equal(out, out0) assert torch.equal(lse, lse0) @@ -2228,13 +2407,19 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dty dq, dk, dv, - ) = torch.autograd.grad(out, (q, k, v), g) + ) = torch.autograd.grad(out, (q, k, v), g, retain_graph=has_learnable_sink) dq_equal = torch.allclose(dq, dq0, atol=dq_atol) if not dq_equal: print(f"Iter {i}, {dq_atol = }, dQ max diff: {(dq - dq0).abs().max().item()}") assert torch.equal(dv, dv0) assert torch.equal(dk, dk0) assert dq_equal + if has_learnable_sink: + dlearnable_sink, = torch.autograd.grad(out, (learnable_sink,), g) + dsink_equal = torch.allclose(dlearnable_sink, dlearnable_sink0, atol=dsink_atol) + if not dsink_equal: + print(f"Iter {i}, {dsink_atol = }, dSink max diff: {(dlearnable_sink - dlearnable_sink0).abs().max().item()}") + assert dsink_equal @pytest.mark.parametrize("dtype", [torch.float16]) @@ -2244,7 +2429,8 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dty # @pytest.mark.parametrize('d', [16]) @pytest.mark.parametrize("seqlen", [1, 2, 5, 17, 128]) # @pytest.mark.parametrize('seqlen', [2]) -def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype): +@pytest.mark.parametrize("has_learnable_sink", [False, True]) +def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype, has_learnable_sink): """We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ, in the case where seqlen % 128 != 0. """ @@ -2261,18 +2447,24 @@ def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype): q.requires_grad_(True) k.requires_grad_(True) v.requires_grad_(True) - out = flash_attn_func(q, k, v, causal=causal) + if has_learnable_sink: + learnable_sink = torch.randn((nheads,), device=device, dtype=dtype, requires_grad=True) + else: + learnable_sink = None + out = flash_attn_func(q, k, v, causal=causal, learnable_sink=learnable_sink) g = torch.randn_like(out) out.backward(g) q_pt = q.detach().clone().requires_grad_(True) k_pt = k.detach().clone().requires_grad_(True) v_pt = v.detach().clone().requires_grad_(True) - out_pt, _ = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True) + learnable_sink_pt = learnable_sink.detach().clone().requires_grad_(True) if has_learnable_sink else None + out_pt, _ = attention_ref(q_pt, k_pt, v_pt, causal=causal, learnable_sink=learnable_sink_pt, upcast=False, reorder_ops=True) out_pt.backward(g) q_ref = q.detach().clone().requires_grad_(True) k_ref = k.detach().clone().requires_grad_(True) v_ref = v.detach().clone().requires_grad_(True) - out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal) + learnable_sink_ref = learnable_sink.detach().clone().requires_grad_(True) if has_learnable_sink else None + out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal, learnable_sink=learnable_sink_ref) out_ref.backward(g) print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}") print(f"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}") @@ -2290,6 +2482,12 @@ def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype): assert (v.grad - v_ref.grad).abs().max().item() <= 5 * ( v_pt.grad - v_ref.grad ).abs().max().item() + 1e-3 + if has_learnable_sink: + print(f"dSink max diff: {(learnable_sink.grad - learnable_sink_ref.grad).abs().max().item()}") + print(f"dSink Pytorch max diff: {(learnable_sink_pt.grad - learnable_sink_ref.grad).abs().max().item()}") + assert (learnable_sink.grad - learnable_sink_ref.grad).abs().max().item() <= 2 * ( + learnable_sink_pt.grad - learnable_sink_ref.grad + ).abs().max().item() @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @@ -2300,7 +2498,8 @@ def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype): # @pytest.mark.parametrize('d', [64]) @pytest.mark.parametrize("seqlen", [97, 128, 200, 256]) # @pytest.mark.parametrize('seqlen', [128]) -def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype): +@pytest.mark.parametrize("has_learnable_sink", [False, True]) +def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype, has_learnable_sink): """We previously had a bug where we were using the wrong strides of dout, which shows up when dout is not contiguous. """ @@ -2313,20 +2512,26 @@ def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype): torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda", requires_grad=True) for _ in range(3) ] - out = rearrange(flash_attn_func(q, k, v, causal=causal), "b s ... -> s b ...") + if has_learnable_sink: + learnable_sink = torch.randn((nheads,), device=device, dtype=dtype, requires_grad=True) + else: + learnable_sink = None + out = rearrange(flash_attn_func(q, k, v, causal=causal, learnable_sink=learnable_sink), "b s ... -> s b ...") # So g is not contiguous g = torch.randn(seqlen, 2 * batch_size, nheads, d, dtype=dtype, device="cuda")[:, ::2] out.backward(g) q_pt = q.detach().clone().requires_grad_(True) k_pt = k.detach().clone().requires_grad_(True) v_pt = v.detach().clone().requires_grad_(True) - out_pt, attn_pt = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True) + learnable_sink_pt = learnable_sink.detach().clone().requires_grad_(True) if has_learnable_sink else None + out_pt, attn_pt = attention_ref(q_pt, k_pt, v_pt, causal=causal, learnable_sink=learnable_sink_pt, upcast=False, reorder_ops=True) out_pt = rearrange(out_pt, "b s ... -> s b ...") out_pt.backward(g) q_ref = q.detach().clone().requires_grad_(True) k_ref = k.detach().clone().requires_grad_(True) v_ref = v.detach().clone().requires_grad_(True) - out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal) + learnable_sink_ref = learnable_sink.detach().clone().requires_grad_(True) if has_learnable_sink else None + out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal, learnable_sink=learnable_sink_ref) out_ref = rearrange(out_ref, "b s ... -> s b ...") out_ref.backward(g) print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}") @@ -2345,6 +2550,12 @@ def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype): assert (v.grad - v_ref.grad).abs().max().item() <= 2 * ( v_pt.grad - v_ref.grad ).abs().max().item() + if has_learnable_sink: + print(f"dSink max diff: {(learnable_sink.grad - learnable_sink_ref.grad).abs().max().item()}") + print(f"dSink Pytorch max diff: {(learnable_sink_pt.grad - learnable_sink_ref.grad).abs().max().item()}") + assert (learnable_sink.grad - learnable_sink_ref.grad).abs().max().item() <= 2 * ( + learnable_sink_pt.grad - learnable_sink_ref.grad + ).abs().max().item() @pytest.mark.parametrize("dtype", [torch.float16]) @@ -2352,7 +2563,8 @@ def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype): # @pytest.mark.parametrize('causal', [False]) @pytest.mark.parametrize("d", [16, 32, 64]) # @pytest.mark.parametrize('d', [16]) -def test_flash_attn_bwd_varlen_overflow(d, causal, dtype): +@pytest.mark.parametrize("has_learnable_sink", [False, True]) +def test_flash_attn_bwd_varlen_overflow(d, causal, dtype, has_learnable_sink): """We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ, in the case where seqlen % 128 != 0 or varlen. """ @@ -2370,15 +2582,20 @@ def test_flash_attn_bwd_varlen_overflow(d, causal, dtype): q.requires_grad_(True) k.requires_grad_(True) v.requires_grad_(True) + if has_learnable_sink: + learnable_sink = torch.randn((nheads,), device=device, dtype=dtype, requires_grad=True) + else: + learnable_sink = None - out = flash_attn_varlen_func(q, k, v, q_cuseqlen, k_cuseqlen, Mq, Mk, causal=causal) + out = flash_attn_varlen_func(q, k, v, q_cuseqlen, k_cuseqlen, Mq, Mk, causal=causal, learnable_sink=learnable_sink) g = torch.randn_like(out) out.backward(g) assert not q.grad.isnan().any() assert not k.grad.isnan().any() assert not v.grad.isnan().any() - + if has_learnable_sink: + assert not learnable_sink.grad.isnan().any() @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16])