From d715844e14e15efdf219137b30f55a81b26dd33f Mon Sep 17 00:00:00 2001 From: jerryao Date: Fri, 15 Aug 2025 11:38:08 +0800 Subject: [PATCH 01/29] Add new API. --- csrc/flash_attn/flash_api.cpp | 391 ++++++++++++++++++++++++++++ csrc/flash_attn/src/flash.h | 4 + csrc/flash_attn/src/static_switch.h | 10 + flash_attn/__init__.py | 1 + flash_attn/flash_attn_interface.py | 215 ++++++++++++++- 5 files changed, 620 insertions(+), 1 deletion(-) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index dd7a5c3f9b4..82336f9768a 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -468,6 +468,174 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_mult window_size_right, softcap ); + params.sink_ptr = nullptr; + + // Keep references to these tensors to extend their lifetime + at::Tensor softmax_lse_accum, out_accum; + std::tie(softmax_lse_accum, out_accum) = set_params_splitkv( + params, batch_size, num_heads, head_size, seqlen_k, seqlen_q, + head_size_rounded, p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts); + + // number of times random will be generated per thread, to offset philox counter in thc random + // state + // We use a custom RNG that increases the offset by batch_size * nheads * 32. + int64_t counter_offset = params.b * params.h * 32; + auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); + // Forward kernel will populate memory with the seed and offset. + params.rng_state = reinterpret_cast(rng_state.data_ptr()); + + if (p_dropout > 0.0) { + auto gen = at::get_generator_or_default( + gen_, at::cuda::detail::getDefaultCUDAGenerator()); + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + params.philox_args = gen->philox_cuda_state(counter_offset); + } + + set_params_alibi(params, alibi_slopes_, batch_size, num_heads); + + if (seqlen_k > 0) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + run_mha_fwd(params, stream); + } else { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); + softmax_lse.fill_(std::numeric_limits::infinity()); + } + + if (seqlenq_ngroups_swapped) { + out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size}); + q = q.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size}); + softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); + } + return {out, softmax_lse, p, rng_state}; +} + +std::vector +mha_sink_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) + const at::Tensor &sink, // num_heads + std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool return_softmax, + std::optional gen_) { + + // Otherwise the kernel will be launched from cuda:0 device + at::cuda::CUDAGuard device_guard{q.device()}; + + auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); + bool is_sm8x_min = cc_major >= 8; + TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer."); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + const auto sizes = q.sizes(); + + const int batch_size = sizes[0]; + int seqlen_q = sizes[1]; + int num_heads = sizes[2]; + const int head_size = sizes[3]; + const int seqlen_k = k.size(1); + const int num_heads_k = k.size(2); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } + + if (window_size_left >= seqlen_k) { window_size_left = -1; } + if (window_size_right >= seqlen_k) { window_size_right = -1; } + + // causal=true is the same as causal=false in this case + if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } + if (is_causal) { window_size_right = 0; } + + // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case + // H/t Daniel Haziza + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value(); + const int ngroups = num_heads / num_heads_k; + if (seqlenq_ngroups_swapped) { + q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2); + seqlen_q = ngroups; + num_heads = num_heads_k; + } + + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); + + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size); + if (seqlenq_ngroups_swapped) { + out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2); + } + } else { + out = torch::empty_like(q); + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + auto opts = q.options(); + + auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + at::Tensor p; + // Only return softmax if there's dropout to reduce compilation time + if (return_softmax) { + TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0"); + p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); + } + else { + p = torch::empty({ 0 }, opts); + } + + Flash_fwd_params params; + set_params_fprop(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, k, v, out, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + /*seqused_k=*/nullptr, + return_softmax ? p.data_ptr() : nullptr, + softmax_lse.data_ptr(), + p_dropout, + softmax_scale, + window_size_left, + window_size_right, + softcap + ); + params.sink_ptr = sink.data_ptr(); + // Keep references to these tensors to extend their lifetime at::Tensor softmax_lse_accum, out_accum; @@ -930,6 +1098,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl deterministic, /*unpadded_lse*/false); params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0); + params.sink_ptr = nullptr; + params.dsink_ptr = nullptr; auto launch = &run_mha_bwd; @@ -970,6 +1140,225 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl return { dq, dk, dv, softmax_d }; } +std::vector +mha_sink_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8) + const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &sink, // num_heads + const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x seqlen_q + std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size + std::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &dsink_, // num_heads + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool deterministic, + std::optional gen_, + std::optional &rng_state) { + + #ifdef FLASHATTENTION_DISABLE_BACKWARD + TORCH_CHECK(false, "This flash attention build does not support backward."); + #endif + if (is_causal) { window_size_right = 0; } + + // Otherwise the kernel will be launched from cuda:0 device + at::cuda::CUDAGuard device_guard{q.device()}; + + auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); + bool is_sm8x_min = cc_major >= 8; + TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer."); + + bool is_dropout = p_dropout > 0.0; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); + TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); + TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); + + const auto sizes = q.sizes(); + + const int batch_size = sizes[0]; + const int seqlen_q = sizes[1]; + const int num_heads = sizes[2]; + const int head_size = sizes[3]; + const int seqlen_k = k.size(1); + const int num_heads_k = k.size(2); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); + TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } + + if (window_size_left >= seqlen_k) { window_size_left = -1; } + if (window_size_right >= seqlen_k) { window_size_right = -1; } + + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size); + + at::Tensor dq, dk, dv, dsink; + if (dq_.has_value()) { + dq = dq_.value(); + TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); + CHECK_DEVICE(dq); + TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); + CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size); + } else { + dq = torch::empty_like(q); + } + if (dk_.has_value()) { + dk = dk_.value(); + TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); + CHECK_DEVICE(dk); + TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); + CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size); + } else { + dk = torch::empty_like(k); + } + if (dv_.has_value()) { + dv = dv_.value(); + TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); + CHECK_DEVICE(dv); + TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); + CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size); + } else { + dv = torch::empty_like(v); + } + if (dsink_.has_value()) { + dsink = dsink_.value(); + TORCH_CHECK(dsink.dtype() == sink.dtype(), "dsink must have the same dtype as sink"); + CHECK_DEVICE(dsink); + TORCH_CHECK(dsink.stride(-1) == 1, "dsink must have contiguous last dimension"); + CHECK_SHAPE(dsink, num_heads); + } else { + dsink = torch::empty_like(sink); + } + + // bool loop = seqlen_k > blocksize_c; + // TODO: change later, for now set to true for simplicity + bool loop = true; + + auto opts = q.options(); + auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat)); + at::Tensor dq_accum; + at::Tensor dk_accum, dv_accum; + if (loop) { + if (!deterministic) { + dq_accum = torch::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); + } else { + const int nsplits = (get_num_sm(get_current_device()) + batch_size * num_heads - 1) / (batch_size * num_heads); + dq_accum = torch::zeros({nsplits, batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); + } + // dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); + // dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); + } + + at::Tensor dk_expanded, dv_expanded; + if (num_heads_k != num_heads) { // MQA / GQA + dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts); + dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts); + } else { + dk_expanded = dk; + dv_expanded = dv; + } + + Flash_bwd_params params; + + set_params_dgrad(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, k, v, out, + dout, dq, dk_expanded, dv_expanded, + nullptr, + nullptr, + loop ? dq_accum.data_ptr() : nullptr, + // loop ? dk_accum.data_ptr() : nullptr, + // loop ? dv_accum.data_ptr() : nullptr, + nullptr, + nullptr, + softmax_lse.data_ptr(), + softmax_d.data_ptr(), + p_dropout, + softmax_scale, + window_size_left, + window_size_right, + softcap, + deterministic, + /*unpadded_lse*/false); + params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0); + params.sink_ptr = sink.data_ptr(); + params.dsink_ptr = dsink.data_ptr(); + + auto launch = &run_mha_bwd; + + auto gen = at::get_generator_or_default( + gen_, at::cuda::detail::getDefaultCUDAGenerator()); + + // We use a custom RNG that increases the offset by batch_size * nheads * 32. + int64_t counter_offset = params.b * params.h * 32; + + if ( rng_state.has_value() ) { + params.rng_state = reinterpret_cast(rng_state.value().data_ptr()); + } else if( is_dropout ) { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + params.philox_args = gen->philox_cuda_state(counter_offset); + auto seeds = at::cuda::philox::unpack(params.philox_args); + params.rng_state[0] = std::get<0>(seeds); + params.rng_state[1] = std::get<1>(seeds); + } + + set_params_alibi(params, alibi_slopes_, batch_size, num_heads); + + if (seqlen_q > 0) { + launch(params, stream); + } else { + // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. + dk_expanded.zero_(); + dv_expanded.zero_(); + softmax_d.zero_(); + } + + // For MQA/GQA we need to sum dK and dV across the groups + if (num_heads_k != num_heads) { + at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); + at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); + } + + return { dq, dk, dv, dsink, softmax_d }; +} + std::vector mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i @@ -1479,7 +1868,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "FlashAttention"; m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass"); m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass (variable length)"); + m.def("sink_fwd", &FLASH_NAMESPACE::mha_sink_fwd, "Forward pass (with sink)"); m.def("bwd", &FLASH_NAMESPACE::mha_bwd, "Backward pass"); m.def("varlen_bwd", &FLASH_NAMESPACE::mha_varlen_bwd, "Backward pass (variable length)"); + m.def("sink_bwd", &FLASH_NAMESPACE::mha_sink_bwd, "Backward pass (with sink)"); m.def("fwd_kvcache", &FLASH_NAMESPACE::mha_fwd_kvcache, "Forward pass, with KV-cache"); } diff --git a/csrc/flash_attn/src/flash.h b/csrc/flash_attn/src/flash.h index 8ffbb62d66e..89613ddb085 100644 --- a/csrc/flash_attn/src/flash.h +++ b/csrc/flash_attn/src/flash.h @@ -140,6 +140,8 @@ struct Flash_fwd_params : public Qkv_params { bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q]. bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d). + + void *__restrict__ sink_ptr; // For gpt_oss }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -182,6 +184,8 @@ struct Flash_bwd_params : public Flash_fwd_params { bool deterministic; index_t dq_accum_split_stride; + + void *__restrict__ dsink_ptr; // For gpt_oss }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/flash_attn/src/static_switch.h b/csrc/flash_attn/src/static_switch.h index 70d14daf69d..7daa1592b0d 100644 --- a/csrc/flash_attn/src/static_switch.h +++ b/csrc/flash_attn/src/static_switch.h @@ -109,3 +109,13 @@ return __VA_ARGS__(); \ } \ }() + +#ifdef FLASHATTENTION_DISABLE_SINK + #define SINK_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define SINK_SWITCH BOOL_SWITCH +#endif diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index 69eae460e36..7657c85eec1 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -2,6 +2,7 @@ from flash_attn.flash_attn_interface import ( flash_attn_func, + flash_attn_sink_func, flash_attn_kvpacked_func, flash_attn_qkvpacked_func, flash_attn_varlen_func, diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 1e041e4538d..ae5850b9b20 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -133,7 +133,7 @@ def _flash_attn_forward_fake( return out, softmax_lse, p, rng_state -if torch.__version__ >= "2.4.0": +if torch.__version__ >= "12.4.0": _wrapped_flash_attn_forward = torch.ops.flash_attn._flash_attn_forward else: _wrapped_flash_attn_forward = _flash_attn_forward @@ -888,6 +888,188 @@ def backward(ctx, dout, *args): dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] return dq, dk, dv, None, None, None, None, None, None, None, None, None + +@_torch_custom_op_wrapper("flash_attn::_flash_attn_sink_forward", mutates_args=(), device_types="cuda") +def _flash_attn_sink_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sink: torch.Tensor, + dropout_p: float, + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + alibi_slopes: Optional[torch.Tensor], + return_softmax: bool +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + q, k, v, sink = [maybe_contiguous(x) for x in (q, k, v, sink)] + out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.sink_fwd( + q, + k, + v, + sink, + None, + alibi_slopes, + dropout_p, + softmax_scale, + causal, + window_size_left, + window_size_right, + softcap, + return_softmax, + None, + ) + return out, softmax_lse, S_dmask, rng_state + +@_torch_custom_op_wrapper("flash_attn::_flash_attn_sink_backward", mutates_args=("dq", "dk", "dv", "dsink"), device_types="cuda") +def _flash_attn_sink_backward( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sink: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + dsink: Optional[torch.Tensor], + dropout_p: float, + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + alibi_slopes: Optional[torch.Tensor], + deterministic: bool, + rng_state: Optional[torch.Tensor] = None, +) -> torch.Tensor: + # dq, dk, dv are allocated by us so they should already be contiguous + dout, q, k, v, sink, out = [maybe_contiguous(x) for x in (dout, q, k, v, sink, out)] + ( + dq, + dk, + dv, + dsink, + softmax_d, + ) = flash_attn_gpu.sink_bwd( + dout, + q, + k, + v, + sink, + out, + softmax_lse, + dq, + dk, + dv, + dsink, + alibi_slopes, + dropout_p, + softmax_scale, + causal, + window_size_left, + window_size_right, + softcap, + deterministic, + None, + rng_state, + ) + return softmax_d + +_wrapped_flash_attn_sink_forward = _flash_attn_sink_forward +_wrapped_flash_attn_sink_backward = _flash_attn_sink_backward + +class FlashAttnSinkFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + sink, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_softmax, + is_grad_enabled, + ): + is_grad = is_grad_enabled and any( + x.requires_grad for x in [q, k, v] + ) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + head_size_og = q.size(3) + if head_size_og % 8 != 0: + q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) + k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) + v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) + out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_sink_forward( + q, + k, + v, + sink, + dropout_p, + softmax_scale, + causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=return_softmax and dropout_p > 0, + ) + if is_grad: + ctx.save_for_backward(q, k, v, sink, out_padded, softmax_lse, rng_state) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + out = out_padded[..., :head_size_og] + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, sink, out, softmax_lse, rng_state = ctx.saved_tensors + dq, dk, dv, dsink = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v), torch.empty_like(sink) + head_size_og = dout.size(3) + dout_padded = dout + if head_size_og % 8 != 0: + dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) + _wrapped_flash_attn_sink_backward( + dout_padded, + q, + k, + v, + sink, + out, + softmax_lse, + dq, + dk, + dv, + dsink, + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + ctx.window_size[0], + ctx.window_size[1], + ctx.softcap, + ctx.alibi_slopes, + ctx.deterministic, + rng_state=rng_state, + ) + dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension + dk = dk[..., : dout.shape[-1]] + dv = dv[..., : dout.shape[-1]] + return dq, dk, dv, dsink, None, None, None, None, None, None, None, None, None class FlashAttnVarlenFunc(torch.autograd.Function): @@ -1209,6 +1391,37 @@ def flash_attn_func( ) +def flash_attn_sink_func( + q, + k, + v, + sink, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, +): + return FlashAttnSinkFunc.apply( + q, + k, + v, + sink, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + torch.is_grad_enabled(), + ) + + def flash_attn_varlen_qkvpacked_func( qkv, cu_seqlens, From 5501390be220aa9f4b13d64fbf4cc10f4efab0fb Mon Sep 17 00:00:00 2001 From: jerryao Date: Fri, 15 Aug 2025 11:46:18 +0800 Subject: [PATCH 02/29] Add tests for sink. --- benchmarks/benchmark_flash_attention.py | 29 ++- benchmarks/flash_attn_with_sink.py | 207 ++++++++++++++++++++ benchmarks/flash_attn_with_sink_fused.py | 212 ++++++++++++++++++++ benchmarks/naive_attn_with_sink.py | 64 ++++++ benchmarks/test.py | 232 ++++++++++++++++++++++ benchmarks/test_fused.py | 237 +++++++++++++++++++++++ 6 files changed, 980 insertions(+), 1 deletion(-) create mode 100644 benchmarks/flash_attn_with_sink.py create mode 100644 benchmarks/flash_attn_with_sink_fused.py create mode 100644 benchmarks/naive_attn_with_sink.py create mode 100644 benchmarks/test.py create mode 100644 benchmarks/test_fused.py diff --git a/benchmarks/benchmark_flash_attention.py b/benchmarks/benchmark_flash_attention.py index 341ae4b2139..f3030bc86b6 100644 --- a/benchmarks/benchmark_flash_attention.py +++ b/benchmarks/benchmark_flash_attention.py @@ -12,6 +12,8 @@ from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined from flash_attn import flash_attn_qkvpacked_func +from flash_attn import flash_attn_func +from flash_attn_with_sink import flash_attn_with_sink_func try: from triton.ops.flash_attention import attention as attention_triton @@ -77,7 +79,7 @@ def time_fwd_bwd(func, *args, **kwargs): dim = 2048 dropout_p = 0.0 -methods = (["Flash2", "Pytorch"] +methods = (["Flash2", "Flash2UnPacked", "Pytorch", "Flash2Sink"] + (["Triton"] if attention_triton is not None else []) + (["xformers.c"] if xops is not None else []) + (["xformers.f"] if xops is not None else [])) @@ -101,6 +103,14 @@ def time_fwd_bwd(func, *args, **kwargs): time_f[config, "Flash2"] = f time_b[config, "Flash2"] = b + q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, + requires_grad=True) for _ in range(3)] + f, b = time_fwd_bwd( + flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=False + ) + time_f[config, "Flash2UnPacked"] = f + time_b[config, "Flash2UnPacked"] = b + try: qkv = qkv.detach().requires_grad_(True) f, b = time_fwd_bwd( @@ -111,6 +121,23 @@ def time_fwd_bwd(func, *args, **kwargs): time_f[config, "Pytorch"] = f time_b[config, "Pytorch"] = b + try: + scaling = nheads**-0.5 + num_key_value_heads = nheads # // 8 + q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, + requires_grad=True) + k, v = [torch.randn(batch_size, seqlen, num_key_value_heads, headdim, device=device, dtype=dtype, + requires_grad=True) for _ in range(2)] + sink = torch.randn((nheads,), dtype=dtype, device=device, requires_grad=True) + + f, b = time_fwd_bwd( + flash_attn_with_sink_func, q, k, v, sink, softmax_scale=scaling, dropout_p=dropout_p, causal=causal, repeats=repeats, verbose=False + ) + except: # Skip if OOM + f, b = float('nan'), float('nan') + time_f[config, "Flash2Sink"] = f + time_b[config, "Flash2Sink"] = b + if attention_triton is not None: q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype, requires_grad=True) for _ in range(3)] diff --git a/benchmarks/flash_attn_with_sink.py b/benchmarks/flash_attn_with_sink.py new file mode 100644 index 00000000000..801f27a789b --- /dev/null +++ b/benchmarks/flash_attn_with_sink.py @@ -0,0 +1,207 @@ +import torch +from flash_attn import flash_attn_func +from flash_attn.flash_attn_interface import _flash_attn_backward + + +class FlashAttentionWithSink(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + q, + k, + v, + sink: torch.Tensor, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + ): + # Check device + if q.device.type != 'cuda': + raise RuntimeError( + f"Flash Attention only supports CUDA devices, " + f"current device: {q.device}" + ) + + ctx.save_for_backward(q, k, v, sink, alibi_slopes) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.deterministic = deterministic + ctx.return_attn_probs = return_attn_probs + ctx.sink_shape = sink.shape # Save original sink shape + + # import pdb; pdb.set_trace() + + out, lse, _ = flash_attn_func( + q, + k, + v, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + print("==== lse shape: ", lse.shape) + + origin_dtype = out.dtype + + ctx.raw_output = out.clone() + ctx.lse = lse.clone() + + lse = lse.transpose(-2, -1).unsqueeze(dim=-1) + sink = sink.reshape(1, 1, -1, 1) + + multiplier = 1 / (torch.exp(sink - lse) + 1) + out = (out * multiplier).to(origin_dtype) + + return out + + @staticmethod + def backward(ctx, grad_output): + q, k, v, sink, alibi_slopes = ctx.saved_tensors + raw_output = ctx.raw_output + lse = ctx.lse + + lse = lse.transpose(-2, -1).unsqueeze(dim=-1) + sink_reshaped = sink.reshape(1, 1, -1, 1) + multiplier = 1 / (torch.exp(sink_reshaped - lse) + 1) + + # 1) Main path via multiplier + grad_raw_output = (grad_output * multiplier).to(q.dtype) + + # Use flash attention backward function for main path + grad_q_main = torch.empty_like(q) + grad_k_main = torch.empty_like(k) + grad_v = torch.empty_like(v) + + _flash_attn_backward( + grad_raw_output, # dout: main path gradient + q, # q + k, # k + v, # v + ctx.raw_output, # out: original output + lse, # softmax_lse + grad_q_main, # dq: main path + grad_k_main, # dk: main path + grad_v, # dv: main path + ctx.dropout_p, # dropout_p + ctx.softmax_scale, # softmax_scale + ctx.causal, # causal + ctx.window_size[0], # window_size_left + ctx.window_size[1], # window_size_right + ctx.softcap, # softcap + alibi_slopes, # alibi_slopes + ctx.deterministic, # deterministic + ) + + # 2) Sink gradient path + # g_r = (grad_output * raw_output).sum(dim=-1) # [B,H,Nq] + g_r = torch.sum(grad_output * raw_output, dim=-1) + + # g_ell = g_r * multiplier * (1 - multiplier) # [B,H,Nq] + # Based on debug output: + # g_r shape: [1, 512, 64] (batch, seq_len, heads) + # multiplier shape: [1, 512, 64, 1] (batch, seq_len, heads, 1) + # We need multiplier_for_grad to have shape [1, 512, 64] + # [1, 512, 64, 1] -> [1, 512, 64] + multiplier_for_grad = multiplier.squeeze(-1) + + g_ell = g_r * multiplier_for_grad * (1 - multiplier_for_grad) + # Based on shapes: g_ell [1, 512, 64], we need to sum over seq_len (dim=1) + # to get [1, 64], then sum over batch (dim=0) to get [64] + grad_sink = -torch.sum(g_ell, dim=1) # Sum over seq_len -> [1, 64] + # Sum over batch dimension and reshape to match original sink shape + grad_sink = grad_sink.sum(dim=0) # Sum over batch -> [64] + grad_sink = grad_sink.reshape(ctx.sink_shape) + + # 3) Additional Q gradient via sink + # dQ_extra = scale * g_ell * attention(Q,K,K) + scale = ctx.softmax_scale or (1.0 / q.shape[-1] ** 0.5) + + # Compute attention(Q,K,K) for additional Q gradient + mu_k = flash_attn_func( + q, k, k, + dropout_p=ctx.dropout_p, + softmax_scale=ctx.softmax_scale, + causal=ctx.causal, + window_size=ctx.window_size, + softcap=ctx.softcap, + alibi_slopes=alibi_slopes, + deterministic=ctx.deterministic, + return_attn_probs=False, + ) + grad_q_extra = scale * g_ell.unsqueeze(-1) * mu_k + + # 4) Additional K gradient via sink + # dK_extra = scale * P^T (g_ell * Q) + x = (g_ell.unsqueeze(-1) * q).to(q.dtype) + + # Use flash attention backward to compute P^T X + grad_k_extra = torch.empty_like(k) + _flash_attn_backward( + x, # dout: g_ell * Q + q, # q + k, # k + k, # v (dummy, using K as V) + ctx.raw_output, # out: original output + lse, # softmax_lse + None, # dq: not needed + None, # dk: not needed + grad_k_extra, # dv: this will be dK_extra + ctx.dropout_p, # dropout_p + ctx.softmax_scale, # softmax_scale + ctx.causal, # causal + ctx.window_size[0], # window_size_left + ctx.window_size[1], # window_size_right + ctx.softcap, # softcap + alibi_slopes, # alibi_slopes + ctx.deterministic, # deterministic + ) + grad_k_extra = scale * grad_k_extra + + # 5) Sum all gradients + grad_q = grad_q_main + grad_q_extra + grad_k = grad_k_main + grad_k_extra + # grad_v already from main path + + return (grad_q, grad_k, grad_v, grad_sink, None, None, None, None, + None, None, None, None) + + +def flash_attn_with_sink_func( + q, + k, + v, + sink: torch.Tensor, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, +): + # Check CUDA availability + if not torch.cuda.is_available(): + raise RuntimeError( + "Flash Attention requires CUDA devices. " + "Current device does not support CUDA." + ) + + return FlashAttentionWithSink.apply( + q, k, v, sink, dropout_p, softmax_scale, causal, + window_size, softcap, alibi_slopes, deterministic, return_attn_probs + ) diff --git a/benchmarks/flash_attn_with_sink_fused.py b/benchmarks/flash_attn_with_sink_fused.py new file mode 100644 index 00000000000..a554a7cf371 --- /dev/null +++ b/benchmarks/flash_attn_with_sink_fused.py @@ -0,0 +1,212 @@ +import torch +from flash_attn import flash_attn_sink_func +from flash_attn.flash_attn_interface import _flash_attn_sink_backward + + +class FlashAttentionWithSinkFused(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + q, + k, + v, + sink: torch.Tensor, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + ): + # Check device + if q.device.type != 'cuda': + raise RuntimeError( + f"Flash Attention only supports CUDA devices, " + f"current device: {q.device}" + ) + + ctx.save_for_backward(q, k, v, sink, alibi_slopes) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.deterministic = deterministic + ctx.return_attn_probs = return_attn_probs + ctx.sink_shape = sink.shape # Save original sink shape + + # import pdb; pdb.set_trace() + + out, lse, _ = flash_attn_sink_func( + q, + k, + v, + sink, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + print("==== lse shape: ", lse.shape) + + origin_dtype = out.dtype + + ctx.raw_output = out.clone() + ctx.lse = lse.clone() + + lse = lse.transpose(-2, -1).unsqueeze(dim=-1) + # sink = sink.reshape(1, 1, -1, 1) + + # multiplier = 1 / (torch.exp(sink - lse) + 1) + # out = (out * multiplier).to(origin_dtype) + + return out + + @staticmethod + def backward(ctx, grad_output): + q, k, v, sink, alibi_slopes = ctx.saved_tensors + raw_output = ctx.raw_output + lse = ctx.lse + + lse = lse.transpose(-2, -1).unsqueeze(dim=-1) + sink_reshaped = sink.reshape(1, 1, -1, 1) + multiplier = 1 / (torch.exp(sink_reshaped - lse) + 1) + + # 1) Main path via multiplier + grad_raw_output = (grad_output * multiplier).to(q.dtype) + + # Use flash attention backward function for main path + grad_q_main = torch.empty_like(q) + grad_k_main = torch.empty_like(k) + grad_v = torch.empty_like(v) + + _flash_attn_sink_backward( + grad_raw_output, # dout: main path gradient + q, # q + k, # k + v, # v + sink, + ctx.raw_output, # out: original output + lse, # softmax_lse + grad_q_main, # dq: main path + grad_k_main, # dk: main path + grad_v, # dv: main path + torch.empty_like(sink), # ds: sink gradient + ctx.dropout_p, # dropout_p + ctx.softmax_scale, # softmax_scale + ctx.causal, # causal + ctx.window_size[0], # window_size_left + ctx.window_size[1], # window_size_right + ctx.softcap, # softcap + alibi_slopes, # alibi_slopes + ctx.deterministic, # deterministic + ) + + # 2) Sink gradient path + # g_r = (grad_output * raw_output).sum(dim=-1) # [B,H,Nq] + g_r = torch.sum(grad_output * raw_output, dim=-1) + + # g_ell = g_r * multiplier * (1 - multiplier) # [B,H,Nq] + # Based on debug output: + # g_r shape: [1, 512, 64] (batch, seq_len, heads) + # multiplier shape: [1, 512, 64, 1] (batch, seq_len, heads, 1) + # We need multiplier_for_grad to have shape [1, 512, 64] + # [1, 512, 64, 1] -> [1, 512, 64] + multiplier_for_grad = multiplier.squeeze(-1) + + g_ell = g_r * multiplier_for_grad * (1 - multiplier_for_grad) + # Based on shapes: g_ell [1, 512, 64], we need to sum over seq_len (dim=1) + # to get [1, 64], then sum over batch (dim=0) to get [64] + grad_sink = -torch.sum(g_ell, dim=1) # Sum over seq_len -> [1, 64] + # Sum over batch dimension and reshape to match original sink shape + grad_sink = grad_sink.sum(dim=0) # Sum over batch -> [64] + grad_sink = grad_sink.reshape(ctx.sink_shape) + + # 3) Additional Q gradient via sink + # dQ_extra = scale * g_ell * attention(Q,K,K) + scale = ctx.softmax_scale or (1.0 / q.shape[-1] ** 0.5) + + # Compute attention(Q,K,K) for additional Q gradient + mu_k = flash_attn_sink_func( + q, k, k, sink, + dropout_p=ctx.dropout_p, + softmax_scale=ctx.softmax_scale, + causal=ctx.causal, + window_size=ctx.window_size, + softcap=ctx.softcap, + alibi_slopes=alibi_slopes, + deterministic=ctx.deterministic, + return_attn_probs=False, + ) + grad_q_extra = scale * g_ell.unsqueeze(-1) * mu_k + + # 4) Additional K gradient via sink + # dK_extra = scale * P^T (g_ell * Q) + x = (g_ell.unsqueeze(-1) * q).to(q.dtype) + + # Use flash attention backward to compute P^T X + grad_k_extra = torch.empty_like(k) + _flash_attn_sink_backward( + x, # dout: g_ell * Q + q, # q + k, # k + k, # v (dummy, using K as V) + sink, + ctx.raw_output, # out: original output + lse, # softmax_lse + None, # dq: not needed + None, # dk: not needed + grad_k_extra, # dv: this will be dK_extra + torch.empty_like(sink), + ctx.dropout_p, # dropout_p + ctx.softmax_scale, # softmax_scale + ctx.causal, # causal + ctx.window_size[0], # window_size_left + ctx.window_size[1], # window_size_right + ctx.softcap, # softcap + alibi_slopes, # alibi_slopes + ctx.deterministic, # deterministic + ) + grad_k_extra = scale * grad_k_extra + + # 5) Sum all gradients + grad_q = grad_q_main + grad_q_extra + grad_k = grad_k_main + grad_k_extra + # grad_v already from main path + + return (grad_q, grad_k, grad_v, grad_sink, None, None, None, None, + None, None, None, None) + + +def flash_attn_with_sink_fused_func( + q, + k, + v, + sink: torch.Tensor, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, +): + # Check CUDA availability + if not torch.cuda.is_available(): + raise RuntimeError( + "Flash Attention requires CUDA devices. " + "Current device does not support CUDA." + ) + + return FlashAttentionWithSinkFused.apply( + q, k, v, sink, dropout_p, softmax_scale, causal, + window_size, softcap, alibi_slopes, deterministic, return_attn_probs + ) diff --git a/benchmarks/naive_attn_with_sink.py b/benchmarks/naive_attn_with_sink.py new file mode 100644 index 00000000000..79d7be6eacb --- /dev/null +++ b/benchmarks/naive_attn_with_sink.py @@ -0,0 +1,64 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). + The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) + to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + print(num_key_value_heads, n_rep) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# from https://github.com/huggingface/transformers/blob/369c99d0cea403b77bd0aef818527106453fd9fc/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L227 +def eager_attention_forward( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sink: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + num_key_value_groups: int = 8, + **kwargs, +): + key_states = repeat_kv(key, num_key_value_groups) + value_states = repeat_kv(value, num_key_value_groups) + print("==== sink: ", sink.shape, sink) + print("==== query shape: ", query.shape) + print("==== key_states shape: ", key_states.shape) + print("==== value_states shape: ", value_states.shape) + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + print("==== attn_weights shape: ", attn_weights.shape) + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + print("==== causal_mask shape: ", causal_mask.shape) + attn_weights = attn_weights + causal_mask + + sinks = sink.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) + print("==== sinks shape: ", sinks.shape) + combined_logits = torch.cat([attn_weights, sinks], dim=-1) + print("==== combined_logits shape: ", combined_logits.shape) + # This was not in the original implementation and slightly affect results; + # it prevents overflow in BF16/FP16 when training with bsz>1 we clamp max values. + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + + probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) + print("==== probs shape: ", probs.shape) + scores = probs[..., :-1] # we drop the sink here + print("==== scores shape: ", scores.shape) + attn_weights = nn.functional.dropout(scores, p=dropout, training=True) + attn_output = torch.matmul(attn_weights, value_states) + print("==== attn_output shape: ", attn_output.shape) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights diff --git a/benchmarks/test.py b/benchmarks/test.py new file mode 100644 index 00000000000..7352b949b8c --- /dev/null +++ b/benchmarks/test.py @@ -0,0 +1,232 @@ +import torch +import torch.nn.functional as F +from flash_attn_with_sink import flash_attn_with_sink_func +from naive_attn_with_sink import eager_attention_forward + + +if __name__ == "__main__": + batch = 1 + num_attention_heads = 64 + num_key_value_heads = 8 + num_key_value_groups = num_attention_heads // num_key_value_heads + head_dim = 64 + seq_len = 512 + scaling = head_dim**-0.5 + torch.manual_seed(0) + + torch.cuda.set_device(0) + query = torch.randn( + (batch, num_attention_heads, seq_len, head_dim), + dtype=torch.bfloat16, + device="cuda", + requires_grad=True, + ) + key = torch.randn( + (batch, num_key_value_heads, seq_len, head_dim), + dtype=torch.bfloat16, + device="cuda", + requires_grad=True, + ) + value = torch.randn( + (batch, num_key_value_heads, seq_len, head_dim), + dtype=torch.bfloat16, + device="cuda", + requires_grad=True, + ) + sink = torch.randn( + (num_attention_heads,), + dtype=torch.bfloat16, + device="cuda", + requires_grad=True, + ) + + # Create causal attention mask + # The mask should be of shape (batch, num_heads, seq_len, seq_len) + # For causal attention, we mask out future positions + # (set them to large negative value) + attention_mask = torch.triu( + torch.full( + (seq_len, seq_len), float("-inf"), device="cuda", dtype=torch.bfloat16 + ), + diagonal=1, + ) + attention_mask = ( + attention_mask.unsqueeze(0) + .unsqueeze(0) + .expand(batch, num_attention_heads, -1, -1) + ) + + print("Running eager attention forward...") + eager_output, eager_weights = eager_attention_forward( + query, + key, + value, + sink, + attention_mask=attention_mask, + scaling=scaling, + dropout=0.0, + num_key_value_groups=num_key_value_groups, + ) + + print("Running flash attention forward...") + # Reshape tensors for flash attention (batch, seq_len, num_heads, head_dim) + q_flash = query.transpose(1, 2) # (batch, seq_len, num_heads, head_dim) + k_flash = key.transpose(1, 2) # (batch, seq_len, num_kv_heads, head_dim) + v_flash = value.transpose(1, 2) # (batch, seq_len, num_kv_heads, head_dim) + + flash_output = flash_attn_with_sink_func( + q_flash, + k_flash, + v_flash, + sink, + softmax_scale=scaling, + dropout_p=0.0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + ) + + # Compare outputs + print(f"Eager output shape: {eager_output.shape}, dtype: {eager_output.dtype}") + print(f"Flash output shape: {flash_output.shape}, dtype: {flash_output.dtype}") + + print( + f"Max absolute difference: {torch.max(torch.abs(eager_output - flash_output))}" + ) + print( + f"Mean absolute difference: {torch.mean(torch.abs(eager_output - flash_output))}" + ) + print( + f"Relative error: {torch.mean(torch.abs(eager_output - flash_output) / (torch.abs(eager_output) + 1e-8))}" + ) + + print("\nEager output sample (first 5x5 elements):") + print(eager_output[0, 0, :5, :5]) + print("\nFlash output sample (first 5x5 elements):") + print(flash_output[0, 0, :5, :5]) + + # Test backward pass + print("\n" + "=" * 50) + print("Testing backward pass...") + + # Reset gradients (handle None case) + if query.grad is not None: + query.grad.zero_() + if key.grad is not None: + key.grad.zero_() + if value.grad is not None: + value.grad.zero_() + if sink.grad is not None: + sink.grad.zero_() + + # Compute loss for eager attention + target = torch.randn_like(eager_output, device="cuda") + eager_loss = F.mse_loss(eager_output, target) * 1000 + + print(f"Eager loss: {eager_loss.item():.6f}") + + # Backward pass for eager attention + eager_loss.backward() + + # Save eager gradients + eager_query_grad = query.grad.clone() + eager_key_grad = key.grad.clone() + eager_value_grad = value.grad.clone() + eager_sink_grad = sink.grad.clone() + + print("\nEager gradient information:") + print(f"Query gradient norm: {eager_query_grad.norm().item():.6f}") + print(f"Key gradient norm: {eager_key_grad.norm().item():.6f}") + print(f"Value gradient norm: {eager_value_grad.norm().item():.6f}") + print(f"Sink gradient norm: {eager_sink_grad.norm().item():.6f}") + + # Reset gradients for flash attention (handle None case) + if query.grad is not None: + query.grad.zero_() + if key.grad is not None: + key.grad.zero_() + if value.grad is not None: + value.grad.zero_() + if sink.grad is not None: + sink.grad.zero_() + + # Compute loss for flash attention + flash_loss = F.mse_loss(flash_output, target) * 1000 + + print(f"\nFlash loss: {flash_loss.item():.6f}") + + # Backward pass for flash attention + flash_loss.backward() + + # Save flash gradients + flash_query_grad = query.grad.clone() + flash_key_grad = key.grad.clone() + flash_value_grad = value.grad.clone() + flash_sink_grad = sink.grad.clone() + + print("\nFlash gradient information:") + print(f"Query gradient norm: {flash_query_grad.norm().item():.6f}") + print(f"Key gradient norm: {flash_key_grad.norm().item():.6f}") + print(f"Value gradient norm: {flash_value_grad.norm().item():.6f}") + print(f"Sink gradient norm: {flash_sink_grad.norm().item():.6f}") + + # Compare gradients + print("\n" + "=" * 50) + print("Comparing gradients...") + + # Calculate gradient differences + query_grad_diff = torch.abs(eager_query_grad - flash_query_grad).max().item() + key_grad_diff = torch.abs(eager_key_grad - flash_key_grad).max().item() + value_grad_diff = torch.abs(eager_value_grad - flash_value_grad).max().item() + sink_grad_diff = torch.abs(eager_sink_grad - flash_sink_grad).max().item() + + print(f"Query gradient max difference: {query_grad_diff:.2e}") + print(f"Key gradient max difference: {key_grad_diff:.2e}") + print(f"Value gradient max difference: {value_grad_diff:.2e}") + print(f"Sink gradient max difference: {sink_grad_diff:.2e}") + + # Check if gradients are close (within tolerance) + tolerance = 1e-3 # Adjust tolerance as needed + query_grad_close = query_grad_diff < tolerance + key_grad_close = key_grad_diff < tolerance + value_grad_close = value_grad_diff < tolerance + sink_grad_close = sink_grad_diff < tolerance + + print(f"\nGradient comparison (tolerance: {tolerance}):") + print(f"Query gradients close: {'✅' if query_grad_close else '❌'}") + print(f"Key gradients close: {'✅' if key_grad_close else '❌'}") + print(f"Value gradients close: {'✅' if value_grad_close else '❌'}") + print(f"Sink gradients close: {'✅' if sink_grad_close else '❌'}") + + # Check if gradients are non-zero + query_grad_zero = eager_query_grad.norm().item() < 1e-8 + key_grad_zero = eager_key_grad.norm().item() < 1e-8 + value_grad_zero = eager_value_grad.norm().item() < 1e-8 + sink_grad_zero = eager_sink_grad.norm().item() < 1e-8 + + print(f"\nGradient non-zero check:") + print(f"Query gradient non-zero: {'✅' if not query_grad_zero else '❌'}") + print(f"Key gradient non-zero: {'✅' if not key_grad_zero else '❌'}") + print(f"Value gradient non-zero: {'✅' if not value_grad_zero else '❌'}") + print(f"Sink gradient non-zero: {'✅' if not sink_grad_zero else '❌'}") + + all_grads_close = ( + query_grad_close and key_grad_close and value_grad_close and sink_grad_close + ) + all_grads_nonzero = not ( + query_grad_zero or key_grad_zero or value_grad_zero or sink_grad_zero + ) + + print(f"\nOverall result:") + print(f" All gradients close: {'✅' if all_grads_close else '❌'}") + print(f" All gradients non-zero: {'✅' if all_grads_nonzero else '❌'}") + + if all_grads_close and all_grads_nonzero: + print("\n🎉 Backward test passed! Gradients match and are non-zero.") + else: + print("\n❌ Backward test failed!") + if not all_grads_close: + print(" - Some gradients don't match between eager and flash attention") + if not all_grads_nonzero: + print(" - Some gradients are zero") diff --git a/benchmarks/test_fused.py b/benchmarks/test_fused.py new file mode 100644 index 00000000000..ee348626b23 --- /dev/null +++ b/benchmarks/test_fused.py @@ -0,0 +1,237 @@ +import torch +import torch.nn.functional as F +from flash_attn_with_sink_fused import flash_attn_with_sink_fused_func +from naive_attn_with_sink import eager_attention_forward + + +if __name__ == "__main__": + batch = 1 + num_attention_heads = 64 + num_key_value_heads = 8 + num_key_value_groups = num_attention_heads // num_key_value_heads + head_dim = 64 + seq_len = 8 + scaling = head_dim**-0.5 + torch.manual_seed(0) + + torch.cuda.set_device(0) + query = torch.randn( + (batch, num_attention_heads, seq_len, head_dim), + dtype=torch.bfloat16, + device="cuda", + requires_grad=True, + ) + key = torch.randn( + (batch, num_key_value_heads, seq_len, head_dim), + dtype=torch.bfloat16, + device="cuda", + requires_grad=True, + ) + value = torch.randn( + (batch, num_key_value_heads, seq_len, head_dim), + dtype=torch.bfloat16, + device="cuda", + requires_grad=True, + ) + sink = torch.randn( + (num_attention_heads,), + dtype=torch.bfloat16, + device="cuda", + requires_grad=True, + ) + print("sink = ", sink) + + # Create causal attention mask + # The mask should be of shape (batch, num_heads, seq_len, seq_len) + # For causal attention, we mask out future positions + # (set them to large negative value) + attention_mask = torch.triu( + torch.full( + (seq_len, seq_len), float("-inf"), device="cuda", dtype=torch.bfloat16 + ), + diagonal=1, + ) + attention_mask = ( + attention_mask.unsqueeze(0) + .unsqueeze(0) + .expand(batch, num_attention_heads, -1, -1) + ) + + print("Running eager attention forward...") + eager_output, eager_weights = eager_attention_forward( + query, + key, + value, + sink, + attention_mask=attention_mask, + scaling=scaling, + dropout=0.0, + num_key_value_groups=num_key_value_groups, + ) + + print("Running flash attention forward...") + # Reshape tensors for flash attention (batch, seq_len, num_heads, head_dim) + q_flash = query.transpose(1, 2) # (batch, seq_len, num_heads, head_dim) + k_flash = key.transpose(1, 2) # (batch, seq_len, num_kv_heads, head_dim) + v_flash = value.transpose(1, 2) # (batch, seq_len, num_kv_heads, head_dim) + + flash_output = flash_attn_with_sink_fused_func( + q_flash, + k_flash, + v_flash, + sink, + softmax_scale=scaling, + dropout_p=0.0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + ) + + # Compare outputs + print(f"Eager output shape: {eager_output.shape}, dtype: {eager_output.dtype}") + print(f"Flash output shape: {flash_output.shape}, dtype: {flash_output.dtype}") + + print( + f"Max absolute difference: {torch.max(torch.abs(eager_output - flash_output))}" + ) + print( + f"Mean absolute difference: {torch.mean(torch.abs(eager_output - flash_output))}" + ) + print( + f"Relative error: {torch.mean(torch.abs(eager_output - flash_output) / (torch.abs(eager_output) + 1e-8))}" + ) + + print("\nEager output sample (first 8x8 elements):") + print(eager_output[0, 0, :8, :8]) + print("\nFlash output sample (first 8x8 elements):") + print(flash_output[0, 0, :8, :8]) + print(eager_output[0, 0, :8, :8]/flash_output[0, 0, :8, :8]) + + # Test backward pass + print("\n" + "=" * 50) + print("Testing backward pass...") + + # Reset gradients (handle None case) + if query.grad is not None: + query.grad.zero_() + if key.grad is not None: + key.grad.zero_() + if value.grad is not None: + value.grad.zero_() + if sink.grad is not None: + sink.grad.zero_() + + # Compute loss for eager attention + target = torch.randn_like(eager_output, device="cuda") + eager_loss = F.mse_loss(eager_output, target) * 1000 + + print(f"Eager loss: {eager_loss.item():.6f}") + + # Backward pass for eager attention + eager_loss.backward() + + # Save eager gradients + eager_query_grad = query.grad.clone() + eager_key_grad = key.grad.clone() + eager_value_grad = value.grad.clone() + eager_sink_grad = sink.grad.clone() + + print("\nEager gradient information:") + print(f"Query gradient norm: {eager_query_grad.norm().item():.6f}") + print(f"Key gradient norm: {eager_key_grad.norm().item():.6f}") + print(f"Value gradient norm: {eager_value_grad.norm().item():.6f}") + print(f"Sink gradient norm: {eager_sink_grad.norm().item():.6f}") + + # Reset gradients for flash attention (handle None case) + if query.grad is not None: + query.grad.zero_() + if key.grad is not None: + key.grad.zero_() + if value.grad is not None: + value.grad.zero_() + if sink.grad is not None: + sink.grad.zero_() + + # Compute loss for flash attention + flash_loss = F.mse_loss(flash_output, target) * 1000 + + print(f"\nFlash loss: {flash_loss.item():.6f}") + + # Backward pass for flash attention + flash_loss.backward() + + # Save flash gradients + flash_query_grad = query.grad.clone() + flash_key_grad = key.grad.clone() + flash_value_grad = value.grad.clone() + flash_sink_grad = sink.grad.clone() + + print("\nFlash gradient information:") + print(f"Query gradient norm: {flash_query_grad.norm().item():.6f}") + print(f"Key gradient norm: {flash_key_grad.norm().item():.6f}") + print(f"Value gradient norm: {flash_value_grad.norm().item():.6f}") + print(f"Sink gradient norm: {flash_sink_grad.norm().item():.6f}") + + # Compare gradients + print("\n" + "=" * 50) + print("Comparing gradients...") + + # Calculate gradient differences + query_grad_diff = torch.abs(eager_query_grad - flash_query_grad).max().item() + key_grad_diff = torch.abs(eager_key_grad - flash_key_grad).max().item() + value_grad_diff = torch.abs(eager_value_grad - flash_value_grad).max().item() + sink_grad_diff = torch.abs(eager_sink_grad - flash_sink_grad).max().item() + + print(f"Query gradient max difference: {query_grad_diff:.2e}") + print(f"Key gradient max difference: {key_grad_diff:.2e}") + print(f"Value gradient max difference: {value_grad_diff:.2e}") + print(f"Sink gradient max difference: {sink_grad_diff:.2e}") + + # Check if gradients are close (within tolerance) + tolerance = 1e-3 # Adjust tolerance as needed + query_grad_close = query_grad_diff < tolerance + key_grad_close = key_grad_diff < tolerance + value_grad_close = value_grad_diff < tolerance + sink_grad_close = sink_grad_diff < tolerance + + print(f"\nGradient comparison (tolerance: {tolerance}):") + print(f"Query gradients close: {'✅' if query_grad_close else '❌'}") + print(f"Key gradients close: {'✅' if key_grad_close else '❌'}") + print(f"Value gradients close: {'✅' if value_grad_close else '❌'}") + print(f"Sink gradients close: {'✅' if sink_grad_close else '❌'}") + + # Check if gradients are non-zero + query_grad_zero = eager_query_grad.norm().item() < 1e-8 + key_grad_zero = eager_key_grad.norm().item() < 1e-8 + value_grad_zero = eager_value_grad.norm().item() < 1e-8 + sink_grad_zero = eager_sink_grad.norm().item() < 1e-8 + + print(f"\nGradient non-zero check:") + print(f"Query gradient non-zero: {'✅' if not query_grad_zero else '❌'}") + print(f"Key gradient non-zero: {'✅' if not key_grad_zero else '❌'}") + print(f"Value gradient non-zero: {'✅' if not value_grad_zero else '❌'}") + print(f"Sink gradient non-zero: {'✅' if not sink_grad_zero else '❌'}") + + all_grads_close = ( + query_grad_close and key_grad_close and value_grad_close and sink_grad_close + ) + all_grads_nonzero = not ( + query_grad_zero or key_grad_zero or value_grad_zero or sink_grad_zero + ) + + print(f"\nOverall result:") + print(f" All gradients close: {'✅' if all_grads_close else '❌'}") + print(f" All gradients non-zero: {'✅' if all_grads_nonzero else '❌'}") + + if all_grads_close and all_grads_nonzero: + print("\n🎉 Backward test passed! Gradients match and are non-zero.") + else: + print("\n❌ Backward test failed!") + if not all_grads_close: + print(" - Some gradients don't match between eager and flash attention") + if not all_grads_nonzero: + print(" - Some gradients are zero") + + print("sink = ", sink.dtype, sink) + From f3b159eaf453f73c80b3da81884fa4cb2ba500b0 Mon Sep 17 00:00:00 2001 From: jerryao Date: Sun, 17 Aug 2025 03:51:36 +0800 Subject: [PATCH 03/29] Right fwd. --- csrc/flash_attn/src/flash_fwd_kernel.h | 13 ++++++++++--- csrc/flash_attn/src/softmax.h | 19 +++++++++++++++---- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index d492c87b5c8..00404f64165 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -51,12 +51,15 @@ __forceinline__ __device__ auto get_lse_tile(const Params ¶ms, const int bid template inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { + constexpr bool Has_sink = true; + using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; // Shared memory. extern __shared__ char smem_[]; + __shared__ float shared_sink_val; // The thread index. const int tidx = threadIdx.x; @@ -128,6 +131,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi } // if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); } + if (Has_sink) { + shared_sink_val = (params.sink_ptr != nullptr) ? static_cast(reinterpret_cast(params.sink_ptr)[bidh]) : -INFINITY; + } + // We iterate over the blocks in reverse order. This is because the last block is the only one // that needs masking when we read K and V from global memory. Moreover, iterating in reverse // might save us 1 register (we just need n_block instead of both n_block and n_block_max). @@ -282,7 +289,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi clear(acc_o); - FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax; + FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax(shared_sink_val); const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; FLASH_NAMESPACE::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); @@ -340,7 +347,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // TODO: when we have key_padding_mask we'll need to Check_inf masking_step == 0 - ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) + ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); // Convert acc_s from fp32 to fp16/bf16 @@ -430,7 +437,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Epilogue - Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, params.rp_dropout); + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, params.rp_dropout); // Convert acc_o from fp32 to fp16/bf16 Tensor rO = FLASH_NAMESPACE::convert_type(acc_o); diff --git a/csrc/flash_attn/src/softmax.h b/csrc/flash_attn/src/softmax.h index 01589adedb3..f12cc996c29 100644 --- a/csrc/flash_attn/src/softmax.h +++ b/csrc/flash_attn/src/softmax.h @@ -130,16 +130,23 @@ struct Softmax { using TensorT = decltype(make_tensor(Shape>{})); TensorT row_max, row_sum; + float sink_val; - __forceinline__ __device__ Softmax() {}; + __forceinline__ __device__ Softmax(const float sink_val = -INFINITY) : sink_val(sink_val) {}; - template + template __forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) { // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(acc_s.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_s.layout())); static_assert(decltype(size<0>(scores))::value == kNRows); if (Is_first) { - FLASH_NAMESPACE::template reduce_max(scores, row_max); + if (Has_sink) { + #pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { row_max(mi) = sink_val; } + FLASH_NAMESPACE::template reduce_max(scores, row_max); + } else { + FLASH_NAMESPACE::template reduce_max(scores, row_max); + } FLASH_NAMESPACE::scale_apply_exp2(scores, row_max, softmax_scale_log2); FLASH_NAMESPACE::reduce_sum(scores, row_sum); } else { @@ -166,7 +173,7 @@ struct Softmax { } }; - template + template __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) { SumOp sum_op; quad_allreduce_(row_sum, row_sum, sum_op); @@ -176,6 +183,10 @@ struct Softmax { #pragma unroll for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { float sum = row_sum(mi); + if (Has_sink) { + const float max_scaled = row_max(mi) == -INFINITY ? 0.f : row_max(mi) * softmax_scale; + sum += expf(sink_val - max_scaled); + } float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; From be46c294c38d2222eaf4aced59e14c7a64748467 Mon Sep 17 00:00:00 2001 From: jerryao Date: Sun, 17 Aug 2025 23:24:43 +0800 Subject: [PATCH 04/29] Right bwd. --- benchmarks/benchmark_flash_attention.py | 20 +++- benchmarks/flash_attn_with_sink_fused.py | 112 +++-------------------- benchmarks/test_fused.py | 39 ++++++-- csrc/flash_attn/flash_api.cpp | 4 +- csrc/flash_attn/src/flash_bwd_kernel.h | 31 +++++++ csrc/flash_attn/src/flash_fwd_kernel.h | 2 +- csrc/flash_attn/src/softmax.h | 6 ++ 7 files changed, 103 insertions(+), 111 deletions(-) diff --git a/benchmarks/benchmark_flash_attention.py b/benchmarks/benchmark_flash_attention.py index f3030bc86b6..9cada0156ef 100644 --- a/benchmarks/benchmark_flash_attention.py +++ b/benchmarks/benchmark_flash_attention.py @@ -14,6 +14,7 @@ from flash_attn import flash_attn_qkvpacked_func from flash_attn import flash_attn_func from flash_attn_with_sink import flash_attn_with_sink_func +from flash_attn_with_sink_fused import flash_attn_with_sink_fused_func try: from triton.ops.flash_attention import attention as attention_triton @@ -79,7 +80,7 @@ def time_fwd_bwd(func, *args, **kwargs): dim = 2048 dropout_p = 0.0 -methods = (["Flash2", "Flash2UnPacked", "Pytorch", "Flash2Sink"] +methods = (["Flash2", "Flash2UnPacked", "Pytorch", "Flash2Sink", "Flash2SinkFused"] + (["Triton"] if attention_triton is not None else []) + (["xformers.c"] if xops is not None else []) + (["xformers.f"] if xops is not None else [])) @@ -121,6 +122,23 @@ def time_fwd_bwd(func, *args, **kwargs): time_f[config, "Pytorch"] = f time_b[config, "Pytorch"] = b + try: + scaling = nheads**-0.5 + num_key_value_heads = nheads # // 8 + q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, + requires_grad=True) + k, v = [torch.randn(batch_size, seqlen, num_key_value_heads, headdim, device=device, dtype=dtype, + requires_grad=True) for _ in range(2)] + sink = torch.randn((nheads,), dtype=dtype, device=device, requires_grad=True) + + f, b = time_fwd_bwd( + flash_attn_with_sink_fused_func, q, k, v, sink, softmax_scale=scaling, dropout_p=dropout_p, causal=causal, repeats=repeats, verbose=False + ) + except: # Skip if OOM + f, b = float('nan'), float('nan') + time_f[config, "Flash2SinkFused"] = f + time_b[config, "Flash2SinkFused"] = b + try: scaling = nheads**-0.5 num_key_value_heads = nheads # // 8 diff --git a/benchmarks/flash_attn_with_sink_fused.py b/benchmarks/flash_attn_with_sink_fused.py index a554a7cf371..e19194ffd7d 100644 --- a/benchmarks/flash_attn_with_sink_fused.py +++ b/benchmarks/flash_attn_with_sink_fused.py @@ -38,8 +38,6 @@ def forward( ctx.return_attn_probs = return_attn_probs ctx.sink_shape = sink.shape # Save original sink shape - # import pdb; pdb.set_trace() - out, lse, _ = flash_attn_sink_func( q, k, @@ -54,51 +52,34 @@ def forward( deterministic=deterministic, return_attn_probs=True, ) - print("==== lse shape: ", lse.shape) - - origin_dtype = out.dtype - - ctx.raw_output = out.clone() - ctx.lse = lse.clone() - - lse = lse.transpose(-2, -1).unsqueeze(dim=-1) - # sink = sink.reshape(1, 1, -1, 1) - - # multiplier = 1 / (torch.exp(sink - lse) + 1) - # out = (out * multiplier).to(origin_dtype) + ctx.lse = lse + ctx.output = out return out + @staticmethod def backward(ctx, grad_output): q, k, v, sink, alibi_slopes = ctx.saved_tensors - raw_output = ctx.raw_output lse = ctx.lse - lse = lse.transpose(-2, -1).unsqueeze(dim=-1) - sink_reshaped = sink.reshape(1, 1, -1, 1) - multiplier = 1 / (torch.exp(sink_reshaped - lse) + 1) - - # 1) Main path via multiplier - grad_raw_output = (grad_output * multiplier).to(q.dtype) - - # Use flash attention backward function for main path - grad_q_main = torch.empty_like(q) - grad_k_main = torch.empty_like(k) + grad_q = torch.empty_like(q) + grad_k = torch.empty_like(k) grad_v = torch.empty_like(v) + grad_sink = torch.empty_like(sink) _flash_attn_sink_backward( - grad_raw_output, # dout: main path gradient + grad_output, # dout: main path gradient q, # q k, # k v, # v sink, - ctx.raw_output, # out: original output + ctx.output, # out: original output lse, # softmax_lse - grad_q_main, # dq: main path - grad_k_main, # dk: main path + grad_q, # dq: main path + grad_k, # dk: main path grad_v, # dv: main path - torch.empty_like(sink), # ds: sink gradient + grad_sink, # ds: sink gradient ctx.dropout_p, # dropout_p ctx.softmax_scale, # softmax_scale ctx.causal, # causal @@ -109,77 +90,6 @@ def backward(ctx, grad_output): ctx.deterministic, # deterministic ) - # 2) Sink gradient path - # g_r = (grad_output * raw_output).sum(dim=-1) # [B,H,Nq] - g_r = torch.sum(grad_output * raw_output, dim=-1) - - # g_ell = g_r * multiplier * (1 - multiplier) # [B,H,Nq] - # Based on debug output: - # g_r shape: [1, 512, 64] (batch, seq_len, heads) - # multiplier shape: [1, 512, 64, 1] (batch, seq_len, heads, 1) - # We need multiplier_for_grad to have shape [1, 512, 64] - # [1, 512, 64, 1] -> [1, 512, 64] - multiplier_for_grad = multiplier.squeeze(-1) - - g_ell = g_r * multiplier_for_grad * (1 - multiplier_for_grad) - # Based on shapes: g_ell [1, 512, 64], we need to sum over seq_len (dim=1) - # to get [1, 64], then sum over batch (dim=0) to get [64] - grad_sink = -torch.sum(g_ell, dim=1) # Sum over seq_len -> [1, 64] - # Sum over batch dimension and reshape to match original sink shape - grad_sink = grad_sink.sum(dim=0) # Sum over batch -> [64] - grad_sink = grad_sink.reshape(ctx.sink_shape) - - # 3) Additional Q gradient via sink - # dQ_extra = scale * g_ell * attention(Q,K,K) - scale = ctx.softmax_scale or (1.0 / q.shape[-1] ** 0.5) - - # Compute attention(Q,K,K) for additional Q gradient - mu_k = flash_attn_sink_func( - q, k, k, sink, - dropout_p=ctx.dropout_p, - softmax_scale=ctx.softmax_scale, - causal=ctx.causal, - window_size=ctx.window_size, - softcap=ctx.softcap, - alibi_slopes=alibi_slopes, - deterministic=ctx.deterministic, - return_attn_probs=False, - ) - grad_q_extra = scale * g_ell.unsqueeze(-1) * mu_k - - # 4) Additional K gradient via sink - # dK_extra = scale * P^T (g_ell * Q) - x = (g_ell.unsqueeze(-1) * q).to(q.dtype) - - # Use flash attention backward to compute P^T X - grad_k_extra = torch.empty_like(k) - _flash_attn_sink_backward( - x, # dout: g_ell * Q - q, # q - k, # k - k, # v (dummy, using K as V) - sink, - ctx.raw_output, # out: original output - lse, # softmax_lse - None, # dq: not needed - None, # dk: not needed - grad_k_extra, # dv: this will be dK_extra - torch.empty_like(sink), - ctx.dropout_p, # dropout_p - ctx.softmax_scale, # softmax_scale - ctx.causal, # causal - ctx.window_size[0], # window_size_left - ctx.window_size[1], # window_size_right - ctx.softcap, # softcap - alibi_slopes, # alibi_slopes - ctx.deterministic, # deterministic - ) - grad_k_extra = scale * grad_k_extra - - # 5) Sum all gradients - grad_q = grad_q_main + grad_q_extra - grad_k = grad_k_main + grad_k_extra - # grad_v already from main path return (grad_q, grad_k, grad_v, grad_sink, None, None, None, None, None, None, None, None) diff --git a/benchmarks/test_fused.py b/benchmarks/test_fused.py index ee348626b23..392981d2978 100644 --- a/benchmarks/test_fused.py +++ b/benchmarks/test_fused.py @@ -10,7 +10,7 @@ num_key_value_heads = 8 num_key_value_groups = num_attention_heads // num_key_value_heads head_dim = 64 - seq_len = 8 + seq_len = 512 scaling = head_dim**-0.5 torch.manual_seed(0) @@ -27,6 +27,13 @@ device="cuda", requires_grad=True, ) + # with torch.no_grad(): + # for h in range(len(key[0])): + # for s in range(len(key[0][h])): + # for d in range(len(key[0][h][s])): + # key[0][h][s][d] = s * 0.1 + print("key = ", key) + # exit() value = torch.randn( (batch, num_key_value_heads, seq_len, head_dim), dtype=torch.bfloat16, @@ -35,10 +42,18 @@ ) sink = torch.randn( (num_attention_heads,), - dtype=torch.bfloat16, + dtype=torch.float32, device="cuda", requires_grad=True, ) + # sink = torch.full( + # (num_attention_heads,), + # 0.5, + # dtype=torch.bfloat16, + # device="cuda", + # requires_grad=True, + # ) + # sink = torch.linspace(0, 1, num_attention_heads, dtype=torch.bfloat16, device="cuda", requires_grad=True) print("sink = ", sink) # Create causal attention mask @@ -62,7 +77,7 @@ query, key, value, - sink, + sink.to(torch.bfloat16), attention_mask=attention_mask, scaling=scaling, dropout=0.0, @@ -106,7 +121,16 @@ print(eager_output[0, 0, :8, :8]) print("\nFlash output sample (first 8x8 elements):") print(flash_output[0, 0, :8, :8]) - print(eager_output[0, 0, :8, :8]/flash_output[0, 0, :8, :8]) + print("eager_output / flash_output:\n", eager_output[0, 0, :8, :8] / flash_output[0, 0, :8, :8]) + + # print("query[0, 0] = ", query[0, 0].shape, query[0, 0]) + # print("key[0, 0] = ", key[0, 0].shape, key[0, 0]) + q_tile = q_flash[0, :, 0, :] + k_tile = k_flash[0, :, 0, :] + # print("query * key = ", torch.matmul(q_tile, k_tile.transpose(-2, -1))) + # print("query * key = ", torch.matmul(q_tile, k_tile.transpose(-2, -1))[0]) + # print("query1 * key1 = ", torch.matmul(q_flash[0, :, 1, :], k_flash[0, :, 1, :].transpose(-2, -1))) + # exit() # Test backward pass print("\n" + "=" * 50) @@ -181,7 +205,10 @@ query_grad_diff = torch.abs(eager_query_grad - flash_query_grad).max().item() key_grad_diff = torch.abs(eager_key_grad - flash_key_grad).max().item() value_grad_diff = torch.abs(eager_value_grad - flash_value_grad).max().item() - sink_grad_diff = torch.abs(eager_sink_grad - flash_sink_grad).max().item() + sink_grad_diff = torch.abs(eager_sink_grad.to(flash_sink_grad.dtype) - flash_sink_grad).max().item() + + print("eager_sink_grad = ", eager_sink_grad) + print("flash_sink_grad = ", flash_sink_grad) print(f"Query gradient max difference: {query_grad_diff:.2e}") print(f"Key gradient max difference: {key_grad_diff:.2e}") @@ -189,7 +216,7 @@ print(f"Sink gradient max difference: {sink_grad_diff:.2e}") # Check if gradients are close (within tolerance) - tolerance = 1e-3 # Adjust tolerance as needed + tolerance = 1e-2 # Adjust tolerance as needed query_grad_close = query_grad_diff < tolerance key_grad_close = key_grad_diff < tolerance value_grad_close = value_grad_diff < tolerance diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 82336f9768a..97de9c4fee5 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -1254,12 +1254,12 @@ mha_sink_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x mu } if (dsink_.has_value()) { dsink = dsink_.value(); - TORCH_CHECK(dsink.dtype() == sink.dtype(), "dsink must have the same dtype as sink"); + TORCH_CHECK(dsink.dtype() == at::kFloat, "dsink must have the dtype float32"); CHECK_DEVICE(dsink); TORCH_CHECK(dsink.stride(-1) == 1, "dsink must have contiguous last dimension"); CHECK_SHAPE(dsink, num_heads); } else { - dsink = torch::empty_like(sink); + dsink = torch::empty_like(sink, sink.options().dtype(at::kFloat)); } // bool loop = seqlen_k > blocksize_c; diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index 50af5f63073..0242ff8c97a 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -79,6 +79,7 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom const& copy_atom, template inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, const int bidh, const int n_block) { + constexpr bool Has_sink = true; using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; @@ -86,6 +87,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Shared memory. extern __shared__ char smem_[]; + __shared__ float shared_sink_val; + float dsink_val = 0.f; // The thread index. const int tidx = threadIdx.x; @@ -105,6 +108,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left, kBlockM)); } + if (Has_sink) { + shared_sink_val = (params.sink_ptr != nullptr) ? static_cast(reinterpret_cast(params.sink_ptr)[bidh]) : -INFINITY; + } + const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + (m_block_max - 1) * kBlockM * params.q_row_stride + bidh * params.q_head_stride; const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) @@ -577,6 +584,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV ); + // Has dP acc_dp + // Has P scores + // Has lse + // Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (row=(2, MMA_N), col=(2, MMA_N)) Tensor dS = make_tensor(acc_dp.data(), scores.layout()); auto pointwise_mult = [](float p, float dp, float d) { @@ -584,12 +595,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in }; #pragma unroll for (int mi = 0; mi < size<0>(dS); ++mi) { + float dsink_val_temp = 0.f; #pragma unroll for (int ni = 0; ni < size<1>(dS); ++ni) { + dsink_val_temp += dS(mi, ni) * scores(mi, ni); float scaled_ds = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi)); if constexpr (Is_softcap) { scaled_ds *= dtanh(mi, ni); } dS(mi, ni) = scaled_ds; } + dsink_val += dsink_val_temp / expf(lse(mi)); } // if (cute::thread0()) { print(dS); } @@ -790,6 +804,23 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN ); + if (Has_sink) { + dsink_val = warpReduceSum(dsink_val); + float* dsink_ptr = reinterpret_cast(params.dsink_ptr); + + if (tidx % 32 == 0) { + float val = -dsink_val * expf(shared_sink_val); + atomicAdd(dsink_ptr + bidh, val); + printf("tidx: %d, bidh: %d, dsink_val: %f, shared_sink_val: %f, add_val: %f\n", tidx, bidh, dsink_val, shared_sink_val, val); + } + // if (tidx % 32 == 0) { + // atomicAdd(reinterpret_cast(params.dsink_ptr) + bidh, static_cast(-dsink_val * expf(shared_sink_val))); + // printf("tidx: %d, bidh: %d, dsink_val: %f, shared_sink_val: %f\n", tidx, bidh, dsink_val, shared_sink_val); + // } + // reinterpret_cast(params.dsink_ptr)[bidh] += static_cast(-dsink_val * expf(shared_sink_val)); + // + } + } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 00404f64165..33f6f051b0f 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -132,7 +132,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); } if (Has_sink) { - shared_sink_val = (params.sink_ptr != nullptr) ? static_cast(reinterpret_cast(params.sink_ptr)[bidh]) : -INFINITY; + shared_sink_val = (params.sink_ptr != nullptr) ? static_cast(reinterpret_cast(params.sink_ptr)[bidh]) : -INFINITY; } // We iterate over the blocks in reverse order. This is because the last block is the only one diff --git a/csrc/flash_attn/src/softmax.h b/csrc/flash_attn/src/softmax.h index f12cc996c29..b4a36eea887 100644 --- a/csrc/flash_attn/src/softmax.h +++ b/csrc/flash_attn/src/softmax.h @@ -123,6 +123,12 @@ __forceinline__ __device__ void max_scale_exp2_sum(Tensor &ten } } +__inline__ __device__ float warpReduceSum(float val) { + for (int offset = 16; offset > 0; offset /= 2) + val += __shfl_down_sync(0xffffffff, val, offset); + return val; +} + //////////////////////////////////////////////////////////////////////////////////////////////////// template From a53a8d8696f12fe78e4f17e13d7a9d11aa095b50 Mon Sep 17 00:00:00 2001 From: jerryao Date: Mon, 18 Aug 2025 01:56:22 +0800 Subject: [PATCH 05/29] Refine. --- benchmarks/benchmark_flash_attention.py | 4 +- benchmarks/test.py | 1 + benchmarks/test_fused.py | 28 ------------ csrc/flash_attn/src/flash_bwd_kernel.h | 35 ++++++--------- .../src/flash_bwd_launch_template.h | 28 ++++++------ csrc/flash_attn/src/flash_fwd_kernel.h | 12 +++-- .../src/flash_fwd_launch_template.h | 44 ++++++++++--------- csrc/flash_attn/src/softmax.h | 2 +- 8 files changed, 61 insertions(+), 93 deletions(-) diff --git a/benchmarks/benchmark_flash_attention.py b/benchmarks/benchmark_flash_attention.py index 9cada0156ef..84e28978183 100644 --- a/benchmarks/benchmark_flash_attention.py +++ b/benchmarks/benchmark_flash_attention.py @@ -13,8 +13,8 @@ from flash_attn import flash_attn_qkvpacked_func from flash_attn import flash_attn_func +from flash_attn import flash_attn_sink_func from flash_attn_with_sink import flash_attn_with_sink_func -from flash_attn_with_sink_fused import flash_attn_with_sink_fused_func try: from triton.ops.flash_attention import attention as attention_triton @@ -132,7 +132,7 @@ def time_fwd_bwd(func, *args, **kwargs): sink = torch.randn((nheads,), dtype=dtype, device=device, requires_grad=True) f, b = time_fwd_bwd( - flash_attn_with_sink_fused_func, q, k, v, sink, softmax_scale=scaling, dropout_p=dropout_p, causal=causal, repeats=repeats, verbose=False + flash_attn_sink_func, q, k, v, sink, softmax_scale=scaling, dropout_p=dropout_p, causal=causal, repeats=repeats, verbose=False ) except: # Skip if OOM f, b = float('nan'), float('nan') diff --git a/benchmarks/test.py b/benchmarks/test.py index 7352b949b8c..e24098caaee 100644 --- a/benchmarks/test.py +++ b/benchmarks/test.py @@ -105,6 +105,7 @@ print(eager_output[0, 0, :5, :5]) print("\nFlash output sample (first 5x5 elements):") print(flash_output[0, 0, :5, :5]) + print("eager_output / flash_output:\n", eager_output[0, 0, :8, :8] / flash_output[0, 0, :8, :8]) # Test backward pass print("\n" + "=" * 50) diff --git a/benchmarks/test_fused.py b/benchmarks/test_fused.py index 392981d2978..849e6cfff66 100644 --- a/benchmarks/test_fused.py +++ b/benchmarks/test_fused.py @@ -27,13 +27,6 @@ device="cuda", requires_grad=True, ) - # with torch.no_grad(): - # for h in range(len(key[0])): - # for s in range(len(key[0][h])): - # for d in range(len(key[0][h][s])): - # key[0][h][s][d] = s * 0.1 - print("key = ", key) - # exit() value = torch.randn( (batch, num_key_value_heads, seq_len, head_dim), dtype=torch.bfloat16, @@ -46,15 +39,6 @@ device="cuda", requires_grad=True, ) - # sink = torch.full( - # (num_attention_heads,), - # 0.5, - # dtype=torch.bfloat16, - # device="cuda", - # requires_grad=True, - # ) - # sink = torch.linspace(0, 1, num_attention_heads, dtype=torch.bfloat16, device="cuda", requires_grad=True) - print("sink = ", sink) # Create causal attention mask # The mask should be of shape (batch, num_heads, seq_len, seq_len) @@ -123,15 +107,6 @@ print(flash_output[0, 0, :8, :8]) print("eager_output / flash_output:\n", eager_output[0, 0, :8, :8] / flash_output[0, 0, :8, :8]) - # print("query[0, 0] = ", query[0, 0].shape, query[0, 0]) - # print("key[0, 0] = ", key[0, 0].shape, key[0, 0]) - q_tile = q_flash[0, :, 0, :] - k_tile = k_flash[0, :, 0, :] - # print("query * key = ", torch.matmul(q_tile, k_tile.transpose(-2, -1))) - # print("query * key = ", torch.matmul(q_tile, k_tile.transpose(-2, -1))[0]) - # print("query1 * key1 = ", torch.matmul(q_flash[0, :, 1, :], k_flash[0, :, 1, :].transpose(-2, -1))) - # exit() - # Test backward pass print("\n" + "=" * 50) print("Testing backward pass...") @@ -259,6 +234,3 @@ print(" - Some gradients don't match between eager and flash attention") if not all_grads_nonzero: print(" - Some gradients are zero") - - print("sink = ", sink.dtype, sink) - diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index 0242ff8c97a..87ea64bd4b1 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -77,9 +77,8 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom const& copy_atom, //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, const int bidh, const int n_block) { - constexpr bool Has_sink = true; using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; @@ -108,8 +107,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left, kBlockM)); } - if (Has_sink) { - shared_sink_val = (params.sink_ptr != nullptr) ? static_cast(reinterpret_cast(params.sink_ptr)[bidh]) : -INFINITY; + if constexpr (Has_sink) { + if (tidx == 0) { shared_sink_val = static_cast(reinterpret_cast(params.sink_ptr)[bidh]); } } const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) @@ -595,15 +594,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in }; #pragma unroll for (int mi = 0; mi < size<0>(dS); ++mi) { - float dsink_val_temp = 0.f; + float dsink_val_cols = 0.f; #pragma unroll for (int ni = 0; ni < size<1>(dS); ++ni) { - dsink_val_temp += dS(mi, ni) * scores(mi, ni); + if constexpr (Has_sink) { dsink_val_cols += dS(mi, ni) * scores(mi, ni); } float scaled_ds = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi)); if constexpr (Is_softcap) { scaled_ds *= dtanh(mi, ni); } dS(mi, ni) = scaled_ds; } - dsink_val += dsink_val_temp / expf(lse(mi)); + if constexpr (Has_sink) { dsink_val += dsink_val_cols / expf(lse(mi)); } } // if (cute::thread0()) { print(dS); } @@ -804,21 +803,14 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN ); - if (Has_sink) { - dsink_val = warpReduceSum(dsink_val); - float* dsink_ptr = reinterpret_cast(params.dsink_ptr); - + if constexpr (Has_sink) { + SumOp sum_op; + dsink_val = Allreduce<4>::run(dsink_val, sum_op); if (tidx % 32 == 0) { - float val = -dsink_val * expf(shared_sink_val); + float* dsink_ptr = reinterpret_cast(params.dsink_ptr); + float val = -dsink_val * exp2f(shared_sink_val * float(M_LOG2E)); atomicAdd(dsink_ptr + bidh, val); - printf("tidx: %d, bidh: %d, dsink_val: %f, shared_sink_val: %f, add_val: %f\n", tidx, bidh, dsink_val, shared_sink_val, val); } - // if (tidx % 32 == 0) { - // atomicAdd(reinterpret_cast(params.dsink_ptr) + bidh, static_cast(-dsink_val * expf(shared_sink_val))); - // printf("tidx: %d, bidh: %d, dsink_val: %f, shared_sink_val: %f\n", tidx, bidh, dsink_val, shared_sink_val); - // } - // reinterpret_cast(params.dsink_ptr)[bidh] += static_cast(-dsink_val * expf(shared_sink_val)); - // } } @@ -838,6 +830,7 @@ inline __device__ void compute_dq_dk_dv(const Params ¶ms) { const int tidx = threadIdx.x; const int n_block_max = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; + printf("n_block_max = %d\n", n_block_max); if (n_block_max == 1) { compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); } else { @@ -852,7 +845,7 @@ inline __device__ void compute_dq_dk_dv(const Params ¶ms) { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) { // The block index for the batch. @@ -862,7 +855,7 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) { // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer. for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) { - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); } } diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index 72e7a333b3a..3b3b1e79640 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -39,10 +39,10 @@ DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bo #endif } -DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) { +DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Has_sink) { #if defined(ARCH_SUPPORTS_FLASH) static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false - FLASH_NAMESPACE::compute_dq_dk_dv_seqk_parallel(params); + FLASH_NAMESPACE::compute_dq_dk_dv_seqk_parallel(params); #else FLASH_UNSUPPORTED_ARCH #endif @@ -99,17 +99,19 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] { ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If head dim > 128, set IsEvenMNConst to false to reduce number of templates - // If Is_local, set Is_causal to false - auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; - // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; - if (smem_size_dq_dk_dv >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + SINK_SWITCH(params.sink_ptr != nullptr, Has_sink, [&] { + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + if (smem_size_dq_dk_dv >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); }); }); diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 33f6f051b0f..bd5d4f30698 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -48,11 +48,9 @@ __forceinline__ __device__ auto get_lse_tile(const Params ¶ms, const int bid } -template +template inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { - constexpr bool Has_sink = true; - using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; @@ -131,8 +129,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi } // if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); } - if (Has_sink) { - shared_sink_val = (params.sink_ptr != nullptr) ? static_cast(reinterpret_cast(params.sink_ptr)[bidh]) : -INFINITY; + if constexpr (Has_sink) { + if (tidx == 0) { shared_sink_val = static_cast(reinterpret_cast(params.sink_ptr)[bidh]); } } // We iterate over the blocks in reverse order. This is because the last block is the only one @@ -1079,7 +1077,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn(const Params ¶ms) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1095,7 +1093,7 @@ inline __device__ void compute_attn(const Params ¶ms) { // the attention matrix. This way, as long as we have the batch, head, and the location of // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. - FLASH_NAMESPACE::compute_attn_1rowblock(params, bidb, bidh, m_block); + FLASH_NAMESPACE::compute_attn_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index 934e7b9114b..da11589e06c 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -29,10 +29,10 @@ namespace FLASH_NAMESPACE { template \ __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) { +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, bool Has_sink) { #if defined(ARCH_SUPPORTS_FLASH) static_assert(!(Is_causal && Is_local)); // Enforce constraints - FLASH_NAMESPACE::compute_attn(params); + FLASH_NAMESPACE::compute_attn(params); #else FLASH_UNSUPPORTED_ARCH #endif @@ -71,25 +71,27 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { - // Will only return softmax if dropout, to reduce compilation time. - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If return_softmax, set IsEvenMNConst to false to reduce number of templates - // If head dim > 128, set IsEvenMNConst to false to reduce number of templates - // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_kernel; - // auto kernel = &flash_fwd_kernel; - // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); - // auto kernel = &flash_fwd_kernel; - if (smem_size >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - // int ctas_per_sm; - // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); - // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + SINK_SWITCH(params.sink_ptr != nullptr, Has_sink, [&] { + // Will only return softmax if dropout, to reduce compilation time. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If return_softmax, set IsEvenMNConst to false to reduce number of templates + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_kernel; + // auto kernel = &flash_fwd_kernel; + // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); }); }); diff --git a/csrc/flash_attn/src/softmax.h b/csrc/flash_attn/src/softmax.h index b4a36eea887..c74131eacd6 100644 --- a/csrc/flash_attn/src/softmax.h +++ b/csrc/flash_attn/src/softmax.h @@ -146,7 +146,7 @@ struct Softmax { Tensor scores = make_tensor(acc_s.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_s.layout())); static_assert(decltype(size<0>(scores))::value == kNRows); if (Is_first) { - if (Has_sink) { + if constexpr (Has_sink) { #pragma unroll for (int mi = 0; mi < size(row_max); ++mi) { row_max(mi) = sink_val; } FLASH_NAMESPACE::template reduce_max(scores, row_max); From 5edbfdc016d0f8fd831bcc80cf10b18815790fa0 Mon Sep 17 00:00:00 2001 From: jerryao Date: Mon, 18 Aug 2025 08:04:43 +0800 Subject: [PATCH 06/29] Fix warp reduce. --- benchmarks/flash_attn_with_sink.py | 1 - benchmarks/naive_attn_with_sink.py | 11 ----------- csrc/flash_attn/src/flash_bwd_kernel.h | 6 +----- csrc/flash_attn/src/softmax.h | 6 ------ flash_attn/flash_attn_interface.py | 2 +- 5 files changed, 2 insertions(+), 24 deletions(-) diff --git a/benchmarks/flash_attn_with_sink.py b/benchmarks/flash_attn_with_sink.py index 801f27a789b..ac3af82e788 100644 --- a/benchmarks/flash_attn_with_sink.py +++ b/benchmarks/flash_attn_with_sink.py @@ -53,7 +53,6 @@ def forward( deterministic=deterministic, return_attn_probs=True, ) - print("==== lse shape: ", lse.shape) origin_dtype = out.dtype diff --git a/benchmarks/naive_attn_with_sink.py b/benchmarks/naive_attn_with_sink.py index 79d7be6eacb..d5f8fb1727b 100644 --- a/benchmarks/naive_attn_with_sink.py +++ b/benchmarks/naive_attn_with_sink.py @@ -34,31 +34,20 @@ def eager_attention_forward( ): key_states = repeat_kv(key, num_key_value_groups) value_states = repeat_kv(value, num_key_value_groups) - print("==== sink: ", sink.shape, sink) - print("==== query shape: ", query.shape) - print("==== key_states shape: ", key_states.shape) - print("==== value_states shape: ", value_states.shape) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - print("==== attn_weights shape: ", attn_weights.shape) if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - print("==== causal_mask shape: ", causal_mask.shape) attn_weights = attn_weights + causal_mask sinks = sink.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) - print("==== sinks shape: ", sinks.shape) combined_logits = torch.cat([attn_weights, sinks], dim=-1) - print("==== combined_logits shape: ", combined_logits.shape) # This was not in the original implementation and slightly affect results; # it prevents overflow in BF16/FP16 when training with bsz>1 we clamp max values. combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) - print("==== probs shape: ", probs.shape) scores = probs[..., :-1] # we drop the sink here - print("==== scores shape: ", scores.shape) attn_weights = nn.functional.dropout(scores, p=dropout, training=True) attn_output = torch.matmul(attn_weights, value_states) - print("==== attn_output shape: ", attn_output.shape) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index 87ea64bd4b1..9f666e5e1f0 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -583,10 +583,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV ); - // Has dP acc_dp - // Has P scores - // Has lse - // Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (row=(2, MMA_N), col=(2, MMA_N)) Tensor dS = make_tensor(acc_dp.data(), scores.layout()); auto pointwise_mult = [](float p, float dp, float d) { @@ -805,7 +801,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in if constexpr (Has_sink) { SumOp sum_op; - dsink_val = Allreduce<4>::run(dsink_val, sum_op); + dsink_val = Allreduce<32>::run(dsink_val, sum_op); if (tidx % 32 == 0) { float* dsink_ptr = reinterpret_cast(params.dsink_ptr); float val = -dsink_val * exp2f(shared_sink_val * float(M_LOG2E)); diff --git a/csrc/flash_attn/src/softmax.h b/csrc/flash_attn/src/softmax.h index c74131eacd6..9031c9a63a7 100644 --- a/csrc/flash_attn/src/softmax.h +++ b/csrc/flash_attn/src/softmax.h @@ -123,12 +123,6 @@ __forceinline__ __device__ void max_scale_exp2_sum(Tensor &ten } } -__inline__ __device__ float warpReduceSum(float val) { - for (int offset = 16; offset > 0; offset /= 2) - val += __shfl_down_sync(0xffffffff, val, offset); - return val; -} - //////////////////////////////////////////////////////////////////////////////////////////////////// template diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index ae5850b9b20..8906499dc64 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -133,7 +133,7 @@ def _flash_attn_forward_fake( return out, softmax_lse, p, rng_state -if torch.__version__ >= "12.4.0": +if torch.__version__ >= "2.4.0": _wrapped_flash_attn_forward = torch.ops.flash_attn._flash_attn_forward else: _wrapped_flash_attn_forward = _flash_attn_forward From b4fc07b7d409823360178f5f17da9751798cdb76 Mon Sep 17 00:00:00 2001 From: jerryao Date: Tue, 19 Aug 2025 17:26:43 +0800 Subject: [PATCH 07/29] Add Tests. --- benchmarks/benchmark_flash_attention.py | 25 +- tests/test_flash_attn.py | 466 +++++++++++++++++++++++- 2 files changed, 467 insertions(+), 24 deletions(-) diff --git a/benchmarks/benchmark_flash_attention.py b/benchmarks/benchmark_flash_attention.py index 84e28978183..121bf63f27f 100644 --- a/benchmarks/benchmark_flash_attention.py +++ b/benchmarks/benchmark_flash_attention.py @@ -15,6 +15,7 @@ from flash_attn import flash_attn_func from flash_attn import flash_attn_sink_func from flash_attn_with_sink import flash_attn_with_sink_func +from flash_attn_with_sink_fused import flash_attn_with_sink_fused_func try: from triton.ops.flash_attention import attention as attention_triton @@ -122,20 +123,18 @@ def time_fwd_bwd(func, *args, **kwargs): time_f[config, "Pytorch"] = f time_b[config, "Pytorch"] = b - try: - scaling = nheads**-0.5 - num_key_value_heads = nheads # // 8 - q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, - requires_grad=True) - k, v = [torch.randn(batch_size, seqlen, num_key_value_heads, headdim, device=device, dtype=dtype, - requires_grad=True) for _ in range(2)] - sink = torch.randn((nheads,), dtype=dtype, device=device, requires_grad=True) - f, b = time_fwd_bwd( - flash_attn_sink_func, q, k, v, sink, softmax_scale=scaling, dropout_p=dropout_p, causal=causal, repeats=repeats, verbose=False - ) - except: # Skip if OOM - f, b = float('nan'), float('nan') + scaling = nheads**-0.5 + num_key_value_heads = nheads # // 8 + q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, + requires_grad=True) + k, v = [torch.randn(batch_size, seqlen, num_key_value_heads, headdim, device=device, dtype=dtype, + requires_grad=True) for _ in range(2)] + sink = torch.randn((nheads,), dtype=torch.float32, device=device, requires_grad=True) + + f, b = time_fwd_bwd( + flash_attn_with_sink_fused_func, q, k, v, sink, softmax_scale=scaling, dropout_p=dropout_p, causal=causal, repeats=repeats, verbose=False + ) time_f[config, "Flash2SinkFused"] = f time_b[config, "Flash2SinkFused"] = b diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index d5bb6ba8531..fa5792b84f4 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -1,4 +1,5 @@ import math +from typing import Optional import pytest import torch @@ -6,6 +7,7 @@ from einops import rearrange, repeat from flash_attn import ( flash_attn_func, + flash_attn_sink_func, flash_attn_kvpacked_func, flash_attn_qkvpacked_func, flash_attn_varlen_func, @@ -564,6 +566,79 @@ def get_dropout_fraction( return dropped.sum() / valid.sum() +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). + The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) + to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def attention_sink_ref( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sink: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + num_key_value_groups: int = 8, + **kwargs, +): + key_states = repeat_kv(key, num_key_value_groups) + value_states = repeat_kv(value, num_key_value_groups) + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + sinks = sink.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) + combined_logits = torch.cat([attn_weights, sinks], dim=-1) + # This was not in the original implementation and slightly affect results; + # it prevents overflow in BF16/FP16 when training with bsz>1 we clamp max values. + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + + probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) + scores = probs[..., :-1] # we drop the sink here + attn_weights = torch.nn.functional.dropout(scores, p=dropout, training=True) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +def get_attention_mask( + seqlen_q, + seqlen_k, + causal, + device, + window_size=(-1, -1), + query_padding_mask=None, + key_padding_mask=None, + key_leftpad=None, +): + if causal: + window_size = (window_size[0], 0) + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + device, + key_leftpad=key_leftpad, + ) + return torch.where(local_mask, torch.tensor(float('-inf')), torch.tensor(0.0)) + return None + + @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("deterministic", [False, True]) @@ -1455,27 +1530,27 @@ def test_flash_attn_varlen_output( # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [True]) -@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) -# @pytest.mark.parametrize("d", [64, 128]) +@pytest.mark.parametrize("d", [64, 128]) @pytest.mark.parametrize("swap_sq_sk", [False, True]) # @pytest.mark.parametrize("swap_sq_sk", [True]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ - (1, 239), - (3, 799), - (127, 512), - (127, 513), - (113, 203), - (128, 217), - (113, 211), - (108, 256), + # (1, 239), + # (3, 799), + # (127, 512), + # (127, 513), + # (113, 203), + # (128, 217), + # (113, 211), + # (108, 256), (256, 512), - (1023, 1024), + # (1023, 1024), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) @@ -2523,3 +2598,372 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus assert torch.equal(dv, dv0) assert torch.equal(dk, dk0) assert torch.equal(dq, dq0) + +@pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("d", [64]) +@pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("swap_sq_sk", [True]) +@pytest.mark.parametrize('seqlen_q,seqlen_k', [(512, 512)]) +def test_flash_attn_sink_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): + if ( + max(seqlen_q, seqlen_k) >= 2048 + and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 + ): + pytest.skip() # Reference implementation OOM + if swap_sq_sk: + seqlen_q, seqlen_k = seqlen_k, seqlen_q + device = "cuda" + causal = True + # set seed + torch.random.manual_seed(0) + batch_size = 1 + nheads = 64 + num_key_value_groups = 8 + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads // num_key_value_groups, d, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads // num_key_value_groups, d, device=device, dtype=dtype, requires_grad=True) + q_ref = q.transpose(1, 2) + k_ref = k.transpose(1, 2) + v_ref = v.transpose(1, 2) + sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) + out = flash_attn_sink_func(q, k, v, sink, 0.0, causal=causal, window_size=window_size) + + attention_mask = get_attention_mask(seqlen_q, seqlen_k, causal, device, window_size) + if attention_mask is not None: + attention_mask = attention_mask.expand(batch_size, nheads, -1, -1).to(dtype) + + out_ref, _ = attention_sink_ref(q_ref.float(), k_ref.float(), v_ref.float(), sink.float(), attention_mask.float(), d**-0.5, 0.0, num_key_value_groups) + out_ref = out_ref.to(dtype) + out_pt, _ = attention_sink_ref(q_ref, k_ref, v_ref, sink.to(dtype), attention_mask, d**-0.5, 0.0, num_key_value_groups) + + print(f"Output max diff: {(out - out_ref).abs().max().detach().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().detach().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().detach().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().detach().item()}") + + g = torch.randn_like(out) + do_o = (g.float() * out.float()).sum(-1) + ( + dq, + dk, + dv, + dsink, + ) = torch.autograd.grad(out, (q, k, v, sink), g) + ( + dq_ref, + dk_ref, + dv_ref, + dsink_ref, + ) = torch.autograd.grad(out_ref, (q, k, v, sink), g) + ( + dq_pt, + dk_pt, + dv_pt, + dsink_pt, + ) = torch.autograd.grad(out_pt, (q, k, v, sink), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().detach().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().detach().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().detach().item()}") + print(f"dS max diff: {(dsink - dsink_ref).abs().max().detach().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().detach().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().detach().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().detach().item()}") + print(f"dS mean diff: {(dsink - dsink_ref).abs().mean().detach().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().detach().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().detach().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().detach().item()}") + print(f"dS Pytorch max diff: {(dsink_pt - dsink_ref).abs().max().detach().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().detach().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().detach().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().detach().item()}") + print(f"dS Pytorch mean diff: {(dsink_pt - dsink_ref).abs().mean().detach().item()}") + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().detach().item() <= 2 * (out_pt - out_ref).abs().max().detach().item() + 1e-5 + + assert (dq - dq_ref).abs().max().detach().item() <= 2 * (dq_pt - dq_ref).abs().max().detach().item() + 1e-5 + assert (dk - dk_ref).abs().max().detach().item() <= 2 * (dk_pt - dk_ref).abs().max().detach().item() + 1e-5 + assert (dv - dv_ref).abs().max().detach().item() <= 2 * (dv_pt - dv_ref).abs().max().detach().item() + 1e-5 + assert (dsink - dsink_ref).abs().max().detach().item() <= 2 * (dsink_pt - dsink_ref).abs().max().detach().item() + 1e-5 + + + +""" +:test_flash_attn_causal + +Running 16 items in this shard +Output max diff: 0.00048828125 +Output mean diff: 1.8358230590820312e-05 +Pytorch max diff: 0.0009765625 +Pytorch mean diff: 3.36766242980957e-05 +==== 256 512 False 64 False torch.float16 +dQ max diff: 0.00048828125 +dK max diff: 0.00048828125 +dV max diff: 0.000244140625 +dQ mean diff: 5.364418029785156e-06 +dK mean diff: 3.3974647521972656e-06 +dV mean diff: 1.1920928955078125e-07 +dQ Pytorch max diff: 0.00146484375 +dK Pytorch max diff: 0.00146484375 +dV Pytorch max diff: 0.0009765625 +dQ Pytorch mean diff: 3.927946090698242e-05 +dK Pytorch mean diff: 2.5272369384765625e-05 +dV Pytorch mean diff: 2.1457672119140625e-05 +.Output max diff: 0.00390625 +Output mean diff: 9.679794311523438e-05 +Pytorch max diff: 0.01171875 +Pytorch mean diff: 0.0002498626708984375 +==== 256 512 False 64 False torch.bfloat16 +dQ max diff: 0.0078125 +dK max diff: 0.0078125 +dV max diff: 0.00390625 +dQ mean diff: 0.0001125335693359375 +dK mean diff: 7.104873657226562e-05 +dV mean diff: 7.343292236328125e-05 +dQ Pytorch max diff: 0.01171875 +dK Pytorch max diff: 0.015625 +dV Pytorch max diff: 0.015625 +dQ Pytorch mean diff: 0.0002956390380859375 +dK Pytorch mean diff: 0.00019073486328125 +dV Pytorch mean diff: 0.0001583099365234375 +.Output max diff: 0.00048828125 +Output mean diff: 2.6047229766845703e-05 +Pytorch max diff: 0.001953125 +Pytorch mean diff: 4.690885543823242e-05 +==== 256 512 False 64 True torch.float16 +dQ max diff: 0.0009765625 +dK max diff: 0.0009765625 +dV max diff: 0.00048828125 +dQ mean diff: 8.881092071533203e-06 +dK mean diff: 5.245208740234375e-06 +dV mean diff: 1.1920928955078125e-07 +dQ Pytorch max diff: 0.001953125 +dK Pytorch max diff: 0.0029296875 +dV Pytorch max diff: 0.00146484375 +dQ Pytorch mean diff: 5.352497100830078e-05 +dK Pytorch mean diff: 3.135204315185547e-05 +dV Pytorch mean diff: 2.7120113372802734e-05 +.Output max diff: 0.00390625 +Output mean diff: 0.0001354217529296875 +Pytorch max diff: 0.0078125 +Pytorch mean diff: 0.00034332275390625 +==== 256 512 False 64 True torch.bfloat16 +dQ max diff: 0.0078125 +dK max diff: 0.0078125 +dV max diff: 0.0078125 +dQ mean diff: 0.00015926361083984375 +dK mean diff: 9.202957153320312e-05 +dV mean diff: 9.632110595703125e-05 +dQ Pytorch max diff: 0.015625 +dK Pytorch max diff: 0.015625 +dV Pytorch max diff: 0.015625 +dQ Pytorch mean diff: 0.000400543212890625 +dK Pytorch mean diff: 0.00023555755615234375 +dV Pytorch mean diff: 0.00019931793212890625 +.Output max diff: 0.00048828125 +Output mean diff: 2.2709369659423828e-05 +Pytorch max diff: 0.0009765625 +Pytorch mean diff: 3.910064697265625e-05 +==== 256 512 False 128 False torch.float16 +dQ max diff: 0.00048828125 +dK max diff: 0.0009765625 +dV max diff: 0.00048828125 +dQ mean diff: 2.1457672119140625e-05 +dK mean diff: 1.6450881958007812e-05 +dV mean diff: 1.436471939086914e-05 +dQ Pytorch max diff: 0.00244140625 +dK Pytorch max diff: 0.00244140625 +dV Pytorch max diff: 0.0009765625 +dQ Pytorch mean diff: 4.565715789794922e-05 +dK Pytorch mean diff: 3.063678741455078e-05 +dV Pytorch mean diff: 2.5272369384765625e-05 +.Output max diff: 0.00390625 +Output mean diff: 9.632110595703125e-05 +Pytorch max diff: 0.009765625 +Pytorch mean diff: 0.000270843505859375 +==== 256 512 False 128 False torch.bfloat16 +dQ max diff: 0.00390625 +dK max diff: 0.0078125 +dV max diff: 0.00390625 +dQ mean diff: 0.000110626220703125 +dK mean diff: 7.200241088867188e-05 +dV mean diff: 7.343292236328125e-05 +dQ Pytorch max diff: 0.01171875 +dK Pytorch max diff: 0.01171875 +dV Pytorch max diff: 0.0078125 +dQ Pytorch mean diff: 0.000331878662109375 +dK Pytorch mean diff: 0.000213623046875 +dV Pytorch mean diff: 0.00017452239990234375 +.Output max diff: 0.00048828125 +Output mean diff: 3.236532211303711e-05 +Pytorch max diff: 0.00146484375 +Pytorch mean diff: 5.4717063903808594e-05 +==== 256 512 False 128 True torch.float16 +dQ max diff: 0.0009765625 +dK max diff: 0.0009765625 +dV max diff: 0.0009765625 +dQ mean diff: 2.956390380859375e-05 +dK mean diff: 2.086162567138672e-05 +dV mean diff: 1.8596649169921875e-05 +dQ Pytorch max diff: 0.001953125 +dK Pytorch max diff: 0.00390625 +dV Pytorch max diff: 0.001220703125 +dQ Pytorch mean diff: 6.276369094848633e-05 +dK Pytorch mean diff: 3.8504600524902344e-05 +dV Pytorch mean diff: 3.212690353393555e-05 +.Output max diff: 0.00390625 +Output mean diff: 0.00013446807861328125 +Pytorch max diff: 0.01171875 +Pytorch mean diff: 0.000377655029296875 +==== 256 512 False 128 True torch.bfloat16 +dQ max diff: 0.0078125 +dK max diff: 0.0078125 +dV max diff: 0.00390625 +dQ mean diff: 0.00015735626220703125 +dK mean diff: 9.298324584960938e-05 +dV mean diff: 9.632110595703125e-05 +dQ Pytorch max diff: 0.015625 +dK Pytorch max diff: 0.015625 +dV Pytorch max diff: 0.01171875 +dQ Pytorch mean diff: 0.000453948974609375 +dK Pytorch mean diff: 0.00026702880859375 +dV Pytorch mean diff: 0.00022125244140625 +.Output max diff: 0.001953125 +Output mean diff: 1.7821788787841797e-05 +Pytorch max diff: 0.0029296875 +Pytorch mean diff: 2.8133392333984375e-05 +==== 512 256 True 64 False torch.float16 +dQ max diff: 0.001953125 +dK max diff: 0.001953125 +dV max diff: 0.0009765625 +dQ mean diff: 7.450580596923828e-06 +dK mean diff: 1.3172626495361328e-05 +dV mean diff: 2.384185791015625e-07 +dQ Pytorch max diff: 0.00390625 +dK Pytorch max diff: 0.001953125 +dV Pytorch max diff: 0.00390625 +dQ Pytorch mean diff: 3.224611282348633e-05 +dK Pytorch mean diff: 5.4717063903808594e-05 +dV Pytorch mean diff: 4.7147274017333984e-05 +.Output max diff: 0.015625 +Output mean diff: 8.821487426757812e-05 +Pytorch max diff: 0.015625 +Pytorch mean diff: 0.000213623046875 +==== 512 256 True 64 False torch.bfloat16 +dQ max diff: 0.015625 +dK max diff: 0.015625 +dV max diff: 0.015625 +dQ mean diff: 0.00010776519775390625 +dK mean diff: 0.00017452239990234375 +dV mean diff: 0.00018310546875 +dQ Pytorch max diff: 0.0234375 +dK Pytorch max diff: 0.015625 +dV Pytorch max diff: 0.03125 +dQ Pytorch mean diff: 0.00024318695068359375 +dK Pytorch mean diff: 0.0004119873046875 +dV Pytorch mean diff: 0.0003509521484375 +.Output max diff: 0.001953125 +Output mean diff: 1.817941665649414e-05 +Pytorch max diff: 0.0029296875 +Pytorch mean diff: 2.8789043426513672e-05 +==== 512 256 True 64 True torch.float16 +dQ max diff: 0.001953125 +dK max diff: 0.001953125 +dV max diff: 0.001953125 +dQ mean diff: 7.62939453125e-06 +dK mean diff: 1.3887882232666016e-05 +dV mean diff: 2.384185791015625e-07 +dQ Pytorch max diff: 0.00390625 +dK Pytorch max diff: 0.0029296875 +dV Pytorch max diff: 0.00390625 +dQ Pytorch mean diff: 3.3020973205566406e-05 +dK Pytorch mean diff: 5.7756900787353516e-05 +dV Pytorch mean diff: 4.9948692321777344e-05 +.Output max diff: 0.015625 +Output mean diff: 9.012222290039062e-05 +Pytorch max diff: 0.015625 +Pytorch mean diff: 0.000217437744140625 +==== 512 256 True 64 True torch.bfloat16 +dQ max diff: 0.015625 +dK max diff: 0.015625 +dV max diff: 0.03125 +dQ mean diff: 0.00011014938354492188 +dK mean diff: 0.00018405914306640625 +dV mean diff: 0.0001926422119140625 +dQ Pytorch max diff: 0.0234375 +dK Pytorch max diff: 0.015625 +dV Pytorch max diff: 0.03125 +dQ Pytorch mean diff: 0.000247955322265625 +dK Pytorch mean diff: 0.0004329681396484375 +dV Pytorch mean diff: 0.0003681182861328125 +.Output max diff: 0.001953125 +Output mean diff: 2.187490463256836e-05 +Pytorch max diff: 0.001953125 +Pytorch mean diff: 3.463029861450195e-05 +==== 512 256 True 128 False torch.float16 +dQ max diff: 0.00390625 +dK max diff: 0.001953125 +dV max diff: 0.00390625 +dQ mean diff: 1.8894672393798828e-05 +dK mean diff: 3.814697265625e-05 +dV mean diff: 3.421306610107422e-05 +dQ Pytorch max diff: 0.00390625 +dK Pytorch max diff: 0.00390625 +dV Pytorch max diff: 0.00390625 +dQ Pytorch mean diff: 3.8504600524902344e-05 +dK Pytorch mean diff: 6.74128532409668e-05 +dV Pytorch mean diff: 5.7637691497802734e-05 +.Output max diff: 0.015625 +Output mean diff: 8.821487426757812e-05 +Pytorch max diff: 0.015625 +Pytorch mean diff: 0.0002384185791015625 +==== 512 256 True 128 False torch.bfloat16 +dQ max diff: 0.03125 +dK max diff: 0.015625 +dV max diff: 0.03125 +dQ mean diff: 0.0001087188720703125 +dK mean diff: 0.0001773834228515625 +dV mean diff: 0.00018596649169921875 +dQ Pytorch max diff: 0.03125 +dK Pytorch max diff: 0.0234375 +dV Pytorch max diff: 0.03125 +dQ Pytorch mean diff: 0.000278472900390625 +dK Pytorch mean diff: 0.00046539306640625 +dV Pytorch mean diff: 0.000392913818359375 +.Output max diff: 0.001953125 +Output mean diff: 2.2351741790771484e-05 +Pytorch max diff: 0.001953125 +Pytorch mean diff: 3.546476364135742e-05 +==== 512 256 True 128 True torch.float16 +dQ max diff: 0.00390625 +dK max diff: 0.001953125 +dV max diff: 0.00390625 +dQ mean diff: 1.9371509552001953e-05 +dK mean diff: 4.029273986816406e-05 +dV mean diff: 3.612041473388672e-05 +dQ Pytorch max diff: 0.00390625 +dK Pytorch max diff: 0.0029296875 +dV Pytorch max diff: 0.00390625 +dQ Pytorch mean diff: 3.933906555175781e-05 +dK Pytorch mean diff: 7.12275505065918e-05 +dV Pytorch mean diff: 6.091594696044922e-05 +.Output max diff: 0.015625 +Output mean diff: 9.012222290039062e-05 +Pytorch max diff: 0.015625 +Pytorch mean diff: 0.00024318695068359375 +==== 512 256 True 128 True torch.bfloat16 +dQ max diff: 0.03125 +dK max diff: 0.015625 +dV max diff: 0.03125 +dQ mean diff: 0.00011110305786132812 +dK mean diff: 0.000186920166015625 +dV mean diff: 0.000194549560546875 +dQ Pytorch max diff: 0.03125 +dK Pytorch max diff: 0.0234375 +dV Pytorch max diff: 0.03125 +dQ Pytorch mean diff: 0.000286102294921875 +dK Pytorch mean diff: 0.000492095947265625 +dV Pytorch mean diff: 0.0004138946533203125 +""" \ No newline at end of file From 66b88a58bcd0caa5aa644bbb661e2a48ca507f1e Mon Sep 17 00:00:00 2001 From: jerryao Date: Wed, 20 Aug 2025 17:22:27 +0800 Subject: [PATCH 08/29] Fix max in fwd. --- csrc/flash_attn/src/flash_fwd_kernel.h | 1 + csrc/flash_attn/src/softmax.h | 5 +++-- tests/test_flash_attn.py | 24 +++++++++++++----------- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index bd5d4f30698..9e1a5d0a5d7 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -131,6 +131,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi if constexpr (Has_sink) { if (tidx == 0) { shared_sink_val = static_cast(reinterpret_cast(params.sink_ptr)[bidh]); } + __syncthreads(); } // We iterate over the blocks in reverse order. This is because the last block is the only one diff --git a/csrc/flash_attn/src/softmax.h b/csrc/flash_attn/src/softmax.h index 9031c9a63a7..b4b0f96d31c 100644 --- a/csrc/flash_attn/src/softmax.h +++ b/csrc/flash_attn/src/softmax.h @@ -141,8 +141,9 @@ struct Softmax { static_assert(decltype(size<0>(scores))::value == kNRows); if (Is_first) { if constexpr (Has_sink) { + const float sink_scaled = (sink_val * float(M_LOG2E) / softmax_scale_log2); #pragma unroll - for (int mi = 0; mi < size(row_max); ++mi) { row_max(mi) = sink_val; } + for (int mi = 0; mi < size(row_max); ++mi) { row_max(mi) = sink_scaled; } FLASH_NAMESPACE::template reduce_max(scores, row_max); } else { FLASH_NAMESPACE::template reduce_max(scores, row_max); @@ -185,7 +186,7 @@ struct Softmax { float sum = row_sum(mi); if (Has_sink) { const float max_scaled = row_max(mi) == -INFINITY ? 0.f : row_max(mi) * softmax_scale; - sum += expf(sink_val - max_scaled); + sum += exp2f((sink_val - max_scaled) * float(M_LOG2E)); } float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index fa5792b84f4..e9331f7b094 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -1541,16 +1541,16 @@ def test_flash_attn_varlen_output( @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ - # (1, 239), - # (3, 799), - # (127, 512), - # (127, 513), - # (113, 203), - # (128, 217), - # (113, 211), - # (108, 256), + (1, 239), + (3, 799), + (127, 512), + (127, 513), + (113, 203), + (128, 217), + (113, 211), + (108, 256), (256, 512), - # (1023, 1024), + (1023, 1024), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) @@ -2599,7 +2599,7 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus assert torch.equal(dk, dk0) assert torch.equal(dq, dq0) -@pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("dtype", ([torch.bfloat16])) @pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("swap_sq_sk", [True]) @@ -2627,7 +2627,7 @@ def test_flash_attn_sink_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype) k_ref = k.transpose(1, 2) v_ref = v.transpose(1, 2) sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) - out = flash_attn_sink_func(q, k, v, sink, 0.0, causal=causal, window_size=window_size) + out = flash_attn_sink_func(q, k, v, sink, 0.0, softmax_scale=d**-0.5, causal=causal, window_size=window_size) attention_mask = get_attention_mask(seqlen_q, seqlen_k, causal, device, window_size) if attention_mask is not None: @@ -2678,6 +2678,8 @@ def test_flash_attn_sink_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype) print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().detach().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().detach().item()}") print(f"dS Pytorch mean diff: {(dsink_pt - dsink_ref).abs().mean().detach().item()}") + print(f"dS Relative error: {torch.mean(torch.abs(dsink - dsink_ref) / (torch.abs(dsink_ref) + 1e-8)).detach()}") + print(f"dS Pytorch relative error: {torch.mean(torch.abs(dsink_pt - dsink_ref) / (torch.abs(dsink_ref) + 1e-8)).detach()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. From 27a005f06469c1eac40cebb61ad783bdf08b0f91 Mon Sep 17 00:00:00 2001 From: jerryao Date: Wed, 20 Aug 2025 21:02:56 +0800 Subject: [PATCH 09/29] Fix init dsink. --- csrc/flash_attn/flash_api.cpp | 3 +++ flash_attn/flash_attn_interface.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 97de9c4fee5..c428aaa3389 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -540,6 +540,7 @@ mha_sink_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round "FlashAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(sink.dtype() == torch::kFloat32, "sink must have the same dtype"); CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); @@ -559,6 +560,7 @@ mha_sink_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round TORCH_CHECK(head_size <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + TORCH_CHECK(num_heads == sink.size(0), "Number of heads in query must match sink size"); if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } @@ -582,6 +584,7 @@ mha_sink_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(sink, num_heads); at::Tensor out; if (out_.has_value()) { diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 8906499dc64..266a23e28bb 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -1039,7 +1039,7 @@ def forward( @staticmethod def backward(ctx, dout, *args): q, k, v, sink, out, softmax_lse, rng_state = ctx.saved_tensors - dq, dk, dv, dsink = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v), torch.empty_like(sink) + dq, dk, dv, dsink = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v), torch.zeros_like(sink) head_size_og = dout.size(3) dout_padded = dout if head_size_og % 8 != 0: From 26b4704cb8ddad0213e58280a1b18df035ec0c18 Mon Sep 17 00:00:00 2001 From: jerryao Date: Thu, 21 Aug 2025 16:04:00 +0800 Subject: [PATCH 10/29] Add to existing interface. --- csrc/flash_attn/flash_api.cpp | 440 ++---------------- csrc/flash_attn/src/flash.h | 2 +- csrc/flash_attn/src/flash_fwd_kernel.h | 29 +- .../src/flash_fwd_launch_template.h | 36 +- flash_attn/flash_attn_interface.py | 330 ++++--------- 5 files changed, 183 insertions(+), 654 deletions(-) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index c428aaa3389..597c66aca8a 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -347,176 +347,38 @@ void set_params_alibi(Flash_fwd_params ¶ms, std::optional &alibi #endif } -std::vector -mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) - const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) - const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) - std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) - std::optional &alibi_slopes_, // num_heads or batch_size x num_heads - const float p_dropout, - const float softmax_scale, - bool is_causal, - int window_size_left, - int window_size_right, - const float softcap, - const bool return_softmax, - std::optional gen_) { - - // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{q.device()}; - - auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); - bool is_sm8x_min = cc_major >= 8; - TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer."); - - auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, - "FlashAttention only support fp16 and bf16 data type"); - TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); - TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - - CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); - - TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - - const auto sizes = q.sizes(); - - const int batch_size = sizes[0]; - int seqlen_q = sizes[1]; - int num_heads = sizes[2]; - const int head_size = sizes[3]; - const int seqlen_k = k.size(1); - const int num_heads_k = k.size(2); - TORCH_CHECK(batch_size > 0, "batch size must be positive"); - TORCH_CHECK(head_size <= 256, "FlashAttention forward only supports head dimension at most 256"); - TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); - TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - - if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } - - if (window_size_left >= seqlen_k) { window_size_left = -1; } - if (window_size_right >= seqlen_k) { window_size_right = -1; } - - // causal=true is the same as causal=false in this case - if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } - if (is_causal) { window_size_right = 0; } - - // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case - // H/t Daniel Haziza - const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value(); - const int ngroups = num_heads / num_heads_k; - if (seqlenq_ngroups_swapped) { - q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2); - seqlen_q = ngroups; - num_heads = num_heads_k; - } - - CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); - CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); - CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); - - at::Tensor out; - if (out_.has_value()) { - out = out_.value(); - TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); - CHECK_DEVICE(out); - TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); - CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size); - if (seqlenq_ngroups_swapped) { - out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2); +void set_params_sink(Flash_fwd_params ¶ms, const std::optional &learnable_sink_, int num_heads, const std::optional &dsink_=std::nullopt) { +#ifdef FLASHATTENTION_DISABLE_ALIBI + TORCH_CHECK(!learnable_sink_.has_value(), "This flash attention build does not support learnable sink."); + params.learnable_sink_ptr = nullptr; +#else + if (learnable_sink_.has_value()) { + auto learnable_sink = learnable_sink_.value(); + TORCH_CHECK(learnable_sink.dtype() == torch::kFloat32, "Learnable sink must have dtype fp32"); + CHECK_DEVICE(learnable_sink); + TORCH_CHECK(learnable_sink.stride(-1) == 1, "Learnable sink tensor must have contiguous last dimension"); + CHECK_SHAPE(learnable_sink, num_heads); + params.learnable_sink_ptr = learnable_sink.data_ptr(); + if (dsink_.has_value()) { + auto dsink = dsink_.value(); + CHECK_DEVICE(dsink); + CHECK_SHAPE(dsink, num_heads); + params.dsink_ptr = dsink.data_ptr(); + } else { + params.dsink_ptr = nullptr; } } else { - out = torch::empty_like(q); - } - - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); - const int seqlen_q_rounded = round_multiple(seqlen_q, 128); - const int seqlen_k_rounded = round_multiple(seqlen_k, 128); - - auto opts = q.options(); - - auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); - at::Tensor p; - // Only return softmax if there's dropout to reduce compilation time - if (return_softmax) { - TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0"); - p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); - } - else { - p = torch::empty({ 0 }, opts); - } - - Flash_fwd_params params; - set_params_fprop(params, - batch_size, - seqlen_q, seqlen_k, - seqlen_q_rounded, seqlen_k_rounded, - num_heads, num_heads_k, - head_size, head_size_rounded, - q, k, v, out, - /*cu_seqlens_q_d=*/nullptr, - /*cu_seqlens_k_d=*/nullptr, - /*seqused_k=*/nullptr, - return_softmax ? p.data_ptr() : nullptr, - softmax_lse.data_ptr(), - p_dropout, - softmax_scale, - window_size_left, - window_size_right, - softcap - ); - params.sink_ptr = nullptr; - - // Keep references to these tensors to extend their lifetime - at::Tensor softmax_lse_accum, out_accum; - std::tie(softmax_lse_accum, out_accum) = set_params_splitkv( - params, batch_size, num_heads, head_size, seqlen_k, seqlen_q, - head_size_rounded, p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts); - - // number of times random will be generated per thread, to offset philox counter in thc random - // state - // We use a custom RNG that increases the offset by batch_size * nheads * 32. - int64_t counter_offset = params.b * params.h * 32; - auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); - // Forward kernel will populate memory with the seed and offset. - params.rng_state = reinterpret_cast(rng_state.data_ptr()); - - if (p_dropout > 0.0) { - auto gen = at::get_generator_or_default( - gen_, at::cuda::detail::getDefaultCUDAGenerator()); - // See Note [Acquire lock when using random generators] - std::lock_guard lock(gen->mutex_); - params.philox_args = gen->philox_cuda_state(counter_offset); - } - - set_params_alibi(params, alibi_slopes_, batch_size, num_heads); - - if (seqlen_k > 0) { - auto stream = at::cuda::getCurrentCUDAStream().stream(); - run_mha_fwd(params, stream); - } else { - // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. - out.zero_(); - softmax_lse.fill_(std::numeric_limits::infinity()); - } - - if (seqlenq_ngroups_swapped) { - out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size}); - q = q.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size}); - softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); + params.learnable_sink_ptr = nullptr; + params.dsink_ptr = nullptr; } - return {out, softmax_lse, p, rng_state}; +#endif } std::vector -mha_sink_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) +mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) - const at::Tensor &sink, // num_heads + std::optional &learnable_sink, // num_heads std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) std::optional &alibi_slopes_, // num_heads or batch_size x num_heads const float p_dropout, @@ -540,7 +402,6 @@ mha_sink_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round "FlashAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - TORCH_CHECK(sink.dtype() == torch::kFloat32, "sink must have the same dtype"); CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); @@ -560,7 +421,6 @@ mha_sink_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round TORCH_CHECK(head_size <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - TORCH_CHECK(num_heads == sink.size(0), "Number of heads in query must match sink size"); if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } @@ -584,7 +444,6 @@ mha_sink_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); - CHECK_SHAPE(sink, num_heads); at::Tensor out; if (out_.has_value()) { @@ -637,7 +496,6 @@ mha_sink_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round window_size_right, softcap ); - params.sink_ptr = sink.data_ptr(); // Keep references to these tensors to extend their lifetime @@ -664,6 +522,7 @@ mha_sink_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round } set_params_alibi(params, alibi_slopes_, batch_size, num_heads); + set_params_sink(params, learnable_sink, num_heads); if (seqlen_k > 0) { auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -686,6 +545,7 @@ std::vector mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + std::optional &learnable_sink, // num_heads std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 @@ -904,6 +764,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s } set_params_alibi(params, alibi_slopes_, batch_size, num_heads); + set_params_sink(params, learnable_sink, num_heads); if (max_seqlen_k > 0) { auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -940,11 +801,13 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &learnable_sink, // num_heads const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &softmax_lse, // b x h x seqlen_q std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size std::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size std::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &dsink_, // num_heads std::optional &alibi_slopes_, // num_heads or batch_size x num_heads const float p_dropout, // probability to drop const float softmax_scale, @@ -1017,7 +880,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size); - at::Tensor dq, dk, dv; + at::Tensor dq, dk, dv, dsink; if (dq_.has_value()) { dq = dq_.value(); TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); @@ -1045,224 +908,12 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl } else { dv = torch::empty_like(v); } - - // bool loop = seqlen_k > blocksize_c; - // TODO: change later, for now set to true for simplicity - bool loop = true; - - auto opts = q.options(); - auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat)); - at::Tensor dq_accum; - at::Tensor dk_accum, dv_accum; - if (loop) { - if (!deterministic) { - dq_accum = torch::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); + if (learnable_sink.has_value()) { + if (dsink_.has_value()) { + dsink = dsink_.value(); } else { - const int nsplits = (get_num_sm(get_current_device()) + batch_size * num_heads - 1) / (batch_size * num_heads); - dq_accum = torch::zeros({nsplits, batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); + dsink = torch::zeros_like(v); } - // dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); - // dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); - } - - at::Tensor dk_expanded, dv_expanded; - if (num_heads_k != num_heads) { // MQA / GQA - dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts); - dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts); - } else { - dk_expanded = dk; - dv_expanded = dv; - } - - Flash_bwd_params params; - - set_params_dgrad(params, - batch_size, - seqlen_q, seqlen_k, - seqlen_q_rounded, seqlen_k_rounded, - num_heads, num_heads_k, - head_size, head_size_rounded, - q, k, v, out, - dout, dq, dk_expanded, dv_expanded, - nullptr, - nullptr, - loop ? dq_accum.data_ptr() : nullptr, - // loop ? dk_accum.data_ptr() : nullptr, - // loop ? dv_accum.data_ptr() : nullptr, - nullptr, - nullptr, - softmax_lse.data_ptr(), - softmax_d.data_ptr(), - p_dropout, - softmax_scale, - window_size_left, - window_size_right, - softcap, - deterministic, - /*unpadded_lse*/false); - params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0); - params.sink_ptr = nullptr; - params.dsink_ptr = nullptr; - - auto launch = &run_mha_bwd; - - auto gen = at::get_generator_or_default( - gen_, at::cuda::detail::getDefaultCUDAGenerator()); - - // We use a custom RNG that increases the offset by batch_size * nheads * 32. - int64_t counter_offset = params.b * params.h * 32; - - if ( rng_state.has_value() ) { - params.rng_state = reinterpret_cast(rng_state.value().data_ptr()); - } else if( is_dropout ) { - // See Note [Acquire lock when using random generators] - std::lock_guard lock(gen->mutex_); - params.philox_args = gen->philox_cuda_state(counter_offset); - auto seeds = at::cuda::philox::unpack(params.philox_args); - params.rng_state[0] = std::get<0>(seeds); - params.rng_state[1] = std::get<1>(seeds); - } - - set_params_alibi(params, alibi_slopes_, batch_size, num_heads); - - if (seqlen_q > 0) { - launch(params, stream); - } else { - // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. - dk_expanded.zero_(); - dv_expanded.zero_(); - softmax_d.zero_(); - } - - // For MQA/GQA we need to sum dK and dV across the groups - if (num_heads_k != num_heads) { - at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); - at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); - } - - return { dq, dk, dv, softmax_d }; -} - -std::vector -mha_sink_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8) - const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor &sink, // num_heads - const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor &softmax_lse, // b x h x seqlen_q - std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size - std::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size - std::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size - std::optional &dsink_, // num_heads - std::optional &alibi_slopes_, // num_heads or batch_size x num_heads - const float p_dropout, // probability to drop - const float softmax_scale, - const bool is_causal, - int window_size_left, - int window_size_right, - const float softcap, - const bool deterministic, - std::optional gen_, - std::optional &rng_state) { - - #ifdef FLASHATTENTION_DISABLE_BACKWARD - TORCH_CHECK(false, "This flash attention build does not support backward."); - #endif - if (is_causal) { window_size_right = 0; } - - // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{q.device()}; - - auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); - bool is_sm8x_min = cc_major >= 8; - TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer."); - - bool is_dropout = p_dropout > 0.0; - auto stream = at::cuda::getCurrentCUDAStream().stream(); - - auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, - "FlashAttention only support fp16 and bf16 data type"); - TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); - TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); - TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); - - CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); - CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); - - TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); - TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); - - const auto sizes = q.sizes(); - - const int batch_size = sizes[0]; - const int seqlen_q = sizes[1]; - const int num_heads = sizes[2]; - const int head_size = sizes[3]; - const int seqlen_k = k.size(1); - const int num_heads_k = k.size(2); - TORCH_CHECK(batch_size > 0, "batch size must be positive"); - TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); - TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); - TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); - const int seqlen_q_rounded = round_multiple(seqlen_q, 128); - const int seqlen_k_rounded = round_multiple(seqlen_k, 128); - - if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } - - if (window_size_left >= seqlen_k) { window_size_left = -1; } - if (window_size_right >= seqlen_k) { window_size_right = -1; } - - CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); - CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); - CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); - CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); - CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size); - - at::Tensor dq, dk, dv, dsink; - if (dq_.has_value()) { - dq = dq_.value(); - TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); - CHECK_DEVICE(dq); - TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); - CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size); - } else { - dq = torch::empty_like(q); - } - if (dk_.has_value()) { - dk = dk_.value(); - TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); - CHECK_DEVICE(dk); - TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); - CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size); - } else { - dk = torch::empty_like(k); - } - if (dv_.has_value()) { - dv = dv_.value(); - TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); - CHECK_DEVICE(dv); - TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); - CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size); - } else { - dv = torch::empty_like(v); - } - if (dsink_.has_value()) { - dsink = dsink_.value(); - TORCH_CHECK(dsink.dtype() == at::kFloat, "dsink must have the dtype float32"); - CHECK_DEVICE(dsink); - TORCH_CHECK(dsink.stride(-1) == 1, "dsink must have contiguous last dimension"); - CHECK_SHAPE(dsink, num_heads); - } else { - dsink = torch::empty_like(sink, sink.options().dtype(at::kFloat)); } // bool loop = seqlen_k > blocksize_c; @@ -1320,8 +971,7 @@ mha_sink_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x mu deterministic, /*unpadded_lse*/false); params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0); - params.sink_ptr = sink.data_ptr(); - params.dsink_ptr = dsink.data_ptr(); + auto launch = &run_mha_bwd; @@ -1343,6 +993,7 @@ mha_sink_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x mu } set_params_alibi(params, alibi_slopes_, batch_size, num_heads); + set_params_sink(params, learnable_sink, num_heads); if (seqlen_q > 0) { launch(params, stream); @@ -1362,16 +1013,19 @@ mha_sink_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x mu return { dq, dk, dv, dsink, softmax_d }; } + std::vector mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional &learnable_sink, // num_heads const at::Tensor &out, // total_q x num_heads x head_size const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp std::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i std::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i std::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional &dsink_, // num_heads const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 std::optional &alibi_slopes_, // num_heads or b x num_heads @@ -1455,7 +1109,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size CHECK_SHAPE(cu_seqlens_q, batch_size + 1); CHECK_SHAPE(cu_seqlens_k, batch_size + 1); - at::Tensor dq, dk, dv; + at::Tensor dq, dk, dv, dsink; if (dq_.has_value()) { dq = dq_.value(); TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); @@ -1483,6 +1137,13 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size } else { dv = torch::empty_like(v); } + if (learnable_sink.has_value()) { + if (dsink_.has_value()) { + dsink = dsink_.value(); + } else { + dsink = torch::zeros_like(v); + } + } // bool loop = max_seqlen_k > blocksize_c; // TODO: change later, for now set to true for simplicity @@ -1572,6 +1233,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size } set_params_alibi(params, alibi_slopes_, batch_size, num_heads); + set_params_sink(params, learnable_sink, num_heads, dsink); if (max_seqlen_q > 0) { launch(params, stream); @@ -1588,7 +1250,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2}); } - return { dq, dk, dv, softmax_d }; + return { dq, dk, dv, dsink, softmax_d }; } std::vector @@ -1604,6 +1266,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he std::optional &leftpad_k_, // batch_size std::optional &block_table_, // batch_size x max_num_blocks_per_seq std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + std::optional &learnable_sink_, // num_heads std::optional &out_, // batch_size x seqlen_q x num_heads x head_size const float softmax_scale, bool is_causal, @@ -1842,6 +1505,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he set_params_alibi(params, alibi_slopes_, batch_size, num_heads); + set_params_sink(params, learnable_sink_, num_heads); auto stream = at::cuda::getCurrentCUDAStream().stream(); // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx, @@ -1871,9 +1535,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "FlashAttention"; m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass"); m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass (variable length)"); - m.def("sink_fwd", &FLASH_NAMESPACE::mha_sink_fwd, "Forward pass (with sink)"); m.def("bwd", &FLASH_NAMESPACE::mha_bwd, "Backward pass"); m.def("varlen_bwd", &FLASH_NAMESPACE::mha_varlen_bwd, "Backward pass (variable length)"); - m.def("sink_bwd", &FLASH_NAMESPACE::mha_sink_bwd, "Backward pass (with sink)"); m.def("fwd_kvcache", &FLASH_NAMESPACE::mha_fwd_kvcache, "Forward pass, with KV-cache"); } diff --git a/csrc/flash_attn/src/flash.h b/csrc/flash_attn/src/flash.h index 89613ddb085..151c4428da2 100644 --- a/csrc/flash_attn/src/flash.h +++ b/csrc/flash_attn/src/flash.h @@ -141,7 +141,7 @@ struct Flash_fwd_params : public Qkv_params { bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q]. bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d). - void *__restrict__ sink_ptr; // For gpt_oss + void *__restrict__ learnable_sink_ptr; // For gpt_oss }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 9e1a5d0a5d7..7e3ac17820d 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -48,7 +48,7 @@ __forceinline__ __device__ auto get_lse_tile(const Params ¶ms, const int bid } -template +template inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { using Element = typename Kernel_traits::Element; @@ -57,7 +57,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Shared memory. extern __shared__ char smem_[]; - __shared__ float shared_sink_val; // The thread index. const int tidx = threadIdx.x; @@ -129,11 +128,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi } // if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); } - if constexpr (Has_sink) { - if (tidx == 0) { shared_sink_val = static_cast(reinterpret_cast(params.sink_ptr)[bidh]); } - __syncthreads(); - } - // We iterate over the blocks in reverse order. This is because the last block is the only one // that needs masking when we read K and V from global memory. Moreover, iterating in reverse // might save us 1 register (we just need n_block instead of both n_block and n_block_max). @@ -288,7 +282,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi clear(acc_o); - FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax(shared_sink_val); + const float sink_val = !Has_sink || params.sink_ptr == nullptr ? -INFINITY : reinterpret_cast(params.sink_ptr)[bidh]; + FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax(sink_val); const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; FLASH_NAMESPACE::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); @@ -501,7 +496,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { using Element = typename Kernel_traits::Element; @@ -540,6 +535,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); } + if (tidx == 0) printf("compute_attn_1rowblock_splitkv: m_block = %d, binfo.actual_seqlen_q = %d, kBlockM = %d, binfo.actual_seqlen_k = %d, kBlockN = %d, params.sink_ptr = %p, Split = %d\n", m_block, binfo.actual_seqlen_q, kBlockM, binfo.actual_seqlen_k, kBlockN, params.sink_ptr, Split); if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0 // We exit early and write 0 to gOaccum and -inf to gLSEaccum. // Otherwise we might read OOB elements from gK and gV, @@ -839,7 +835,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons clear(acc_o); - FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax; + const float sink_val = !Has_sink || params.sink_ptr == nullptr ? -INFINITY : reinterpret_cast(params.sink_ptr)[bidh]; + FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax(sink_val); const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; FLASH_NAMESPACE::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); @@ -920,7 +917,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // We have key_padding_mask so we'll need to Check_inf masking_step == 0 - ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) + ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } @@ -1000,7 +997,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Epilogue - Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); // if (cute::thread0()) { print(lse); } Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) @@ -1078,7 +1075,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn(const Params ¶ms) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1094,12 +1091,12 @@ inline __device__ void compute_attn(const Params ¶ms) { // the attention matrix. This way, as long as we have the batch, head, and the location of // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. - FLASH_NAMESPACE::compute_attn_1rowblock(params, bidb, bidh, m_block); + FLASH_NAMESPACE::compute_attn_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_splitkv(const Params ¶ms) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1108,7 +1105,7 @@ inline __device__ void compute_attn_splitkv(const Params ¶ms) { const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; const int n_split_idx = Split ? blockIdx.y : 0; const int num_n_splits = Split ? gridDim.y : 1; - FLASH_NAMESPACE::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); + FLASH_NAMESPACE::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index da11589e06c..4b9301291f4 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -29,18 +29,18 @@ namespace FLASH_NAMESPACE { template \ __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, bool Has_sink) { +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Has_sink, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) { #if defined(ARCH_SUPPORTS_FLASH) static_assert(!(Is_causal && Is_local)); // Enforce constraints - FLASH_NAMESPACE::compute_attn(params); + FLASH_NAMESPACE::compute_attn(params); #else FLASH_UNSUPPORTED_ARCH #endif } -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV) { +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Has_sink, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV) { #if defined(ARCH_SUPPORTS_FLASH) - FLASH_NAMESPACE::compute_attn_splitkv(params); + FLASH_NAMESPACE::compute_attn_splitkv(params); #else FLASH_UNSUPPORTED_ARCH #endif @@ -77,7 +77,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // If return_softmax, set IsEvenMNConst to false to reduce number of templates // If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_kernel; + auto kernel = &flash_fwd_kernel; // auto kernel = &flash_fwd_kernel; // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); // auto kernel = &flash_fwd_kernel; @@ -116,18 +116,20 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { - // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_splitkv_kernel; - // auto kernel = &flash_fwd_splitkv_kernel; - // auto kernel = &flash_fwd_splitkv_kernel; - if (smem_size >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + SINK_SWITCH(params.sink_ptr != nullptr, Has_sink, [&] { + // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); }); }); diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 266a23e28bb..5867bbe2f13 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -78,6 +78,7 @@ def _flash_attn_forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + learnable_sink: Optional[torch.Tensor], dropout_p: float, softmax_scale: float, causal: bool, @@ -92,6 +93,7 @@ def _flash_attn_forward( q, k, v, + learnable_sink, None, alibi_slopes, dropout_p, @@ -111,6 +113,7 @@ def _flash_attn_forward_fake( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + learnable_sink: Optional[torch.Tensor], dropout_p: float, softmax_scale: float, causal: bool, @@ -133,7 +136,7 @@ def _flash_attn_forward_fake( return out, softmax_lse, p, rng_state -if torch.__version__ >= "2.4.0": +if torch.__version__ >= "12.4.0": _wrapped_flash_attn_forward = torch.ops.flash_attn._flash_attn_forward else: _wrapped_flash_attn_forward = _flash_attn_forward @@ -144,6 +147,7 @@ def _flash_attn_varlen_forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + learnable_sink: Optional[torch.Tensor], cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, @@ -166,6 +170,7 @@ def _flash_attn_varlen_forward( q, k, v, + learnable_sink, None, cu_seqlens_q, cu_seqlens_k, @@ -195,6 +200,7 @@ def _flash_attn_varlen_forward_fake( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + learnable_sink: Optional[torch.Tensor], cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, @@ -240,11 +246,13 @@ def _flash_attn_backward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + learnable_sink: Optional[torch.Tensor], out: torch.Tensor, softmax_lse: torch.Tensor, dq: Optional[torch.Tensor], dk: Optional[torch.Tensor], dv: Optional[torch.Tensor], + dsink: Optional[torch.Tensor], dropout_p: float, softmax_scale: float, causal: bool, @@ -261,17 +269,20 @@ def _flash_attn_backward( dq, dk, dv, + dsink, softmax_d, - ) = flash_attn_gpu.bwd( + ) = flash_attn_gpu.sink_bwd( dout, q, k, v, + learnable_sink, out, softmax_lse, dq, dk, dv, + dsink, alibi_slopes, dropout_p, softmax_scale, @@ -292,11 +303,13 @@ def _flash_attn_backward_fake( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + learnable_sink: Optional[torch.Tensor], out: torch.Tensor, softmax_lse: torch.Tensor, dq: Optional[torch.Tensor], dk: Optional[torch.Tensor], dv: Optional[torch.Tensor], + dsink: Optional[torch.Tensor], dropout_p: float, softmax_scale: float, causal: bool, @@ -314,6 +327,8 @@ def _flash_attn_backward_fake( dk = torch.empty_like(k) if dv is None: dv = torch.empty_like(v) + if dsink is None and learnable_sink is not None: + dsink = torch.empty_like(learnable_sink) batch_size, seqlen_q, num_heads, _ = q.shape softmax_d = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128)), device=q.device, dtype=torch.float32) @@ -332,11 +347,13 @@ def _flash_attn_varlen_backward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + learnable_sink: Optional[torch.Tensor], out: torch.Tensor, softmax_lse: torch.Tensor, dq: Optional[torch.Tensor], dk: Optional[torch.Tensor], dv: Optional[torch.Tensor], + dsink: Optional[torch.Tensor], cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, @@ -358,17 +375,20 @@ def _flash_attn_varlen_backward( dq, dk, dv, + dsink, softmax_d, ) = flash_attn_gpu.varlen_bwd( dout, q, k, v, + learnable_sink, out, softmax_lse, dq, dk, dv, + dsink, cu_seqlens_q, cu_seqlens_k, alibi_slopes, @@ -396,11 +416,13 @@ def _flash_attn_varlen_backward_fake( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + learnable_sink: Optional[torch.Tensor], out: torch.Tensor, softmax_lse: torch.Tensor, dq: Optional[torch.Tensor], dk: Optional[torch.Tensor], dv: Optional[torch.Tensor], + dsink: Optional[torch.Tensor], cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, @@ -426,6 +448,8 @@ def _flash_attn_varlen_backward_fake( dk = torch.empty_like(k) if dv is None: dv = torch.empty_like(v) + if dsink is None and learnable_sink is not None: + dsink = torch.empty_like(learnable_sink) softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32) return softmax_d @@ -442,6 +466,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): def forward( ctx, qkv, + learnable_sink, dropout_p, softmax_scale, causal, @@ -465,6 +490,7 @@ def forward( q, k, v, + learnable_sink, dropout_p, softmax_scale, causal=causal, @@ -475,7 +501,7 @@ def forward( return_softmax=return_softmax and dropout_p > 0, ) if is_grad: - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) + ctx.save_for_backward(q, k, v, learnable_sink, out_padded, softmax_lse, rng_state) ctx.dropout_p = dropout_p ctx.softmax_scale = softmax_scale ctx.causal = causal @@ -488,9 +514,10 @@ def forward( @staticmethod def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors + q, k, v, learnable_sink, out, softmax_lse, rng_state = ctx.saved_tensors qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) + dsink = None if learnable_sink is None else torch.empty_like(learnable_sink) head_size_og = dout.size(3) dout_padded = dout if head_size_og % 8 != 0: @@ -500,11 +527,13 @@ def backward(ctx, dout, *args): q, k, v, + learnable_sink, out, softmax_lse, dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2], + dsink, ctx.dropout_p, ctx.softmax_scale, ctx.causal, @@ -516,7 +545,7 @@ def backward(ctx, dout, *args): rng_state=rng_state, ) dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None, None, None, None, None, None + return dqkv, dsink, None, None, None, None, None, None, None, None, None class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): @@ -524,6 +553,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): def forward( ctx, qkv, + learnable_sink, cu_seqlens, max_seqlen, dropout_p, @@ -549,6 +579,7 @@ def forward( q, k, v, + learnable_sink, cu_seqlens, cu_seqlens, max_seqlen, @@ -564,7 +595,7 @@ def forward( block_table=None, ) if is_grad: - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state) + ctx.save_for_backward(q, k, v, learnable_sink, out_padded, softmax_lse, cu_seqlens, rng_state) ctx.dropout_p = dropout_p ctx.max_seqlen = max_seqlen ctx.softmax_scale = softmax_scale @@ -578,9 +609,10 @@ def forward( @staticmethod def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors + q, k, v, learnable_sink, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) + dsink = None if learnable_sink is None else torch.empty_like(learnable_sink) head_size_og = dout.size(2) dout_padded = dout if head_size_og % 8 != 0: @@ -590,11 +622,13 @@ def backward(ctx, dout, *args): q, k, v, + learnable_sink, out, softmax_lse, dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], + dsink, cu_seqlens, cu_seqlens, ctx.max_seqlen, @@ -610,7 +644,7 @@ def backward(ctx, dout, *args): rng_state=rng_state, ) dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None, None, None, None, None, None, None, None + return dqkv, dsink, None, None, None, None, None, None, None, None, None, None, None class FlashAttnKVPackedFunc(torch.autograd.Function): @@ -619,6 +653,7 @@ def forward( ctx, q, kv, + learnable_sink, dropout_p, softmax_scale, causal, @@ -644,6 +679,7 @@ def forward( q, k, v, + learnable_sink, dropout_p, softmax_scale, causal=causal, @@ -654,7 +690,7 @@ def forward( return_softmax=return_softmax and dropout_p > 0, ) if is_grad: - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) + ctx.save_for_backward(q, k, v, learnable_sink, out_padded, softmax_lse, rng_state) ctx.dropout_p = dropout_p ctx.softmax_scale = softmax_scale ctx.causal = causal @@ -667,10 +703,11 @@ def forward( @staticmethod def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors + q, k, v, learnable_sink, out, softmax_lse, rng_state = ctx.saved_tensors dq = torch.empty_like(q) kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) + dsink = None if learnable_sink is None else torch.empty_like(learnable_sink) head_size_og = dout.size(3) dout_padded = dout if head_size_og % 8 != 0: @@ -680,11 +717,13 @@ def backward(ctx, dout, *args): q, k, v, + learnable_sink, out, softmax_lse, dq, dkv[:, :, 0], dkv[:, :, 1], + dsink, ctx.dropout_p, ctx.softmax_scale, ctx.causal, @@ -697,7 +736,7 @@ def backward(ctx, dout, *args): ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dkv = dkv[..., : dout.shape[-1]] - return dq, dkv, None, None, None, None, None, None, None, None, None + return dq, dkv, dsink, None, None, None, None, None, None, None, None, None class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): @@ -706,6 +745,7 @@ def forward( ctx, q, kv, + learnable_sink, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -735,6 +775,7 @@ def forward( q, k, v, + learnable_sink, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -751,7 +792,7 @@ def forward( ) if is_grad: ctx.save_for_backward( - q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state + q, k, v, learnable_sink, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state ) ctx.dropout_p = dropout_p ctx.max_seqlen_q = max_seqlen_q @@ -767,10 +808,11 @@ def forward( @staticmethod def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + q, k, v, learnable_sink, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors dq = torch.empty_like(q) kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) + dsink = None if learnable_sink is None else torch.empty_like(learnable_sink) head_size_og = dout.size(2) dout_padded = dout if head_size_og % 8 != 0: @@ -780,11 +822,13 @@ def backward(ctx, dout, *args): q, k, v, + learnable_sink, out, softmax_lse, dq, dkv[:, 0], dkv[:, 1], + dsink, cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, @@ -801,7 +845,7 @@ def backward(ctx, dout, *args): ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dkv = dkv[..., : dout.shape[-1]] - return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None, None + return dq, dkv, dsink, None, None, None, None, None, None, None, None, None, None, None, None, None class FlashAttnFunc(torch.autograd.Function): @@ -811,6 +855,7 @@ def forward( q, k, v, + learnable_sink, dropout_p, softmax_scale, causal, @@ -835,6 +880,7 @@ def forward( q, k, v, + learnable_sink, dropout_p, softmax_scale, causal=causal, @@ -845,7 +891,7 @@ def forward( return_softmax=return_softmax and dropout_p > 0, ) if is_grad: - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) + ctx.save_for_backward(q, k, v, learnable_sink, out_padded, softmax_lse, rng_state) ctx.dropout_p = dropout_p ctx.softmax_scale = softmax_scale ctx.causal = causal @@ -858,8 +904,9 @@ def forward( @staticmethod def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors + q, k, v, learnable_sink, out, softmax_lse, rng_state = ctx.saved_tensors dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + dsink = None if learnable_sink is None else torch.empty_like(learnable_sink) head_size_og = dout.size(3) dout_padded = dout if head_size_og % 8 != 0: @@ -869,187 +916,7 @@ def backward(ctx, dout, *args): q, k, v, - out, - softmax_lse, - dq, - dk, - dv, - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.alibi_slopes, - ctx.deterministic, - rng_state=rng_state, - ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dk = dk[..., : dout.shape[-1]] - dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None - -@_torch_custom_op_wrapper("flash_attn::_flash_attn_sink_forward", mutates_args=(), device_types="cuda") -def _flash_attn_sink_forward( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - sink: torch.Tensor, - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - softcap: float, - alibi_slopes: Optional[torch.Tensor], - return_softmax: bool -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - q, k, v, sink = [maybe_contiguous(x) for x in (q, k, v, sink)] - out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.sink_fwd( - q, - k, - v, - sink, - None, - alibi_slopes, - dropout_p, - softmax_scale, - causal, - window_size_left, - window_size_right, - softcap, - return_softmax, - None, - ) - return out, softmax_lse, S_dmask, rng_state - -@_torch_custom_op_wrapper("flash_attn::_flash_attn_sink_backward", mutates_args=("dq", "dk", "dv", "dsink"), device_types="cuda") -def _flash_attn_sink_backward( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - sink: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - dq: Optional[torch.Tensor], - dk: Optional[torch.Tensor], - dv: Optional[torch.Tensor], - dsink: Optional[torch.Tensor], - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - softcap: float, - alibi_slopes: Optional[torch.Tensor], - deterministic: bool, - rng_state: Optional[torch.Tensor] = None, -) -> torch.Tensor: - # dq, dk, dv are allocated by us so they should already be contiguous - dout, q, k, v, sink, out = [maybe_contiguous(x) for x in (dout, q, k, v, sink, out)] - ( - dq, - dk, - dv, - dsink, - softmax_d, - ) = flash_attn_gpu.sink_bwd( - dout, - q, - k, - v, - sink, - out, - softmax_lse, - dq, - dk, - dv, - dsink, - alibi_slopes, - dropout_p, - softmax_scale, - causal, - window_size_left, - window_size_right, - softcap, - deterministic, - None, - rng_state, - ) - return softmax_d - -_wrapped_flash_attn_sink_forward = _flash_attn_sink_forward -_wrapped_flash_attn_sink_backward = _flash_attn_sink_backward - -class FlashAttnSinkFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - sink, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - is_grad_enabled, - ): - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q, k, v] - ) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - head_size_og = q.size(3) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_sink_forward( - q, - k, - v, - sink, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - softcap=softcap, - alibi_slopes=alibi_slopes, - return_softmax=return_softmax and dropout_p > 0, - ) - if is_grad: - ctx.save_for_backward(q, k, v, sink, out_padded, softmax_lse, rng_state) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - out = out_padded[..., :head_size_og] - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, sink, out, softmax_lse, rng_state = ctx.saved_tensors - dq, dk, dv, dsink = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v), torch.zeros_like(sink) - head_size_og = dout.size(3) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - _wrapped_flash_attn_sink_backward( - dout_padded, - q, - k, - v, - sink, + learnable_sink, out, softmax_lse, dq, @@ -1079,6 +946,7 @@ def forward( q, k, v, + learnable_sink, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -1108,6 +976,7 @@ def forward( q, k, v, + learnable_sink, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -1124,7 +993,7 @@ def forward( ) if is_grad: ctx.save_for_backward( - q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state + q, k, v, learnable_sink, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state ) ctx.dropout_p = dropout_p ctx.max_seqlen_q = max_seqlen_q @@ -1141,8 +1010,9 @@ def forward( @staticmethod def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + q, k, v, learnable_sink, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + dsink = None if learnable_sink is None else torch.empty_like(learnable_sink) head_size_og = dout.size(2) dout_padded = dout if head_size_og % 8 != 0: @@ -1157,6 +1027,7 @@ def backward(ctx, dout, *args): dq, dk, dv, + dsink, cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, @@ -1174,11 +1045,12 @@ def backward(ctx, dout, *args): dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None + return dq, dk, dv, dsink, None, None, None, None, None, None, None, None, None, None, None, None, None, None def flash_attn_qkvpacked_func( qkv, + learnable_sink=None, dropout_p=0.0, softmax_scale=None, causal=False, @@ -1200,6 +1072,8 @@ def flash_attn_qkvpacked_func( Arguments: qkv: (batch_size, seqlen, 3, nheads, headdim) + learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an + additional attention logit to each query's attention scores before softmax. dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). @@ -1224,6 +1098,7 @@ def flash_attn_qkvpacked_func( """ return FlashAttnQKVPackedFunc.apply( qkv, + learnable_sink, dropout_p, softmax_scale, causal, @@ -1239,6 +1114,7 @@ def flash_attn_qkvpacked_func( def flash_attn_kvpacked_func( q, kv, + learnable_sink, dropout_p=0.0, softmax_scale=None, causal=False, @@ -1276,6 +1152,8 @@ def flash_attn_kvpacked_func( Arguments: q: (batch_size, seqlen, nheads, headdim) kv: (batch_size, seqlen, 2, nheads_k, headdim) + learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an + additional attention logit to each query's attention scores before softmax. dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). @@ -1302,6 +1180,7 @@ def flash_attn_kvpacked_func( return FlashAttnKVPackedFunc.apply( q, kv, + learnable_sink, dropout_p, softmax_scale, causal, @@ -1318,6 +1197,7 @@ def flash_attn_func( q, k, v, + learnable_sink, dropout_p=0.0, softmax_scale=None, causal=False, @@ -1353,6 +1233,8 @@ def flash_attn_func( q: (batch_size, seqlen, nheads, headdim) k: (batch_size, seqlen, nheads_k, headdim) v: (batch_size, seqlen, nheads_k, headdim) + learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an + additional attention logit to each query's attention scores before softmax. dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). @@ -1379,37 +1261,7 @@ def flash_attn_func( q, k, v, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - torch.is_grad_enabled(), - ) - - -def flash_attn_sink_func( - q, - k, - v, - sink, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - return FlashAttnSinkFunc.apply( - q, - k, - v, - sink, + learnable_sink, dropout_p, softmax_scale, causal, @@ -1424,6 +1276,7 @@ def flash_attn_sink_func( def flash_attn_varlen_qkvpacked_func( qkv, + learnable_sink, cu_seqlens, max_seqlen, dropout_p=0.0, @@ -1447,6 +1300,8 @@ def flash_attn_varlen_qkvpacked_func( Arguments: qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch. + learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an + additional attention logit to each query's attention scores before softmax. cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into qkv. max_seqlen: int. Maximum sequence length in the batch. @@ -1474,6 +1329,7 @@ def flash_attn_varlen_qkvpacked_func( """ return FlashAttnVarlenQKVPackedFunc.apply( qkv, + learnable_sink, cu_seqlens, max_seqlen, dropout_p, @@ -1491,6 +1347,7 @@ def flash_attn_varlen_qkvpacked_func( def flash_attn_varlen_kvpacked_func( q, kv, + learnable_sink, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -1532,6 +1389,8 @@ def flash_attn_varlen_kvpacked_func( Arguments: q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch. + learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an + additional attention logit to each query's attention scores before softmax. cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into q. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths @@ -1564,6 +1423,7 @@ def flash_attn_varlen_kvpacked_func( return FlashAttnVarlenKVPackedFunc.apply( q, kv, + learnable_sink, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -1584,6 +1444,7 @@ def flash_attn_varlen_func( q, k, v, + learnable_sink, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -1624,6 +1485,8 @@ def flash_attn_varlen_func( q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an + additional attention logit to each query's attention scores before softmax. cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into q. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths @@ -1657,6 +1520,7 @@ def flash_attn_varlen_func( q, k, v, + learnable_sink, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -1692,6 +1556,7 @@ def flash_attn_with_kvcache( softcap=0.0, # 0.0 means deactivated rotary_interleaved=True, alibi_slopes=None, + learnable_sink=None, num_splits=0, return_softmax_lse=False, ): @@ -1770,6 +1635,8 @@ def flash_attn_with_kvcache( alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) is added to the attention score of query i and key j. + learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an + additional attention logit to each query's attention scores before softmax. num_splits: int. If > 1, split the key/value into this many chunks along the sequence. If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic to automatically determine the number of splits. @@ -1807,6 +1674,7 @@ def flash_attn_with_kvcache( cache_leftpad, block_table, alibi_slopes, + learnable_sink, None, softmax_scale, causal, From 9a57df246cff4054de5c96bb22ce5731d30d2bc9 Mon Sep 17 00:00:00 2001 From: jerryao Date: Thu, 21 Aug 2025 17:29:20 +0800 Subject: [PATCH 11/29] Modify bwd. --- csrc/flash_attn/flash_api.cpp | 23 ++++++++-------- csrc/flash_attn/src/flash_bwd_kernel.h | 26 ++++++++----------- .../src/flash_bwd_launch_template.h | 12 ++++----- csrc/flash_attn/src/flash_fwd_kernel.h | 6 ++--- .../src/flash_fwd_launch_template.h | 4 +-- flash_attn/flash_attn_interface.py | 16 ++++++------ 6 files changed, 42 insertions(+), 45 deletions(-) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 597c66aca8a..32e8907dbaf 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -347,7 +347,7 @@ void set_params_alibi(Flash_fwd_params ¶ms, std::optional &alibi #endif } -void set_params_sink(Flash_fwd_params ¶ms, const std::optional &learnable_sink_, int num_heads, const std::optional &dsink_=std::nullopt) { +void set_params_sink(Flash_fwd_params ¶ms, const std::optional &learnable_sink_, int num_heads) { #ifdef FLASHATTENTION_DISABLE_ALIBI TORCH_CHECK(!learnable_sink_.has_value(), "This flash attention build does not support learnable sink."); params.learnable_sink_ptr = nullptr; @@ -359,17 +359,8 @@ void set_params_sink(Flash_fwd_params ¶ms, const std::optional 0) { launch(params, stream); @@ -1140,6 +1136,10 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size if (learnable_sink.has_value()) { if (dsink_.has_value()) { dsink = dsink_.value(); + TORCH_CHECK(dsink.dtype() == torch::kFloat32, "dsink must have dtype fp32"); + CHECK_DEVICE(dsink); + TORCH_CHECK(dsink.stride(-1) == 1, "dsink tensor must have contiguous last dimension"); + CHECK_SHAPE(dsink, num_heads); } else { dsink = torch::zeros_like(v); } @@ -1233,7 +1233,8 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size } set_params_alibi(params, alibi_slopes_, batch_size, num_heads); - set_params_sink(params, learnable_sink, num_heads, dsink); + set_params_sink(params, learnable_sink, num_heads); + params.dsink_ptr = learnable_sink.has_value() ? dsink.data_ptr() : nullptr; if (max_seqlen_q > 0) { launch(params, stream); diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index 9f666e5e1f0..18bc6164267 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -77,7 +77,7 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom const& copy_atom, //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, const int bidh, const int n_block) { using Element = typename Kernel_traits::Element; @@ -86,7 +86,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Shared memory. extern __shared__ char smem_[]; - __shared__ float shared_sink_val; float dsink_val = 0.f; // The thread index. @@ -107,9 +106,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left, kBlockM)); } - if constexpr (Has_sink) { - if (tidx == 0) { shared_sink_val = static_cast(reinterpret_cast(params.sink_ptr)[bidh]); } - } const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + (m_block_max - 1) * kBlockM * params.q_row_stride + bidh * params.q_head_stride; @@ -800,11 +796,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in ); if constexpr (Has_sink) { + const float sink_val = !Has_sink || params.learnable_sink_ptr == nullptr ? -INFINITY : reinterpret_cast(params.learnable_sink_ptr)[bidh]; SumOp sum_op; dsink_val = Allreduce<32>::run(dsink_val, sum_op); - if (tidx % 32 == 0) { + if (tidx % 32 == 0 && params.dsink_ptr != nullptr) { float* dsink_ptr = reinterpret_cast(params.dsink_ptr); - float val = -dsink_val * exp2f(shared_sink_val * float(M_LOG2E)); + float val = -dsink_val * exp2f(sink_val * float(M_LOG2E)); atomicAdd(dsink_ptr + bidh, val); } } @@ -813,7 +810,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv(const Params ¶ms) { // The block index for the batch. @@ -826,22 +823,21 @@ inline __device__ void compute_dq_dk_dv(const Params ¶ms) { const int tidx = threadIdx.x; const int n_block_max = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; - printf("n_block_max = %d\n", n_block_max); if (n_block_max == 1) { - compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); + compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); } else { // Iterating backward from n_block_max - 1 to 0 might save 1 register - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block_max - 1); + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block_max - 1); for (int n_block = n_block_max - 2; n_block > 0; n_block--) { - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); } - compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); + compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); } } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) { // The block index for the batch. @@ -851,7 +847,7 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) { // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer. for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) { - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); } } diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index 3b3b1e79640..7cd825cd7ce 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -31,18 +31,18 @@ namespace FLASH_NAMESPACE { template \ __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params) -DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K) { +DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Has_sink, bool Is_even_M, bool Is_even_K) { #if defined(ARCH_SUPPORTS_FLASH) - FLASH_NAMESPACE::compute_dq_dk_dv(params); + FLASH_NAMESPACE::compute_dq_dk_dv(params); #else FLASH_UNSUPPORTED_ARCH #endif } -DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Has_sink) { +DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Has_sink, bool Is_even_MN, bool Is_even_K, bool Is_softcap) { #if defined(ARCH_SUPPORTS_FLASH) static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false - FLASH_NAMESPACE::compute_dq_dk_dv_seqk_parallel(params); + FLASH_NAMESPACE::compute_dq_dk_dv_seqk_parallel(params); #else FLASH_UNSUPPORTED_ARCH #endif @@ -99,11 +99,11 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] { ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { - SINK_SWITCH(params.sink_ptr != nullptr, Has_sink, [&] { + SINK_SWITCH(params.learnable_sink_ptr != nullptr, Has_sink, [&] { // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If Is_local, set Is_causal to false - auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; if (smem_size_dq_dk_dv >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 7e3ac17820d..70a9d248c42 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -282,7 +282,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi clear(acc_o); - const float sink_val = !Has_sink || params.sink_ptr == nullptr ? -INFINITY : reinterpret_cast(params.sink_ptr)[bidh]; + const float sink_val = !Has_sink || params.learnable_sink_ptr == nullptr ? -INFINITY : reinterpret_cast(params.learnable_sink_ptr)[bidh]; FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax(sink_val); const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; @@ -535,7 +535,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); } - if (tidx == 0) printf("compute_attn_1rowblock_splitkv: m_block = %d, binfo.actual_seqlen_q = %d, kBlockM = %d, binfo.actual_seqlen_k = %d, kBlockN = %d, params.sink_ptr = %p, Split = %d\n", m_block, binfo.actual_seqlen_q, kBlockM, binfo.actual_seqlen_k, kBlockN, params.sink_ptr, Split); + if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0 // We exit early and write 0 to gOaccum and -inf to gLSEaccum. // Otherwise we might read OOB elements from gK and gV, @@ -835,7 +835,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons clear(acc_o); - const float sink_val = !Has_sink || params.sink_ptr == nullptr ? -INFINITY : reinterpret_cast(params.sink_ptr)[bidh]; + const float sink_val = !Has_sink || params.learnable_sink_ptr == nullptr ? -INFINITY : reinterpret_cast(params.learnable_sink_ptr)[bidh]; FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax(sink_val); const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index 4b9301291f4..de47f891e8c 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -71,7 +71,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { - SINK_SWITCH(params.sink_ptr != nullptr, Has_sink, [&] { + SINK_SWITCH(params.learnable_sink_ptr != nullptr, Has_sink, [&] { // Will only return softmax if dropout, to reduce compilation time. // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If return_softmax, set IsEvenMNConst to false to reduce number of templates @@ -116,7 +116,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { - SINK_SWITCH(params.sink_ptr != nullptr, Has_sink, [&] { + SINK_SWITCH(params.learnable_sink_ptr != nullptr, Has_sink, [&] { // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If Is_local, set Is_causal to false diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 5867bbe2f13..3d1dbfa5453 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -136,7 +136,7 @@ def _flash_attn_forward_fake( return out, softmax_lse, p, rng_state -if torch.__version__ >= "12.4.0": +if torch.__version__ >= "2.4.0": _wrapped_flash_attn_forward = torch.ops.flash_attn._flash_attn_forward else: _wrapped_flash_attn_forward = _flash_attn_forward @@ -271,7 +271,7 @@ def _flash_attn_backward( dv, dsink, softmax_d, - ) = flash_attn_gpu.sink_bwd( + ) = flash_attn_gpu.bwd( dout, q, k, @@ -517,7 +517,7 @@ def backward(ctx, dout, *args): q, k, v, learnable_sink, out, softmax_lse, rng_state = ctx.saved_tensors qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) - dsink = None if learnable_sink is None else torch.empty_like(learnable_sink) + dsink = None if learnable_sink is None else torch.zeros_like(learnable_sink) head_size_og = dout.size(3) dout_padded = dout if head_size_og % 8 != 0: @@ -612,7 +612,7 @@ def backward(ctx, dout, *args): q, k, v, learnable_sink, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) - dsink = None if learnable_sink is None else torch.empty_like(learnable_sink) + dsink = None if learnable_sink is None else torch.zeros_like(learnable_sink) head_size_og = dout.size(2) dout_padded = dout if head_size_og % 8 != 0: @@ -707,7 +707,7 @@ def backward(ctx, dout, *args): dq = torch.empty_like(q) kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) - dsink = None if learnable_sink is None else torch.empty_like(learnable_sink) + dsink = None if learnable_sink is None else torch.zeros_like(learnable_sink) head_size_og = dout.size(3) dout_padded = dout if head_size_og % 8 != 0: @@ -812,7 +812,7 @@ def backward(ctx, dout, *args): dq = torch.empty_like(q) kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) - dsink = None if learnable_sink is None else torch.empty_like(learnable_sink) + dsink = None if learnable_sink is None else torch.zeros_like(learnable_sink) head_size_og = dout.size(2) dout_padded = dout if head_size_og % 8 != 0: @@ -906,7 +906,7 @@ def forward( def backward(ctx, dout, *args): q, k, v, learnable_sink, out, softmax_lse, rng_state = ctx.saved_tensors dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) - dsink = None if learnable_sink is None else torch.empty_like(learnable_sink) + dsink = None if learnable_sink is None else torch.zeros_like(learnable_sink) head_size_og = dout.size(3) dout_padded = dout if head_size_og % 8 != 0: @@ -1012,7 +1012,7 @@ def forward( def backward(ctx, dout, *args): q, k, v, learnable_sink, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) - dsink = None if learnable_sink is None else torch.empty_like(learnable_sink) + dsink = None if learnable_sink is None else torch.zeros_like(learnable_sink) head_size_og = dout.size(2) dout_padded = dout if head_size_og % 8 != 0: From 4cf0f8872af4e4403a45e46b886a6d03fb2500ea Mon Sep 17 00:00:00 2001 From: jerryao Date: Thu, 21 Aug 2025 19:51:42 +0800 Subject: [PATCH 12/29] Fix. --- csrc/flash_attn/flash_api.cpp | 2 +- flash_attn/__init__.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 32e8907dbaf..c2eb91f09a2 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -348,7 +348,7 @@ void set_params_alibi(Flash_fwd_params ¶ms, std::optional &alibi } void set_params_sink(Flash_fwd_params ¶ms, const std::optional &learnable_sink_, int num_heads) { -#ifdef FLASHATTENTION_DISABLE_ALIBI +#ifdef FLASHATTENTION_DISABLE_SINK TORCH_CHECK(!learnable_sink_.has_value(), "This flash attention build does not support learnable sink."); params.learnable_sink_ptr = nullptr; #else diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index 7657c85eec1..69eae460e36 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -2,7 +2,6 @@ from flash_attn.flash_attn_interface import ( flash_attn_func, - flash_attn_sink_func, flash_attn_kvpacked_func, flash_attn_qkvpacked_func, flash_attn_varlen_func, From 9880064cbdea7bce2d4b181a96ad2330bb50f085 Mon Sep 17 00:00:00 2001 From: jerryao Date: Thu, 21 Aug 2025 23:43:03 +0800 Subject: [PATCH 13/29] Fix lse in combine_attn_seqk_parallel. --- csrc/flash_attn/src/flash_fwd_kernel.h | 37 +++++++++++++++--- .../src/flash_fwd_launch_template.h | 38 ++++++++++--------- 2 files changed, 52 insertions(+), 23 deletions(-) diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 70a9d248c42..89200f3381e 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -997,7 +997,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Epilogue - Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); + // We fix the lse for sink in combine_attn_seqk_parallel. + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); // if (cute::thread0()) { print(lse); } Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) @@ -1110,7 +1111,7 @@ inline __device__ void compute_attn_splitkv(const Params ¶ms) { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; @@ -1190,15 +1191,41 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { MaxOp max_op; lse_max = Allreduce::run(lse_max, max_op); lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf - float lse_sum = expf(lse_accum(0) - lse_max); + float lse_sum = __expf(lse_accum(0) - lse_max); #pragma unroll - for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); } + for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += __expf(lse_accum(l) - lse_max); } SumOp sum_op; lse_sum = Allreduce::run(lse_sum, sum_op); // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. - ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max; + ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : __logf(lse_sum) + lse_max; // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } + + if constexpr (Has_sink) { + #pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + if (row < params.num_splits && col < kBlockM) { + const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose; + if (params.unpadded_lse) { + // LSE is written as (h, seqlen_q, b) or (h, b, seqlen_q). + if (lse_offset < lse_size) { + const int head_idx = lse_offset / (params.b * params.seqlen_q); + const float lse_logsum_sink = __expf(lse_logsum); + const float sink_val_exp = params.learnable_sink_ptr == nullptr ? 0.f : __expf(reinterpret_cast(params.learnable_sink_ptr)[head_idx]); + lse_logsum = __logf(lse_logsum_sink + sink_val_exp);; + } + } else { + // LSE is written as (b, h, seqlen_q). + const int head_idx = (lse_offset % (params.h * params.seqlen_q)) / params.seqlen_q; + const float lse_logsum_sink = __expf(lse_logsum); + const float sink_val_exp = params.learnable_sink_ptr == nullptr ? 0.f : __expf(reinterpret_cast(params.learnable_sink_ptr)[head_idx]); + lse_logsum = __logf(lse_logsum_sink + sink_val_exp); + } + } + } + } if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { if (params.unpadded_lse) { const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose; diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index de47f891e8c..49ec54a5a03 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -46,9 +46,9 @@ DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_lo #endif } -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) { +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K, bool Has_sink) { static_assert(Log_max_splits >= 1); - FLASH_NAMESPACE::combine_attn_seqk_parallel(params); + FLASH_NAMESPACE::combine_attn_seqk_parallel(params); } template @@ -144,22 +144,24 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { - if (params.num_splits <= 2) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 4) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 8) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 16) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 32) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 64) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 128) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } - C10_CUDA_KERNEL_LAUNCH_CHECK(); + SINK_SWITCH(params.learnable_sink_ptr != nullptr, Has_sink, [&] { + if (params.num_splits <= 2) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 4) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 8) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 16) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 32) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 64) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 128) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); } } From 91173041d18543b9d75e4e4a43b9d0d51d0480ac Mon Sep 17 00:00:00 2001 From: jerryao Date: Thu, 21 Aug 2025 23:54:10 +0800 Subject: [PATCH 14/29] Update tests. --- tests/test_flash_attn.py | 41 +++++++++++++++++++++++++++++++--------- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index e9331f7b094..97ac11c7ff9 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -7,7 +7,6 @@ from einops import rearrange, repeat from flash_attn import ( flash_attn_func, - flash_attn_sink_func, flash_attn_kvpacked_func, flash_attn_qkvpacked_func, flash_attn_varlen_func, @@ -1530,12 +1529,12 @@ def test_flash_attn_varlen_output( # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [True]) -# @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) -@pytest.mark.parametrize("d", [64, 128]) +# @pytest.mark.parametrize("d", [64, 128]) @pytest.mark.parametrize("swap_sq_sk", [False, True]) # @pytest.mark.parametrize("swap_sq_sk", [True]) @pytest.mark.parametrize( @@ -2599,11 +2598,35 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus assert torch.equal(dk, dk0) assert torch.equal(dq, dq0) -@pytest.mark.parametrize("dtype", ([torch.bfloat16])) -@pytest.mark.parametrize("d", [64]) -@pytest.mark.parametrize("local", [False]) -@pytest.mark.parametrize("swap_sq_sk", [True]) -@pytest.mark.parametrize('seqlen_q,seqlen_k', [(512, 512)]) + +@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("local", [True]) +@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [64, 128]) +@pytest.mark.parametrize("swap_sq_sk", [False, True]) +# @pytest.mark.parametrize("swap_sq_sk", [True]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + # (1, 239), # TODO: fix this + (3, 799), + (127, 512), + (127, 513), + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (1023, 1024), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) def test_flash_attn_sink_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): if ( max(seqlen_q, seqlen_k) >= 2048 @@ -2627,7 +2650,7 @@ def test_flash_attn_sink_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype) k_ref = k.transpose(1, 2) v_ref = v.transpose(1, 2) sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) - out = flash_attn_sink_func(q, k, v, sink, 0.0, softmax_scale=d**-0.5, causal=causal, window_size=window_size) + out = flash_attn_func(q, k, v, sink, 0.0, softmax_scale=d**-0.5, causal=causal, window_size=window_size) attention_mask = get_attention_mask(seqlen_q, seqlen_k, causal, device, window_size) if attention_mask is not None: From 46b8ee99f8b0e4f8de43d8733f03f738548cacfa Mon Sep 17 00:00:00 2001 From: jerryao Date: Fri, 22 Aug 2025 09:47:45 +0800 Subject: [PATCH 15/29] Clean code. --- benchmarks/flash_attn_with_sink.py | 206 ----------------- benchmarks/flash_attn_with_sink_fused.py | 122 ---------- benchmarks/naive_attn_with_sink.py | 53 ----- benchmarks/test.py | 233 ------------------- benchmarks/test_fused.py | 236 ------------------- csrc/flash_attn/flash_api.cpp | 14 +- tests/test_flash_attn.py | 282 +---------------------- 7 files changed, 8 insertions(+), 1138 deletions(-) delete mode 100644 benchmarks/flash_attn_with_sink.py delete mode 100644 benchmarks/flash_attn_with_sink_fused.py delete mode 100644 benchmarks/naive_attn_with_sink.py delete mode 100644 benchmarks/test.py delete mode 100644 benchmarks/test_fused.py diff --git a/benchmarks/flash_attn_with_sink.py b/benchmarks/flash_attn_with_sink.py deleted file mode 100644 index ac3af82e788..00000000000 --- a/benchmarks/flash_attn_with_sink.py +++ /dev/null @@ -1,206 +0,0 @@ -import torch -from flash_attn import flash_attn_func -from flash_attn.flash_attn_interface import _flash_attn_backward - - -class FlashAttentionWithSink(torch.autograd.Function): - - @staticmethod - def forward( - ctx, - q, - k, - v, - sink: torch.Tensor, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - ): - # Check device - if q.device.type != 'cuda': - raise RuntimeError( - f"Flash Attention only supports CUDA devices, " - f"current device: {q.device}" - ) - - ctx.save_for_backward(q, k, v, sink, alibi_slopes) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.deterministic = deterministic - ctx.return_attn_probs = return_attn_probs - ctx.sink_shape = sink.shape # Save original sink shape - - # import pdb; pdb.set_trace() - - out, lse, _ = flash_attn_func( - q, - k, - v, - dropout_p=dropout_p, - softmax_scale=softmax_scale, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - - origin_dtype = out.dtype - - ctx.raw_output = out.clone() - ctx.lse = lse.clone() - - lse = lse.transpose(-2, -1).unsqueeze(dim=-1) - sink = sink.reshape(1, 1, -1, 1) - - multiplier = 1 / (torch.exp(sink - lse) + 1) - out = (out * multiplier).to(origin_dtype) - - return out - - @staticmethod - def backward(ctx, grad_output): - q, k, v, sink, alibi_slopes = ctx.saved_tensors - raw_output = ctx.raw_output - lse = ctx.lse - - lse = lse.transpose(-2, -1).unsqueeze(dim=-1) - sink_reshaped = sink.reshape(1, 1, -1, 1) - multiplier = 1 / (torch.exp(sink_reshaped - lse) + 1) - - # 1) Main path via multiplier - grad_raw_output = (grad_output * multiplier).to(q.dtype) - - # Use flash attention backward function for main path - grad_q_main = torch.empty_like(q) - grad_k_main = torch.empty_like(k) - grad_v = torch.empty_like(v) - - _flash_attn_backward( - grad_raw_output, # dout: main path gradient - q, # q - k, # k - v, # v - ctx.raw_output, # out: original output - lse, # softmax_lse - grad_q_main, # dq: main path - grad_k_main, # dk: main path - grad_v, # dv: main path - ctx.dropout_p, # dropout_p - ctx.softmax_scale, # softmax_scale - ctx.causal, # causal - ctx.window_size[0], # window_size_left - ctx.window_size[1], # window_size_right - ctx.softcap, # softcap - alibi_slopes, # alibi_slopes - ctx.deterministic, # deterministic - ) - - # 2) Sink gradient path - # g_r = (grad_output * raw_output).sum(dim=-1) # [B,H,Nq] - g_r = torch.sum(grad_output * raw_output, dim=-1) - - # g_ell = g_r * multiplier * (1 - multiplier) # [B,H,Nq] - # Based on debug output: - # g_r shape: [1, 512, 64] (batch, seq_len, heads) - # multiplier shape: [1, 512, 64, 1] (batch, seq_len, heads, 1) - # We need multiplier_for_grad to have shape [1, 512, 64] - # [1, 512, 64, 1] -> [1, 512, 64] - multiplier_for_grad = multiplier.squeeze(-1) - - g_ell = g_r * multiplier_for_grad * (1 - multiplier_for_grad) - # Based on shapes: g_ell [1, 512, 64], we need to sum over seq_len (dim=1) - # to get [1, 64], then sum over batch (dim=0) to get [64] - grad_sink = -torch.sum(g_ell, dim=1) # Sum over seq_len -> [1, 64] - # Sum over batch dimension and reshape to match original sink shape - grad_sink = grad_sink.sum(dim=0) # Sum over batch -> [64] - grad_sink = grad_sink.reshape(ctx.sink_shape) - - # 3) Additional Q gradient via sink - # dQ_extra = scale * g_ell * attention(Q,K,K) - scale = ctx.softmax_scale or (1.0 / q.shape[-1] ** 0.5) - - # Compute attention(Q,K,K) for additional Q gradient - mu_k = flash_attn_func( - q, k, k, - dropout_p=ctx.dropout_p, - softmax_scale=ctx.softmax_scale, - causal=ctx.causal, - window_size=ctx.window_size, - softcap=ctx.softcap, - alibi_slopes=alibi_slopes, - deterministic=ctx.deterministic, - return_attn_probs=False, - ) - grad_q_extra = scale * g_ell.unsqueeze(-1) * mu_k - - # 4) Additional K gradient via sink - # dK_extra = scale * P^T (g_ell * Q) - x = (g_ell.unsqueeze(-1) * q).to(q.dtype) - - # Use flash attention backward to compute P^T X - grad_k_extra = torch.empty_like(k) - _flash_attn_backward( - x, # dout: g_ell * Q - q, # q - k, # k - k, # v (dummy, using K as V) - ctx.raw_output, # out: original output - lse, # softmax_lse - None, # dq: not needed - None, # dk: not needed - grad_k_extra, # dv: this will be dK_extra - ctx.dropout_p, # dropout_p - ctx.softmax_scale, # softmax_scale - ctx.causal, # causal - ctx.window_size[0], # window_size_left - ctx.window_size[1], # window_size_right - ctx.softcap, # softcap - alibi_slopes, # alibi_slopes - ctx.deterministic, # deterministic - ) - grad_k_extra = scale * grad_k_extra - - # 5) Sum all gradients - grad_q = grad_q_main + grad_q_extra - grad_k = grad_k_main + grad_k_extra - # grad_v already from main path - - return (grad_q, grad_k, grad_v, grad_sink, None, None, None, None, - None, None, None, None) - - -def flash_attn_with_sink_func( - q, - k, - v, - sink: torch.Tensor, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - # Check CUDA availability - if not torch.cuda.is_available(): - raise RuntimeError( - "Flash Attention requires CUDA devices. " - "Current device does not support CUDA." - ) - - return FlashAttentionWithSink.apply( - q, k, v, sink, dropout_p, softmax_scale, causal, - window_size, softcap, alibi_slopes, deterministic, return_attn_probs - ) diff --git a/benchmarks/flash_attn_with_sink_fused.py b/benchmarks/flash_attn_with_sink_fused.py deleted file mode 100644 index e19194ffd7d..00000000000 --- a/benchmarks/flash_attn_with_sink_fused.py +++ /dev/null @@ -1,122 +0,0 @@ -import torch -from flash_attn import flash_attn_sink_func -from flash_attn.flash_attn_interface import _flash_attn_sink_backward - - -class FlashAttentionWithSinkFused(torch.autograd.Function): - - @staticmethod - def forward( - ctx, - q, - k, - v, - sink: torch.Tensor, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - ): - # Check device - if q.device.type != 'cuda': - raise RuntimeError( - f"Flash Attention only supports CUDA devices, " - f"current device: {q.device}" - ) - - ctx.save_for_backward(q, k, v, sink, alibi_slopes) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.deterministic = deterministic - ctx.return_attn_probs = return_attn_probs - ctx.sink_shape = sink.shape # Save original sink shape - - out, lse, _ = flash_attn_sink_func( - q, - k, - v, - sink, - dropout_p=dropout_p, - softmax_scale=softmax_scale, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - ctx.lse = lse - ctx.output = out - - return out - - - @staticmethod - def backward(ctx, grad_output): - q, k, v, sink, alibi_slopes = ctx.saved_tensors - lse = ctx.lse - - grad_q = torch.empty_like(q) - grad_k = torch.empty_like(k) - grad_v = torch.empty_like(v) - grad_sink = torch.empty_like(sink) - - _flash_attn_sink_backward( - grad_output, # dout: main path gradient - q, # q - k, # k - v, # v - sink, - ctx.output, # out: original output - lse, # softmax_lse - grad_q, # dq: main path - grad_k, # dk: main path - grad_v, # dv: main path - grad_sink, # ds: sink gradient - ctx.dropout_p, # dropout_p - ctx.softmax_scale, # softmax_scale - ctx.causal, # causal - ctx.window_size[0], # window_size_left - ctx.window_size[1], # window_size_right - ctx.softcap, # softcap - alibi_slopes, # alibi_slopes - ctx.deterministic, # deterministic - ) - - - return (grad_q, grad_k, grad_v, grad_sink, None, None, None, None, - None, None, None, None) - - -def flash_attn_with_sink_fused_func( - q, - k, - v, - sink: torch.Tensor, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - # Check CUDA availability - if not torch.cuda.is_available(): - raise RuntimeError( - "Flash Attention requires CUDA devices. " - "Current device does not support CUDA." - ) - - return FlashAttentionWithSinkFused.apply( - q, k, v, sink, dropout_p, softmax_scale, causal, - window_size, softcap, alibi_slopes, deterministic, return_attn_probs - ) diff --git a/benchmarks/naive_attn_with_sink.py b/benchmarks/naive_attn_with_sink.py deleted file mode 100644 index d5f8fb1727b..00000000000 --- a/benchmarks/naive_attn_with_sink.py +++ /dev/null @@ -1,53 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from typing import Optional - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). - The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) - to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand( - batch, num_key_value_heads, n_rep, slen, head_dim - ) - print(num_key_value_heads, n_rep) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -# from https://github.com/huggingface/transformers/blob/369c99d0cea403b77bd0aef818527106453fd9fc/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L227 -def eager_attention_forward( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - sink: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - num_key_value_groups: int = 8, - **kwargs, -): - key_states = repeat_kv(key, num_key_value_groups) - value_states = repeat_kv(value, num_key_value_groups) - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - sinks = sink.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) - combined_logits = torch.cat([attn_weights, sinks], dim=-1) - # This was not in the original implementation and slightly affect results; - # it prevents overflow in BF16/FP16 when training with bsz>1 we clamp max values. - combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values - - probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) - scores = probs[..., :-1] # we drop the sink here - attn_weights = nn.functional.dropout(scores, p=dropout, training=True) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - return attn_output, attn_weights diff --git a/benchmarks/test.py b/benchmarks/test.py deleted file mode 100644 index e24098caaee..00000000000 --- a/benchmarks/test.py +++ /dev/null @@ -1,233 +0,0 @@ -import torch -import torch.nn.functional as F -from flash_attn_with_sink import flash_attn_with_sink_func -from naive_attn_with_sink import eager_attention_forward - - -if __name__ == "__main__": - batch = 1 - num_attention_heads = 64 - num_key_value_heads = 8 - num_key_value_groups = num_attention_heads // num_key_value_heads - head_dim = 64 - seq_len = 512 - scaling = head_dim**-0.5 - torch.manual_seed(0) - - torch.cuda.set_device(0) - query = torch.randn( - (batch, num_attention_heads, seq_len, head_dim), - dtype=torch.bfloat16, - device="cuda", - requires_grad=True, - ) - key = torch.randn( - (batch, num_key_value_heads, seq_len, head_dim), - dtype=torch.bfloat16, - device="cuda", - requires_grad=True, - ) - value = torch.randn( - (batch, num_key_value_heads, seq_len, head_dim), - dtype=torch.bfloat16, - device="cuda", - requires_grad=True, - ) - sink = torch.randn( - (num_attention_heads,), - dtype=torch.bfloat16, - device="cuda", - requires_grad=True, - ) - - # Create causal attention mask - # The mask should be of shape (batch, num_heads, seq_len, seq_len) - # For causal attention, we mask out future positions - # (set them to large negative value) - attention_mask = torch.triu( - torch.full( - (seq_len, seq_len), float("-inf"), device="cuda", dtype=torch.bfloat16 - ), - diagonal=1, - ) - attention_mask = ( - attention_mask.unsqueeze(0) - .unsqueeze(0) - .expand(batch, num_attention_heads, -1, -1) - ) - - print("Running eager attention forward...") - eager_output, eager_weights = eager_attention_forward( - query, - key, - value, - sink, - attention_mask=attention_mask, - scaling=scaling, - dropout=0.0, - num_key_value_groups=num_key_value_groups, - ) - - print("Running flash attention forward...") - # Reshape tensors for flash attention (batch, seq_len, num_heads, head_dim) - q_flash = query.transpose(1, 2) # (batch, seq_len, num_heads, head_dim) - k_flash = key.transpose(1, 2) # (batch, seq_len, num_kv_heads, head_dim) - v_flash = value.transpose(1, 2) # (batch, seq_len, num_kv_heads, head_dim) - - flash_output = flash_attn_with_sink_func( - q_flash, - k_flash, - v_flash, - sink, - softmax_scale=scaling, - dropout_p=0.0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, - ) - - # Compare outputs - print(f"Eager output shape: {eager_output.shape}, dtype: {eager_output.dtype}") - print(f"Flash output shape: {flash_output.shape}, dtype: {flash_output.dtype}") - - print( - f"Max absolute difference: {torch.max(torch.abs(eager_output - flash_output))}" - ) - print( - f"Mean absolute difference: {torch.mean(torch.abs(eager_output - flash_output))}" - ) - print( - f"Relative error: {torch.mean(torch.abs(eager_output - flash_output) / (torch.abs(eager_output) + 1e-8))}" - ) - - print("\nEager output sample (first 5x5 elements):") - print(eager_output[0, 0, :5, :5]) - print("\nFlash output sample (first 5x5 elements):") - print(flash_output[0, 0, :5, :5]) - print("eager_output / flash_output:\n", eager_output[0, 0, :8, :8] / flash_output[0, 0, :8, :8]) - - # Test backward pass - print("\n" + "=" * 50) - print("Testing backward pass...") - - # Reset gradients (handle None case) - if query.grad is not None: - query.grad.zero_() - if key.grad is not None: - key.grad.zero_() - if value.grad is not None: - value.grad.zero_() - if sink.grad is not None: - sink.grad.zero_() - - # Compute loss for eager attention - target = torch.randn_like(eager_output, device="cuda") - eager_loss = F.mse_loss(eager_output, target) * 1000 - - print(f"Eager loss: {eager_loss.item():.6f}") - - # Backward pass for eager attention - eager_loss.backward() - - # Save eager gradients - eager_query_grad = query.grad.clone() - eager_key_grad = key.grad.clone() - eager_value_grad = value.grad.clone() - eager_sink_grad = sink.grad.clone() - - print("\nEager gradient information:") - print(f"Query gradient norm: {eager_query_grad.norm().item():.6f}") - print(f"Key gradient norm: {eager_key_grad.norm().item():.6f}") - print(f"Value gradient norm: {eager_value_grad.norm().item():.6f}") - print(f"Sink gradient norm: {eager_sink_grad.norm().item():.6f}") - - # Reset gradients for flash attention (handle None case) - if query.grad is not None: - query.grad.zero_() - if key.grad is not None: - key.grad.zero_() - if value.grad is not None: - value.grad.zero_() - if sink.grad is not None: - sink.grad.zero_() - - # Compute loss for flash attention - flash_loss = F.mse_loss(flash_output, target) * 1000 - - print(f"\nFlash loss: {flash_loss.item():.6f}") - - # Backward pass for flash attention - flash_loss.backward() - - # Save flash gradients - flash_query_grad = query.grad.clone() - flash_key_grad = key.grad.clone() - flash_value_grad = value.grad.clone() - flash_sink_grad = sink.grad.clone() - - print("\nFlash gradient information:") - print(f"Query gradient norm: {flash_query_grad.norm().item():.6f}") - print(f"Key gradient norm: {flash_key_grad.norm().item():.6f}") - print(f"Value gradient norm: {flash_value_grad.norm().item():.6f}") - print(f"Sink gradient norm: {flash_sink_grad.norm().item():.6f}") - - # Compare gradients - print("\n" + "=" * 50) - print("Comparing gradients...") - - # Calculate gradient differences - query_grad_diff = torch.abs(eager_query_grad - flash_query_grad).max().item() - key_grad_diff = torch.abs(eager_key_grad - flash_key_grad).max().item() - value_grad_diff = torch.abs(eager_value_grad - flash_value_grad).max().item() - sink_grad_diff = torch.abs(eager_sink_grad - flash_sink_grad).max().item() - - print(f"Query gradient max difference: {query_grad_diff:.2e}") - print(f"Key gradient max difference: {key_grad_diff:.2e}") - print(f"Value gradient max difference: {value_grad_diff:.2e}") - print(f"Sink gradient max difference: {sink_grad_diff:.2e}") - - # Check if gradients are close (within tolerance) - tolerance = 1e-3 # Adjust tolerance as needed - query_grad_close = query_grad_diff < tolerance - key_grad_close = key_grad_diff < tolerance - value_grad_close = value_grad_diff < tolerance - sink_grad_close = sink_grad_diff < tolerance - - print(f"\nGradient comparison (tolerance: {tolerance}):") - print(f"Query gradients close: {'✅' if query_grad_close else '❌'}") - print(f"Key gradients close: {'✅' if key_grad_close else '❌'}") - print(f"Value gradients close: {'✅' if value_grad_close else '❌'}") - print(f"Sink gradients close: {'✅' if sink_grad_close else '❌'}") - - # Check if gradients are non-zero - query_grad_zero = eager_query_grad.norm().item() < 1e-8 - key_grad_zero = eager_key_grad.norm().item() < 1e-8 - value_grad_zero = eager_value_grad.norm().item() < 1e-8 - sink_grad_zero = eager_sink_grad.norm().item() < 1e-8 - - print(f"\nGradient non-zero check:") - print(f"Query gradient non-zero: {'✅' if not query_grad_zero else '❌'}") - print(f"Key gradient non-zero: {'✅' if not key_grad_zero else '❌'}") - print(f"Value gradient non-zero: {'✅' if not value_grad_zero else '❌'}") - print(f"Sink gradient non-zero: {'✅' if not sink_grad_zero else '❌'}") - - all_grads_close = ( - query_grad_close and key_grad_close and value_grad_close and sink_grad_close - ) - all_grads_nonzero = not ( - query_grad_zero or key_grad_zero or value_grad_zero or sink_grad_zero - ) - - print(f"\nOverall result:") - print(f" All gradients close: {'✅' if all_grads_close else '❌'}") - print(f" All gradients non-zero: {'✅' if all_grads_nonzero else '❌'}") - - if all_grads_close and all_grads_nonzero: - print("\n🎉 Backward test passed! Gradients match and are non-zero.") - else: - print("\n❌ Backward test failed!") - if not all_grads_close: - print(" - Some gradients don't match between eager and flash attention") - if not all_grads_nonzero: - print(" - Some gradients are zero") diff --git a/benchmarks/test_fused.py b/benchmarks/test_fused.py deleted file mode 100644 index 849e6cfff66..00000000000 --- a/benchmarks/test_fused.py +++ /dev/null @@ -1,236 +0,0 @@ -import torch -import torch.nn.functional as F -from flash_attn_with_sink_fused import flash_attn_with_sink_fused_func -from naive_attn_with_sink import eager_attention_forward - - -if __name__ == "__main__": - batch = 1 - num_attention_heads = 64 - num_key_value_heads = 8 - num_key_value_groups = num_attention_heads // num_key_value_heads - head_dim = 64 - seq_len = 512 - scaling = head_dim**-0.5 - torch.manual_seed(0) - - torch.cuda.set_device(0) - query = torch.randn( - (batch, num_attention_heads, seq_len, head_dim), - dtype=torch.bfloat16, - device="cuda", - requires_grad=True, - ) - key = torch.randn( - (batch, num_key_value_heads, seq_len, head_dim), - dtype=torch.bfloat16, - device="cuda", - requires_grad=True, - ) - value = torch.randn( - (batch, num_key_value_heads, seq_len, head_dim), - dtype=torch.bfloat16, - device="cuda", - requires_grad=True, - ) - sink = torch.randn( - (num_attention_heads,), - dtype=torch.float32, - device="cuda", - requires_grad=True, - ) - - # Create causal attention mask - # The mask should be of shape (batch, num_heads, seq_len, seq_len) - # For causal attention, we mask out future positions - # (set them to large negative value) - attention_mask = torch.triu( - torch.full( - (seq_len, seq_len), float("-inf"), device="cuda", dtype=torch.bfloat16 - ), - diagonal=1, - ) - attention_mask = ( - attention_mask.unsqueeze(0) - .unsqueeze(0) - .expand(batch, num_attention_heads, -1, -1) - ) - - print("Running eager attention forward...") - eager_output, eager_weights = eager_attention_forward( - query, - key, - value, - sink.to(torch.bfloat16), - attention_mask=attention_mask, - scaling=scaling, - dropout=0.0, - num_key_value_groups=num_key_value_groups, - ) - - print("Running flash attention forward...") - # Reshape tensors for flash attention (batch, seq_len, num_heads, head_dim) - q_flash = query.transpose(1, 2) # (batch, seq_len, num_heads, head_dim) - k_flash = key.transpose(1, 2) # (batch, seq_len, num_kv_heads, head_dim) - v_flash = value.transpose(1, 2) # (batch, seq_len, num_kv_heads, head_dim) - - flash_output = flash_attn_with_sink_fused_func( - q_flash, - k_flash, - v_flash, - sink, - softmax_scale=scaling, - dropout_p=0.0, - causal=True, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, - ) - - # Compare outputs - print(f"Eager output shape: {eager_output.shape}, dtype: {eager_output.dtype}") - print(f"Flash output shape: {flash_output.shape}, dtype: {flash_output.dtype}") - - print( - f"Max absolute difference: {torch.max(torch.abs(eager_output - flash_output))}" - ) - print( - f"Mean absolute difference: {torch.mean(torch.abs(eager_output - flash_output))}" - ) - print( - f"Relative error: {torch.mean(torch.abs(eager_output - flash_output) / (torch.abs(eager_output) + 1e-8))}" - ) - - print("\nEager output sample (first 8x8 elements):") - print(eager_output[0, 0, :8, :8]) - print("\nFlash output sample (first 8x8 elements):") - print(flash_output[0, 0, :8, :8]) - print("eager_output / flash_output:\n", eager_output[0, 0, :8, :8] / flash_output[0, 0, :8, :8]) - - # Test backward pass - print("\n" + "=" * 50) - print("Testing backward pass...") - - # Reset gradients (handle None case) - if query.grad is not None: - query.grad.zero_() - if key.grad is not None: - key.grad.zero_() - if value.grad is not None: - value.grad.zero_() - if sink.grad is not None: - sink.grad.zero_() - - # Compute loss for eager attention - target = torch.randn_like(eager_output, device="cuda") - eager_loss = F.mse_loss(eager_output, target) * 1000 - - print(f"Eager loss: {eager_loss.item():.6f}") - - # Backward pass for eager attention - eager_loss.backward() - - # Save eager gradients - eager_query_grad = query.grad.clone() - eager_key_grad = key.grad.clone() - eager_value_grad = value.grad.clone() - eager_sink_grad = sink.grad.clone() - - print("\nEager gradient information:") - print(f"Query gradient norm: {eager_query_grad.norm().item():.6f}") - print(f"Key gradient norm: {eager_key_grad.norm().item():.6f}") - print(f"Value gradient norm: {eager_value_grad.norm().item():.6f}") - print(f"Sink gradient norm: {eager_sink_grad.norm().item():.6f}") - - # Reset gradients for flash attention (handle None case) - if query.grad is not None: - query.grad.zero_() - if key.grad is not None: - key.grad.zero_() - if value.grad is not None: - value.grad.zero_() - if sink.grad is not None: - sink.grad.zero_() - - # Compute loss for flash attention - flash_loss = F.mse_loss(flash_output, target) * 1000 - - print(f"\nFlash loss: {flash_loss.item():.6f}") - - # Backward pass for flash attention - flash_loss.backward() - - # Save flash gradients - flash_query_grad = query.grad.clone() - flash_key_grad = key.grad.clone() - flash_value_grad = value.grad.clone() - flash_sink_grad = sink.grad.clone() - - print("\nFlash gradient information:") - print(f"Query gradient norm: {flash_query_grad.norm().item():.6f}") - print(f"Key gradient norm: {flash_key_grad.norm().item():.6f}") - print(f"Value gradient norm: {flash_value_grad.norm().item():.6f}") - print(f"Sink gradient norm: {flash_sink_grad.norm().item():.6f}") - - # Compare gradients - print("\n" + "=" * 50) - print("Comparing gradients...") - - # Calculate gradient differences - query_grad_diff = torch.abs(eager_query_grad - flash_query_grad).max().item() - key_grad_diff = torch.abs(eager_key_grad - flash_key_grad).max().item() - value_grad_diff = torch.abs(eager_value_grad - flash_value_grad).max().item() - sink_grad_diff = torch.abs(eager_sink_grad.to(flash_sink_grad.dtype) - flash_sink_grad).max().item() - - print("eager_sink_grad = ", eager_sink_grad) - print("flash_sink_grad = ", flash_sink_grad) - - print(f"Query gradient max difference: {query_grad_diff:.2e}") - print(f"Key gradient max difference: {key_grad_diff:.2e}") - print(f"Value gradient max difference: {value_grad_diff:.2e}") - print(f"Sink gradient max difference: {sink_grad_diff:.2e}") - - # Check if gradients are close (within tolerance) - tolerance = 1e-2 # Adjust tolerance as needed - query_grad_close = query_grad_diff < tolerance - key_grad_close = key_grad_diff < tolerance - value_grad_close = value_grad_diff < tolerance - sink_grad_close = sink_grad_diff < tolerance - - print(f"\nGradient comparison (tolerance: {tolerance}):") - print(f"Query gradients close: {'✅' if query_grad_close else '❌'}") - print(f"Key gradients close: {'✅' if key_grad_close else '❌'}") - print(f"Value gradients close: {'✅' if value_grad_close else '❌'}") - print(f"Sink gradients close: {'✅' if sink_grad_close else '❌'}") - - # Check if gradients are non-zero - query_grad_zero = eager_query_grad.norm().item() < 1e-8 - key_grad_zero = eager_key_grad.norm().item() < 1e-8 - value_grad_zero = eager_value_grad.norm().item() < 1e-8 - sink_grad_zero = eager_sink_grad.norm().item() < 1e-8 - - print(f"\nGradient non-zero check:") - print(f"Query gradient non-zero: {'✅' if not query_grad_zero else '❌'}") - print(f"Key gradient non-zero: {'✅' if not key_grad_zero else '❌'}") - print(f"Value gradient non-zero: {'✅' if not value_grad_zero else '❌'}") - print(f"Sink gradient non-zero: {'✅' if not sink_grad_zero else '❌'}") - - all_grads_close = ( - query_grad_close and key_grad_close and value_grad_close and sink_grad_close - ) - all_grads_nonzero = not ( - query_grad_zero or key_grad_zero or value_grad_zero or sink_grad_zero - ) - - print(f"\nOverall result:") - print(f" All gradients close: {'✅' if all_grads_close else '❌'}") - print(f" All gradients non-zero: {'✅' if all_grads_nonzero else '❌'}") - - if all_grads_close and all_grads_nonzero: - print("\n🎉 Backward test passed! Gradients match and are non-zero.") - else: - print("\n❌ Backward test failed!") - if not all_grads_close: - print(" - Some gradients don't match between eager and flash attention") - if not all_grads_nonzero: - print(" - Some gradients are zero") diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index c2eb91f09a2..d163997ae1c 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -369,7 +369,7 @@ std::vector mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) - std::optional &learnable_sink, // num_heads + std::optional &learnable_sink_, // num_heads std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) std::optional &alibi_slopes_, // num_heads or batch_size x num_heads const float p_dropout, @@ -424,7 +424,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_mult // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza - const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value(); + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value() && !learnable_sink_.has_value(); const int ngroups = num_heads / num_heads_k; if (seqlenq_ngroups_swapped) { q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2); @@ -513,7 +513,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_mult } set_params_alibi(params, alibi_slopes_, batch_size, num_heads); - set_params_sink(params, learnable_sink, num_heads); + set_params_sink(params, learnable_sink_, num_heads); if (seqlen_k > 0) { auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -536,7 +536,7 @@ std::vector mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. - std::optional &learnable_sink, // num_heads + std::optional &learnable_sink_, // num_heads std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 @@ -611,7 +611,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza - const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value(); + const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value() && !learnable_sink_.has_value(); const int ngroups = num_heads / num_heads_k; if (seqlenq_ngroups_swapped) { q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size}); @@ -755,7 +755,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s } set_params_alibi(params, alibi_slopes_, batch_size, num_heads); - set_params_sink(params, learnable_sink, num_heads); + set_params_sink(params, learnable_sink_, num_heads); if (max_seqlen_k > 0) { auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -1331,7 +1331,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza - const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value(); + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value() && !learnable_sink_.has_value(); if (seqlenq_ngroups_swapped) { const int ngroups = num_heads / num_heads_k; q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 97ac11c7ff9..bd793343c2f 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -2614,7 +2614,7 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ - # (1, 239), # TODO: fix this + (1, 239), # disable seqlenq_ngroups_swapped (3, 799), (127, 512), (127, 513), @@ -2712,283 +2712,3 @@ def test_flash_attn_sink_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype) assert (dk - dk_ref).abs().max().detach().item() <= 2 * (dk_pt - dk_ref).abs().max().detach().item() + 1e-5 assert (dv - dv_ref).abs().max().detach().item() <= 2 * (dv_pt - dv_ref).abs().max().detach().item() + 1e-5 assert (dsink - dsink_ref).abs().max().detach().item() <= 2 * (dsink_pt - dsink_ref).abs().max().detach().item() + 1e-5 - - - -""" -:test_flash_attn_causal - -Running 16 items in this shard -Output max diff: 0.00048828125 -Output mean diff: 1.8358230590820312e-05 -Pytorch max diff: 0.0009765625 -Pytorch mean diff: 3.36766242980957e-05 -==== 256 512 False 64 False torch.float16 -dQ max diff: 0.00048828125 -dK max diff: 0.00048828125 -dV max diff: 0.000244140625 -dQ mean diff: 5.364418029785156e-06 -dK mean diff: 3.3974647521972656e-06 -dV mean diff: 1.1920928955078125e-07 -dQ Pytorch max diff: 0.00146484375 -dK Pytorch max diff: 0.00146484375 -dV Pytorch max diff: 0.0009765625 -dQ Pytorch mean diff: 3.927946090698242e-05 -dK Pytorch mean diff: 2.5272369384765625e-05 -dV Pytorch mean diff: 2.1457672119140625e-05 -.Output max diff: 0.00390625 -Output mean diff: 9.679794311523438e-05 -Pytorch max diff: 0.01171875 -Pytorch mean diff: 0.0002498626708984375 -==== 256 512 False 64 False torch.bfloat16 -dQ max diff: 0.0078125 -dK max diff: 0.0078125 -dV max diff: 0.00390625 -dQ mean diff: 0.0001125335693359375 -dK mean diff: 7.104873657226562e-05 -dV mean diff: 7.343292236328125e-05 -dQ Pytorch max diff: 0.01171875 -dK Pytorch max diff: 0.015625 -dV Pytorch max diff: 0.015625 -dQ Pytorch mean diff: 0.0002956390380859375 -dK Pytorch mean diff: 0.00019073486328125 -dV Pytorch mean diff: 0.0001583099365234375 -.Output max diff: 0.00048828125 -Output mean diff: 2.6047229766845703e-05 -Pytorch max diff: 0.001953125 -Pytorch mean diff: 4.690885543823242e-05 -==== 256 512 False 64 True torch.float16 -dQ max diff: 0.0009765625 -dK max diff: 0.0009765625 -dV max diff: 0.00048828125 -dQ mean diff: 8.881092071533203e-06 -dK mean diff: 5.245208740234375e-06 -dV mean diff: 1.1920928955078125e-07 -dQ Pytorch max diff: 0.001953125 -dK Pytorch max diff: 0.0029296875 -dV Pytorch max diff: 0.00146484375 -dQ Pytorch mean diff: 5.352497100830078e-05 -dK Pytorch mean diff: 3.135204315185547e-05 -dV Pytorch mean diff: 2.7120113372802734e-05 -.Output max diff: 0.00390625 -Output mean diff: 0.0001354217529296875 -Pytorch max diff: 0.0078125 -Pytorch mean diff: 0.00034332275390625 -==== 256 512 False 64 True torch.bfloat16 -dQ max diff: 0.0078125 -dK max diff: 0.0078125 -dV max diff: 0.0078125 -dQ mean diff: 0.00015926361083984375 -dK mean diff: 9.202957153320312e-05 -dV mean diff: 9.632110595703125e-05 -dQ Pytorch max diff: 0.015625 -dK Pytorch max diff: 0.015625 -dV Pytorch max diff: 0.015625 -dQ Pytorch mean diff: 0.000400543212890625 -dK Pytorch mean diff: 0.00023555755615234375 -dV Pytorch mean diff: 0.00019931793212890625 -.Output max diff: 0.00048828125 -Output mean diff: 2.2709369659423828e-05 -Pytorch max diff: 0.0009765625 -Pytorch mean diff: 3.910064697265625e-05 -==== 256 512 False 128 False torch.float16 -dQ max diff: 0.00048828125 -dK max diff: 0.0009765625 -dV max diff: 0.00048828125 -dQ mean diff: 2.1457672119140625e-05 -dK mean diff: 1.6450881958007812e-05 -dV mean diff: 1.436471939086914e-05 -dQ Pytorch max diff: 0.00244140625 -dK Pytorch max diff: 0.00244140625 -dV Pytorch max diff: 0.0009765625 -dQ Pytorch mean diff: 4.565715789794922e-05 -dK Pytorch mean diff: 3.063678741455078e-05 -dV Pytorch mean diff: 2.5272369384765625e-05 -.Output max diff: 0.00390625 -Output mean diff: 9.632110595703125e-05 -Pytorch max diff: 0.009765625 -Pytorch mean diff: 0.000270843505859375 -==== 256 512 False 128 False torch.bfloat16 -dQ max diff: 0.00390625 -dK max diff: 0.0078125 -dV max diff: 0.00390625 -dQ mean diff: 0.000110626220703125 -dK mean diff: 7.200241088867188e-05 -dV mean diff: 7.343292236328125e-05 -dQ Pytorch max diff: 0.01171875 -dK Pytorch max diff: 0.01171875 -dV Pytorch max diff: 0.0078125 -dQ Pytorch mean diff: 0.000331878662109375 -dK Pytorch mean diff: 0.000213623046875 -dV Pytorch mean diff: 0.00017452239990234375 -.Output max diff: 0.00048828125 -Output mean diff: 3.236532211303711e-05 -Pytorch max diff: 0.00146484375 -Pytorch mean diff: 5.4717063903808594e-05 -==== 256 512 False 128 True torch.float16 -dQ max diff: 0.0009765625 -dK max diff: 0.0009765625 -dV max diff: 0.0009765625 -dQ mean diff: 2.956390380859375e-05 -dK mean diff: 2.086162567138672e-05 -dV mean diff: 1.8596649169921875e-05 -dQ Pytorch max diff: 0.001953125 -dK Pytorch max diff: 0.00390625 -dV Pytorch max diff: 0.001220703125 -dQ Pytorch mean diff: 6.276369094848633e-05 -dK Pytorch mean diff: 3.8504600524902344e-05 -dV Pytorch mean diff: 3.212690353393555e-05 -.Output max diff: 0.00390625 -Output mean diff: 0.00013446807861328125 -Pytorch max diff: 0.01171875 -Pytorch mean diff: 0.000377655029296875 -==== 256 512 False 128 True torch.bfloat16 -dQ max diff: 0.0078125 -dK max diff: 0.0078125 -dV max diff: 0.00390625 -dQ mean diff: 0.00015735626220703125 -dK mean diff: 9.298324584960938e-05 -dV mean diff: 9.632110595703125e-05 -dQ Pytorch max diff: 0.015625 -dK Pytorch max diff: 0.015625 -dV Pytorch max diff: 0.01171875 -dQ Pytorch mean diff: 0.000453948974609375 -dK Pytorch mean diff: 0.00026702880859375 -dV Pytorch mean diff: 0.00022125244140625 -.Output max diff: 0.001953125 -Output mean diff: 1.7821788787841797e-05 -Pytorch max diff: 0.0029296875 -Pytorch mean diff: 2.8133392333984375e-05 -==== 512 256 True 64 False torch.float16 -dQ max diff: 0.001953125 -dK max diff: 0.001953125 -dV max diff: 0.0009765625 -dQ mean diff: 7.450580596923828e-06 -dK mean diff: 1.3172626495361328e-05 -dV mean diff: 2.384185791015625e-07 -dQ Pytorch max diff: 0.00390625 -dK Pytorch max diff: 0.001953125 -dV Pytorch max diff: 0.00390625 -dQ Pytorch mean diff: 3.224611282348633e-05 -dK Pytorch mean diff: 5.4717063903808594e-05 -dV Pytorch mean diff: 4.7147274017333984e-05 -.Output max diff: 0.015625 -Output mean diff: 8.821487426757812e-05 -Pytorch max diff: 0.015625 -Pytorch mean diff: 0.000213623046875 -==== 512 256 True 64 False torch.bfloat16 -dQ max diff: 0.015625 -dK max diff: 0.015625 -dV max diff: 0.015625 -dQ mean diff: 0.00010776519775390625 -dK mean diff: 0.00017452239990234375 -dV mean diff: 0.00018310546875 -dQ Pytorch max diff: 0.0234375 -dK Pytorch max diff: 0.015625 -dV Pytorch max diff: 0.03125 -dQ Pytorch mean diff: 0.00024318695068359375 -dK Pytorch mean diff: 0.0004119873046875 -dV Pytorch mean diff: 0.0003509521484375 -.Output max diff: 0.001953125 -Output mean diff: 1.817941665649414e-05 -Pytorch max diff: 0.0029296875 -Pytorch mean diff: 2.8789043426513672e-05 -==== 512 256 True 64 True torch.float16 -dQ max diff: 0.001953125 -dK max diff: 0.001953125 -dV max diff: 0.001953125 -dQ mean diff: 7.62939453125e-06 -dK mean diff: 1.3887882232666016e-05 -dV mean diff: 2.384185791015625e-07 -dQ Pytorch max diff: 0.00390625 -dK Pytorch max diff: 0.0029296875 -dV Pytorch max diff: 0.00390625 -dQ Pytorch mean diff: 3.3020973205566406e-05 -dK Pytorch mean diff: 5.7756900787353516e-05 -dV Pytorch mean diff: 4.9948692321777344e-05 -.Output max diff: 0.015625 -Output mean diff: 9.012222290039062e-05 -Pytorch max diff: 0.015625 -Pytorch mean diff: 0.000217437744140625 -==== 512 256 True 64 True torch.bfloat16 -dQ max diff: 0.015625 -dK max diff: 0.015625 -dV max diff: 0.03125 -dQ mean diff: 0.00011014938354492188 -dK mean diff: 0.00018405914306640625 -dV mean diff: 0.0001926422119140625 -dQ Pytorch max diff: 0.0234375 -dK Pytorch max diff: 0.015625 -dV Pytorch max diff: 0.03125 -dQ Pytorch mean diff: 0.000247955322265625 -dK Pytorch mean diff: 0.0004329681396484375 -dV Pytorch mean diff: 0.0003681182861328125 -.Output max diff: 0.001953125 -Output mean diff: 2.187490463256836e-05 -Pytorch max diff: 0.001953125 -Pytorch mean diff: 3.463029861450195e-05 -==== 512 256 True 128 False torch.float16 -dQ max diff: 0.00390625 -dK max diff: 0.001953125 -dV max diff: 0.00390625 -dQ mean diff: 1.8894672393798828e-05 -dK mean diff: 3.814697265625e-05 -dV mean diff: 3.421306610107422e-05 -dQ Pytorch max diff: 0.00390625 -dK Pytorch max diff: 0.00390625 -dV Pytorch max diff: 0.00390625 -dQ Pytorch mean diff: 3.8504600524902344e-05 -dK Pytorch mean diff: 6.74128532409668e-05 -dV Pytorch mean diff: 5.7637691497802734e-05 -.Output max diff: 0.015625 -Output mean diff: 8.821487426757812e-05 -Pytorch max diff: 0.015625 -Pytorch mean diff: 0.0002384185791015625 -==== 512 256 True 128 False torch.bfloat16 -dQ max diff: 0.03125 -dK max diff: 0.015625 -dV max diff: 0.03125 -dQ mean diff: 0.0001087188720703125 -dK mean diff: 0.0001773834228515625 -dV mean diff: 0.00018596649169921875 -dQ Pytorch max diff: 0.03125 -dK Pytorch max diff: 0.0234375 -dV Pytorch max diff: 0.03125 -dQ Pytorch mean diff: 0.000278472900390625 -dK Pytorch mean diff: 0.00046539306640625 -dV Pytorch mean diff: 0.000392913818359375 -.Output max diff: 0.001953125 -Output mean diff: 2.2351741790771484e-05 -Pytorch max diff: 0.001953125 -Pytorch mean diff: 3.546476364135742e-05 -==== 512 256 True 128 True torch.float16 -dQ max diff: 0.00390625 -dK max diff: 0.001953125 -dV max diff: 0.00390625 -dQ mean diff: 1.9371509552001953e-05 -dK mean diff: 4.029273986816406e-05 -dV mean diff: 3.612041473388672e-05 -dQ Pytorch max diff: 0.00390625 -dK Pytorch max diff: 0.0029296875 -dV Pytorch max diff: 0.00390625 -dQ Pytorch mean diff: 3.933906555175781e-05 -dK Pytorch mean diff: 7.12275505065918e-05 -dV Pytorch mean diff: 6.091594696044922e-05 -.Output max diff: 0.015625 -Output mean diff: 9.012222290039062e-05 -Pytorch max diff: 0.015625 -Pytorch mean diff: 0.00024318695068359375 -==== 512 256 True 128 True torch.bfloat16 -dQ max diff: 0.03125 -dK max diff: 0.015625 -dV max diff: 0.03125 -dQ mean diff: 0.00011110305786132812 -dK mean diff: 0.000186920166015625 -dV mean diff: 0.000194549560546875 -dQ Pytorch max diff: 0.03125 -dK Pytorch max diff: 0.0234375 -dV Pytorch max diff: 0.03125 -dQ Pytorch mean diff: 0.000286102294921875 -dK Pytorch mean diff: 0.000492095947265625 -dV Pytorch mean diff: 0.0004138946533203125 -""" \ No newline at end of file From c00f8064861936463e5d01648532e10fc7482b9d Mon Sep 17 00:00:00 2001 From: jerryao Date: Fri, 22 Aug 2025 10:23:53 +0800 Subject: [PATCH 16/29] forbidding Learnable sink and ALiBi to party together. --- benchmarks/benchmark_flash_attention.py | 32 ++----------------- csrc/flash_attn/flash_api.cpp | 3 ++ .../src/flash_bwd_launch_template.h | 2 +- .../src/flash_fwd_launch_template.h | 4 +-- tests/test_flash_attn.py | 2 +- 5 files changed, 10 insertions(+), 33 deletions(-) diff --git a/benchmarks/benchmark_flash_attention.py b/benchmarks/benchmark_flash_attention.py index 121bf63f27f..8caf0084776 100644 --- a/benchmarks/benchmark_flash_attention.py +++ b/benchmarks/benchmark_flash_attention.py @@ -13,9 +13,6 @@ from flash_attn import flash_attn_qkvpacked_func from flash_attn import flash_attn_func -from flash_attn import flash_attn_sink_func -from flash_attn_with_sink import flash_attn_with_sink_func -from flash_attn_with_sink_fused import flash_attn_with_sink_fused_func try: from triton.ops.flash_attention import attention as attention_triton @@ -81,7 +78,7 @@ def time_fwd_bwd(func, *args, **kwargs): dim = 2048 dropout_p = 0.0 -methods = (["Flash2", "Flash2UnPacked", "Pytorch", "Flash2Sink", "Flash2SinkFused"] +methods = (["Flash2", "Pytorch", "Flash2Sink"] + (["Triton"] if attention_triton is not None else []) + (["xformers.c"] if xops is not None else []) + (["xformers.f"] if xops is not None else [])) @@ -105,13 +102,6 @@ def time_fwd_bwd(func, *args, **kwargs): time_f[config, "Flash2"] = f time_b[config, "Flash2"] = b - q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, - requires_grad=True) for _ in range(3)] - f, b = time_fwd_bwd( - flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=False - ) - time_f[config, "Flash2UnPacked"] = f - time_b[config, "Flash2UnPacked"] = b try: qkv = qkv.detach().requires_grad_(True) @@ -133,28 +123,12 @@ def time_fwd_bwd(func, *args, **kwargs): sink = torch.randn((nheads,), dtype=torch.float32, device=device, requires_grad=True) f, b = time_fwd_bwd( - flash_attn_with_sink_fused_func, q, k, v, sink, softmax_scale=scaling, dropout_p=dropout_p, causal=causal, repeats=repeats, verbose=False + flash_attn_func, q, k, v, softmax_scale=scaling, dropout_p=dropout_p, causal=causal, learnable_sink=sink, repeats=repeats, verbose=False ) - time_f[config, "Flash2SinkFused"] = f - time_b[config, "Flash2SinkFused"] = b - - try: - scaling = nheads**-0.5 - num_key_value_heads = nheads # // 8 - q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, - requires_grad=True) - k, v = [torch.randn(batch_size, seqlen, num_key_value_heads, headdim, device=device, dtype=dtype, - requires_grad=True) for _ in range(2)] - sink = torch.randn((nheads,), dtype=dtype, device=device, requires_grad=True) - - f, b = time_fwd_bwd( - flash_attn_with_sink_func, q, k, v, sink, softmax_scale=scaling, dropout_p=dropout_p, causal=causal, repeats=repeats, verbose=False - ) - except: # Skip if OOM - f, b = float('nan'), float('nan') time_f[config, "Flash2Sink"] = f time_b[config, "Flash2Sink"] = b + if attention_triton is not None: q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype, requires_grad=True) for _ in range(3)] diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index d163997ae1c..71141aea3b7 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -353,6 +353,9 @@ void set_params_sink(Flash_fwd_params ¶ms, const std::optional 128, set IsEvenMNConst to false to reduce number of templates // If Is_local, set Is_causal to false - auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; if (smem_size_dq_dk_dv >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index 49ec54a5a03..3150637a1e6 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -77,7 +77,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // If return_softmax, set IsEvenMNConst to false to reduce number of templates // If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_kernel; + auto kernel = &flash_fwd_kernel; // auto kernel = &flash_fwd_kernel; // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); // auto kernel = &flash_fwd_kernel; @@ -120,7 +120,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_splitkv_kernel; + auto kernel = &flash_fwd_splitkv_kernel; // auto kernel = &flash_fwd_splitkv_kernel; // auto kernel = &flash_fwd_splitkv_kernel; if (smem_size >= 48 * 1024) { diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index bd793343c2f..60a852b712d 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -2650,7 +2650,7 @@ def test_flash_attn_sink_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype) k_ref = k.transpose(1, 2) v_ref = v.transpose(1, 2) sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) - out = flash_attn_func(q, k, v, sink, 0.0, softmax_scale=d**-0.5, causal=causal, window_size=window_size) + out = flash_attn_func(q, k, v, 0.0, softmax_scale=d**-0.5, causal=causal, window_size=window_size, learnable_sink=sink) attention_mask = get_attention_mask(seqlen_q, seqlen_k, causal, device, window_size) if attention_mask is not None: From f5044c8020205843159ffd0461820d80ba769688 Mon Sep 17 00:00:00 2001 From: jerryao Date: Sat, 23 Aug 2025 15:54:50 +0800 Subject: [PATCH 17/29] learnable_sink optonal. --- flash_attn/flash_attn_interface.py | 36 +++++++++++++++--------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 3d1dbfa5453..63562ecfb7f 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -1050,13 +1050,13 @@ def backward(ctx, dout, *args): def flash_attn_qkvpacked_func( qkv, - learnable_sink=None, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window softcap=0.0, # <=0.0 means deactivate alibi_slopes=None, + learnable_sink=None, deterministic=False, return_attn_probs=False, ): @@ -1072,8 +1072,6 @@ def flash_attn_qkvpacked_func( Arguments: qkv: (batch_size, seqlen, 3, nheads, headdim) - learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an - additional attention logit to each query's attention scores before softmax. dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). @@ -1082,6 +1080,8 @@ def flash_attn_qkvpacked_func( softcap: float. Anything > 0 activates softcapping attention. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to the attention score of query i and key j. + learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an + additional attention logit to each query's attention scores before softmax. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for @@ -1114,13 +1114,13 @@ def flash_attn_qkvpacked_func( def flash_attn_kvpacked_func( q, kv, - learnable_sink, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window softcap=0.0, # 0.0 means deactivated alibi_slopes=None, + learnable_sink=None, deterministic=False, return_attn_probs=False, ): @@ -1152,8 +1152,6 @@ def flash_attn_kvpacked_func( Arguments: q: (batch_size, seqlen, nheads, headdim) kv: (batch_size, seqlen, 2, nheads_k, headdim) - learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an - additional attention logit to each query's attention scores before softmax. dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). @@ -1163,6 +1161,8 @@ def flash_attn_kvpacked_func( alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) is added to the attention score of query i and key j. + learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an + additional attention logit to each query's attention scores before softmax. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for @@ -1197,13 +1197,13 @@ def flash_attn_func( q, k, v, - learnable_sink, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window softcap=0.0, # 0.0 means deactivated alibi_slopes=None, + learnable_sink=None, deterministic=False, return_attn_probs=False, ): @@ -1233,8 +1233,6 @@ def flash_attn_func( q: (batch_size, seqlen, nheads, headdim) k: (batch_size, seqlen, nheads_k, headdim) v: (batch_size, seqlen, nheads_k, headdim) - learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an - additional attention logit to each query's attention scores before softmax. dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). @@ -1243,6 +1241,8 @@ def flash_attn_func( alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) is added to the attention score of query i and key j. + learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an + additional attention logit to each query's attention scores before softmax. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for @@ -1276,7 +1276,6 @@ def flash_attn_func( def flash_attn_varlen_qkvpacked_func( qkv, - learnable_sink, cu_seqlens, max_seqlen, dropout_p=0.0, @@ -1285,6 +1284,7 @@ def flash_attn_varlen_qkvpacked_func( window_size=(-1, -1), # -1 means infinite context window softcap=0.0, # 0.0 means deactivated alibi_slopes=None, + learnable_sink=None, deterministic=False, return_attn_probs=False, ): @@ -1300,8 +1300,6 @@ def flash_attn_varlen_qkvpacked_func( Arguments: qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch. - learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an - additional attention logit to each query's attention scores before softmax. cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into qkv. max_seqlen: int. Maximum sequence length in the batch. @@ -1313,6 +1311,8 @@ def flash_attn_varlen_qkvpacked_func( softcap: float. Anything > 0 activates softcapping attention. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to the attention score of query i and key j. + learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an + additional attention logit to each query's attention scores before softmax. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for @@ -1347,7 +1347,6 @@ def flash_attn_varlen_qkvpacked_func( def flash_attn_varlen_kvpacked_func( q, kv, - learnable_sink, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -1358,6 +1357,7 @@ def flash_attn_varlen_kvpacked_func( window_size=(-1, -1), # -1 means infinite context window softcap=0.0, # 0.0 means deactivated alibi_slopes=None, + learnable_sink=None, deterministic=False, return_attn_probs=False, ): @@ -1389,8 +1389,6 @@ def flash_attn_varlen_kvpacked_func( Arguments: q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch. - learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an - additional attention logit to each query's attention scores before softmax. cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into q. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths @@ -1406,6 +1404,8 @@ def flash_attn_varlen_kvpacked_func( alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) is added to the attention score of query i and key j. + learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an + additional attention logit to each query's attention scores before softmax. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for @@ -1444,7 +1444,6 @@ def flash_attn_varlen_func( q, k, v, - learnable_sink, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -1455,6 +1454,7 @@ def flash_attn_varlen_func( window_size=(-1, -1), # -1 means infinite context window softcap=0.0, # 0.0 means deactivated alibi_slopes=None, + learnable_sink=None, deterministic=False, return_attn_probs=False, block_table=None, @@ -1485,8 +1485,6 @@ def flash_attn_varlen_func( q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. - learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an - additional attention logit to each query's attention scores before softmax. cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into q. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths @@ -1502,6 +1500,8 @@ def flash_attn_varlen_func( alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) is added to the attention score of query i and key j. + learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an + additional attention logit to each query's attention scores before softmax. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for From 793f3f59b4fed6d8b9179e5de59500bc4126b647 Mon Sep 17 00:00:00 2001 From: jerryao Date: Sun, 24 Aug 2025 23:13:08 +0800 Subject: [PATCH 18/29] Fix arg learnable_sink. --- flash_attn/flash_attn_interface.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 63562ecfb7f..3d8beb8d1d7 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -240,7 +240,7 @@ def _flash_attn_varlen_forward_fake( _wrapped_flash_attn_varlen_forward = _flash_attn_varlen_forward -@_torch_custom_op_wrapper("flash_attn::_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda") +@_torch_custom_op_wrapper("flash_attn::_flash_attn_backward", mutates_args=("dq", "dk", "dv", "dsink"), device_types="cuda") def _flash_attn_backward( dout: torch.Tensor, q: torch.Tensor, @@ -341,7 +341,7 @@ def _flash_attn_backward_fake( _wrapped_flash_attn_backward = _flash_attn_backward -@_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda") +@_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_backward", mutates_args=("dq", "dk", "dv", "dsink"), device_types="cuda") def _flash_attn_varlen_backward( dout: torch.Tensor, q: torch.Tensor, @@ -1022,6 +1022,7 @@ def backward(ctx, dout, *args): q, k, v, + learnable_sink, out, softmax_lse, dq, From 51524daf48a4ce44e4a51c5498eeb224fbe8fdb4 Mon Sep 17 00:00:00 2001 From: jerryao Date: Tue, 26 Aug 2025 16:41:35 +0800 Subject: [PATCH 19/29] Fix bwd with sink and dropout. --- csrc/flash_attn/src/flash_bwd_kernel.h | 4 +- tests/test_flash_attn.py | 282 ++++++++++++++++++------- 2 files changed, 209 insertions(+), 77 deletions(-) diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index 18bc6164267..1b77f979e90 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -589,7 +589,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in float dsink_val_cols = 0.f; #pragma unroll for (int ni = 0; ni < size<1>(dS); ++ni) { - if constexpr (Has_sink) { dsink_val_cols += dS(mi, ni) * scores(mi, ni); } + if constexpr (Has_sink) { dsink_val_cols += pointwise_mult(scores(mi, ni), dS(mi, ni), 0.f); } float scaled_ds = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi)); if constexpr (Is_softcap) { scaled_ds *= dtanh(mi, ni); } dS(mi, ni) = scaled_ds; @@ -801,7 +801,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in dsink_val = Allreduce<32>::run(dsink_val, sum_op); if (tidx % 32 == 0 && params.dsink_ptr != nullptr) { float* dsink_ptr = reinterpret_cast(params.dsink_ptr); - float val = -dsink_val * exp2f(sink_val * float(M_LOG2E)); + float val = -dsink_val * exp2f(sink_val * float(M_LOG2E)) * params.rp_dropout; atomicAdd(dsink_ptr + bidh, val); } } diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 60a852b712d..6ce288e1859 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -226,6 +226,7 @@ def attention_ref( dropout_mask=None, causal=False, window_size=(-1, -1), # -1 means infinite window size + learnable_sink=None, softcap=0.0, upcast=True, reorder_ops=False, @@ -284,7 +285,16 @@ def attention_ref( scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: scores = scores + attn_bias - attention = torch.softmax(scores, dim=-1).to(v.dtype) + if learnable_sink is None: + attention = torch.softmax(scores, dim=-1).to(v.dtype) + else: + scores_fp32 = scores.to(torch.float32) + logits_max = torch.amax(scores_fp32, dim=-1, keepdim=True) + learnable_sink = rearrange(learnable_sink, "h -> h 1 1") + logits_or_sinks_max = torch.maximum(learnable_sink, logits_max) + unnormalized_scores = torch.exp(scores_fp32 - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp(learnable_sink - logits_or_sinks_max) + attention = (unnormalized_scores / normalizer).to(v.dtype) # Some rows might be completely masked out so we fill them with zero instead of NaN if window_size[0] >= 0 or window_size[1] >= 0: attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) @@ -315,6 +325,7 @@ def attention_kvpacked_ref( dropout_mask=None, causal=False, window_size=(-1, -1), # -1 means infinite window size + learnable_sink=None, softcap=0.0, upcast=True, reorder_ops=False, @@ -332,6 +343,7 @@ def attention_kvpacked_ref( upcast=upcast, causal=causal, window_size=window_size, + learnable_sink=learnable_sink, softcap=softcap, reorder_ops=reorder_ops, key_leftpad=key_leftpad, @@ -346,6 +358,7 @@ def attention_qkvpacked_ref( dropout_mask=None, causal=False, window_size=(-1, -1), # -1 means infinite window size + learnable_sink=None, softcap=0.0, upcast=True, reorder_ops=False, @@ -362,6 +375,7 @@ def attention_qkvpacked_ref( upcast=upcast, causal=causal, window_size=window_size, + learnable_sink=learnable_sink, softcap=softcap, reorder_ops=reorder_ops, ) @@ -474,6 +488,7 @@ def normalize_flash_attn_S( is_dropout=False, causal=False, window_size=(-1, -1), # -1 means infinite window size + learnable_sink=None, ): """ Arguments: @@ -508,11 +523,16 @@ def normalize_flash_attn_S( block_size_n = _get_block_size_n(scores.device, head_dim, is_dropout, causal) scores_block = scores.split(block_size_n, dim=-1) lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1) - lse = torch.logsumexp(lse_block, dim=-1) + if learnable_sink is None: + lse = torch.logsumexp(lse_block, dim=-1) + scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1) + else: + sinks = learnable_sink.reshape(1, -1, 1, 1).expand(scores.shape[0], -1, scores.shape[2], -1) + lse = torch.logsumexp(torch.cat([lse_block, sinks], dim=-1), dim=-1) + scores_max_block = torch.stack([torch.amax(torch.cat([s, sinks], dim=-1), dim=-1) for s in scores_block], dim=-1) # lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf # so that when we do torch.exp(m - lse), we get 0.0 instead of NaN. lse[lse == float("-inf")] = float("inf") - scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1) cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1) attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1) attn_norm = torch.cat( @@ -657,9 +677,12 @@ def get_attention_mask( # @pytest.mark.parametrize("seqlen", [512]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) # @pytest.mark.parametrize("dropout_p", [0.0]) -def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype): +@pytest.mark.parametrize("has_learnable_sink", [False, True]) +def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype, has_learnable_sink): if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: pytest.skip() # Reference implementation OOM + if alibi and has_learnable_sink: + pytest.skip("Alibi and learnable sink not supported together") device = "cuda" # set seed torch.random.manual_seed(0) @@ -674,12 +697,17 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal) else: alibi_slopes, attn_bias = None, None + if has_learnable_sink: + learnable_sink = torch.randn(nheads, device=device, dtype=torch.float32, requires_grad=True) + else: + learnable_sink = None out, lse, S_dmask = flash_attn_qkvpacked_func( qkv, dropout_p, causal=causal, window_size=window_size, alibi_slopes=alibi_slopes, + learnable_sink=learnable_sink, deterministic=deterministic, return_attn_probs=True, ) @@ -708,6 +736,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ dropout_p > 0.0, causal=causal, window_size=window_size, + learnable_sink=learnable_sink, ) dropout_fraction = get_dropout_fraction( dropout_mask, None, None, causal=causal, window_size=window_size @@ -717,7 +746,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ dropout_mask = None out_ref, attn_ref = attention_qkvpacked_ref( - qkv, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size + qkv, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, learnable_sink=learnable_sink ) out_pt, attn_pt = attention_qkvpacked_ref( qkv, @@ -727,6 +756,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ dropout_mask, causal=causal, window_size=window_size, + learnable_sink=learnable_sink, upcast=False, reorder_ops=True, ) @@ -761,9 +791,9 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ # dv_tmp = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, :64], g[:, :64]) # dv_tmp1 = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, 64:], g[:, 64:]) if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): - (dqkv,) = torch.autograd.grad(out, qkv, g) - (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g) - (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g) + (dqkv,) = torch.autograd.grad(out, qkv, g, retain_graph=has_learnable_sink) + (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g, retain_graph=has_learnable_sink) + (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g, retain_graph=has_learnable_sink) print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}") print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") @@ -772,6 +802,14 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}") + if has_learnable_sink: + (dsink,) = torch.autograd.grad(out, learnable_sink, g) + (dsink_ref,) = torch.autograd.grad(out_ref, learnable_sink, g) + (dsink_pt,) = torch.autograd.grad(out_pt, learnable_sink, g) + print(f"dSink max diff: {(dsink - dsink_ref).abs().max().item()}") + print(f"dSink mean diff: {(dsink - dsink_ref).abs().mean().item()}") + print(f"dSink Pytorch max diff: {(dsink_pt - dsink_ref).abs().max().item()}") + print(f"dSink Pytorch mean diff: {(dsink_pt - dsink_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. @@ -785,6 +823,9 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() + if has_learnable_sink: + tol_multiplier = 7 if dtype == torch.float16 and dropout_p > 0.0 else 3 + assert (dsink - dsink_ref).abs().max().item() <= tol_multiplier * (dsink_pt - dsink_ref).abs().max().item() @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @@ -804,11 +845,14 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ # @pytest.mark.parametrize('seqlen', [128]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) # @pytest.mark.parametrize('dropout_p', [0.0]) +@pytest.mark.parametrize("has_learnable_sink", [False, True]) def test_flash_attn_varlen_qkvpacked( - seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype + seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype, has_learnable_sink ): if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: pytest.skip() # Reference implementation OOM + if alibi and has_learnable_sink: + pytest.skip("Alibi and learnable sink not supported together") device = "cuda" # set seed torch.random.manual_seed(0) @@ -818,6 +862,10 @@ def test_flash_attn_varlen_qkvpacked( qkv = torch.randn( batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True ) + if has_learnable_sink: + learnable_sink = torch.randn(nheads, device=device, dtype=torch.float32, requires_grad=True) + else: + learnable_sink = None key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode="random") # key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full') @@ -841,6 +889,7 @@ def test_flash_attn_varlen_qkvpacked( causal=causal, window_size=window_size, alibi_slopes=alibi_slopes, + learnable_sink=learnable_sink, deterministic=deterministic, return_attn_probs=True, ) @@ -870,6 +919,7 @@ def test_flash_attn_varlen_qkvpacked( dropout_p > 0.0, causal=causal, window_size=window_size, + learnable_sink=learnable_sink, ) dropout_fraction = get_dropout_fraction( dropout_mask, key_padding_mask, key_padding_mask, causal=causal, window_size=window_size @@ -886,6 +936,7 @@ def test_flash_attn_varlen_qkvpacked( dropout_mask, causal=causal, window_size=window_size, + learnable_sink=learnable_sink, ) out_pt, attn_pt = attention_qkvpacked_ref( qkv, @@ -895,6 +946,7 @@ def test_flash_attn_varlen_qkvpacked( dropout_mask, causal=causal, window_size=window_size, + learnable_sink=learnable_sink, upcast=False, reorder_ops=True, ) @@ -908,10 +960,10 @@ def test_flash_attn_varlen_qkvpacked( g = torch.randn_like(out) if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): - (dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g) + (dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g, retain_graph=has_learnable_sink) dqkv = dqkv_pad_fn(dqkv_unpad) - (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g) - (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g) + (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g, retain_graph=has_learnable_sink) + (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g, retain_graph=has_learnable_sink) print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}") print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") @@ -920,6 +972,14 @@ def test_flash_attn_varlen_qkvpacked( print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}") + if has_learnable_sink: + (dsink,) = torch.autograd.grad(out, learnable_sink, g) + (dsink_ref,) = torch.autograd.grad(out_ref, learnable_sink, g) + (dsink_pt,) = torch.autograd.grad(out_pt, learnable_sink, g) + print(f"dSink max diff: {(dsink - dsink_ref).abs().max().item()}") + print(f"dSink mean diff: {(dsink - dsink_ref).abs().mean().item()}") + print(f"dSink Pytorch max diff: {(dsink_pt - dsink_ref).abs().max().item()}") + print(f"dSink Pytorch mean diff: {(dsink_pt - dsink_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. @@ -933,6 +993,9 @@ def test_flash_attn_varlen_qkvpacked( if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() + if has_learnable_sink: + tol_multiplier = 7 if dtype == torch.float16 and dropout_p > 0.0 else 3 + assert (dsink - dsink_ref).abs().max().item() <= tol_multiplier * (dsink_pt - dsink_ref).abs().max().item() @pytest.mark.parametrize("kvpacked", [True, False]) @@ -974,8 +1037,9 @@ def test_flash_attn_varlen_qkvpacked( @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) # @pytest.mark.parametrize("dropout_p", [0.0]) @pytest.mark.parametrize("softcap", [0.0, 50.0]) +@pytest.mark.parametrize("has_learnable_sink", [False, True]) def test_flash_attn_output( - seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap + seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap, has_learnable_sink ): if ( max(seqlen_q, seqlen_k) >= 2048 @@ -984,6 +1048,8 @@ def test_flash_attn_output( pytest.skip() # Reference implementation OOM if softcap > 0.0 and dropout_p > 0.0: pytest.skip("Softcap and dropout not supported together") + if alibi and has_learnable_sink: + pytest.skip("Alibi and learnable sink not supported together") device = "cuda" # set seed torch.random.manual_seed(0) @@ -1012,7 +1078,10 @@ def test_flash_attn_output( attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal) else: alibi_slopes, attn_bias = None, None - + if has_learnable_sink: + learnable_sink = torch.randn(nheads, device=device, dtype=torch.float32, requires_grad=True) + else: + learnable_sink = None if kvpacked: out, lse, S_dmask = flash_attn_kvpacked_func( q, @@ -1022,6 +1091,7 @@ def test_flash_attn_output( window_size=window_size, softcap=softcap, alibi_slopes=alibi_slopes, + learnable_sink=learnable_sink, deterministic=deterministic, return_attn_probs=True, ) @@ -1035,6 +1105,7 @@ def test_flash_attn_output( window_size=window_size, softcap=softcap, alibi_slopes=alibi_slopes, + learnable_sink=learnable_sink, deterministic=deterministic, return_attn_probs=True, ) @@ -1069,6 +1140,7 @@ def test_flash_attn_output( dropout_p > 0.0, causal=causal, window_size=window_size, + learnable_sink=learnable_sink, ) dropout_fraction = get_dropout_fraction( dropout_mask, None, None, causal=causal, window_size=window_size @@ -1088,6 +1160,7 @@ def test_flash_attn_output( dropout_mask, causal=causal, window_size=window_size, + learnable_sink=learnable_sink, softcap=softcap, ) out_pt, attn_pt = attention_kvpacked_ref( @@ -1100,6 +1173,7 @@ def test_flash_attn_output( dropout_mask, causal=causal, window_size=window_size, + learnable_sink=learnable_sink, softcap=softcap, upcast=False, reorder_ops=True, @@ -1116,6 +1190,7 @@ def test_flash_attn_output( dropout_mask, causal=causal, window_size=window_size, + learnable_sink=learnable_sink, softcap=softcap, ) out_pt, attn_pt = attention_ref( @@ -1129,6 +1204,7 @@ def test_flash_attn_output( dropout_mask, causal=causal, window_size=window_size, + learnable_sink=learnable_sink, softcap=softcap, upcast=False, reorder_ops=True, @@ -1149,34 +1225,34 @@ def test_flash_attn_output( ( dq, dkv, - ) = torch.autograd.grad(out, (q, kv), g) + ) = torch.autograd.grad(out, (q, kv), g, retain_graph=has_learnable_sink) dk, dv = dkv.unbind(2) ( dq_ref, dkv_ref, - ) = torch.autograd.grad(out_ref, (q, kv), g) + ) = torch.autograd.grad(out_ref, (q, kv), g, retain_graph=has_learnable_sink) dk_ref, dv_ref = dkv_ref.unbind(2) ( dq_pt, dkv_pt, - ) = torch.autograd.grad(out_pt, (q, kv), g) + ) = torch.autograd.grad(out_pt, (q, kv), g, retain_graph=has_learnable_sink) dk_pt, dv_pt = dkv_pt.unbind(2) else: ( dq, dk, dv, - ) = torch.autograd.grad(out, (q, k, v), g) + ) = torch.autograd.grad(out, (q, k, v), g, retain_graph=has_learnable_sink) ( dq_ref, dk_ref, dv_ref, - ) = torch.autograd.grad(out_ref, (q, k, v), g) + ) = torch.autograd.grad(out_ref, (q, k, v), g, retain_graph=has_learnable_sink) ( dq_pt, dk_pt, dv_pt, - ) = torch.autograd.grad(out_pt, (q, k, v), g) + ) = torch.autograd.grad(out_pt, (q, k, v), g, retain_graph=has_learnable_sink) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") @@ -1189,6 +1265,14 @@ def test_flash_attn_output( print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + if has_learnable_sink: + dsink, = torch.autograd.grad(out, (learnable_sink,), g) + dsink_ref, = torch.autograd.grad(out_ref, (learnable_sink,), g) + dsink_pt, = torch.autograd.grad(out_pt, (learnable_sink,), g) + print(f"dSink max diff: {(dsink - dsink_ref).abs().max().item()}") + print(f"dSink mean diff: {(dsink - dsink_ref).abs().mean().item()}") + print(f"dSink Pytorch max diff: {(dsink_pt - dsink_ref).abs().max().item()}") + print(f"dSink Pytorch mean diff: {(dsink_pt - dsink_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. @@ -1204,6 +1288,8 @@ def test_flash_attn_output( assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() + if has_learnable_sink: + assert (dsink - dsink_ref).abs().max().item() <= 3 * (dsink_pt - dsink_ref).abs().max().item() @pytest.mark.parametrize("kvpacked", [True, False]) @@ -1220,9 +1306,9 @@ def test_flash_attn_output( # @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize('causal', [True]) -@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) -# @pytest.mark.parametrize('d', [64]) +@pytest.mark.parametrize('d', [64]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -1243,8 +1329,9 @@ def test_flash_attn_output( @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) @pytest.mark.parametrize("softcap", [0.0, 50.0]) # @pytest.mark.parametrize('dropout_p', [0.0]) +@pytest.mark.parametrize("has_learnable_sink", [False, True]) def test_flash_attn_varlen_output( - seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap + seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap, has_learnable_sink ): if ( max(seqlen_q, seqlen_k) >= 2048 @@ -1253,6 +1340,8 @@ def test_flash_attn_varlen_output( pytest.skip() # Reference implementation OOM if softcap > 0.0 and dropout_p > 0.0: pytest.skip("Softcap and dropout not supported together") + if alibi and has_learnable_sink: + pytest.skip("Alibi and learnable sink not supported together") device = "cuda" # set seed torch.random.manual_seed(0) @@ -1277,7 +1366,10 @@ def test_flash_attn_varlen_output( v = torch.randn( batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True ) - + if has_learnable_sink: + learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) + else: + learnable_sink = None query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full') @@ -1313,6 +1405,7 @@ def test_flash_attn_varlen_output( dropout_p, causal=causal, window_size=window_size, + learnable_sink=learnable_sink, softcap=softcap, alibi_slopes=alibi_slopes, deterministic=deterministic, @@ -1345,6 +1438,7 @@ def test_flash_attn_varlen_output( dropout_p, causal=causal, window_size=window_size, + learnable_sink=learnable_sink, softcap=softcap, alibi_slopes=alibi_slopes, deterministic=deterministic, @@ -1433,6 +1527,7 @@ def test_flash_attn_varlen_output( dropout_mask, causal=causal, window_size=window_size, + learnable_sink=learnable_sink, softcap=softcap, ) out_pt, attn_pt = attention_ref( @@ -1446,6 +1541,7 @@ def test_flash_attn_varlen_output( dropout_mask, causal=causal, window_size=window_size, + learnable_sink=learnable_sink, softcap=softcap, upcast=False, reorder_ops=True, @@ -1465,36 +1561,36 @@ def test_flash_attn_varlen_output( ( dq_unpad, dkv_unpad, - ) = torch.autograd.grad(out, (q_unpad, kv_unpad), g) + ) = torch.autograd.grad(out, (q_unpad, kv_unpad), g, retain_graph=has_learnable_sink) dk, dv = dkv_pad_fn(dkv_unpad).unbind(2) ( dq_ref, dkv_ref, - ) = torch.autograd.grad(out_ref, (q, kv), g) + ) = torch.autograd.grad(out_ref, (q, kv), g, retain_graph=has_learnable_sink) dk_ref, dv_ref = dkv_ref.unbind(2) ( dq_pt, dkv_pt, - ) = torch.autograd.grad(out_pt, (q, kv), g) + ) = torch.autograd.grad(out_pt, (q, kv), g, retain_graph=has_learnable_sink) dk_pt, dv_pt = dkv_pt.unbind(2) else: ( dq_unpad, dk_unpad, dv_unpad, - ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g) + ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=has_learnable_sink) dk = dk_pad_fn(dk_unpad) dv = dk_pad_fn(dv_unpad) ( dq_ref, dk_ref, dv_ref, - ) = torch.autograd.grad(out_ref, (q, k, v), g) + ) = torch.autograd.grad(out_ref, (q, k, v), g, retain_graph=has_learnable_sink) ( dq_pt, dk_pt, dv_pt, - ) = torch.autograd.grad(out_pt, (q, k, v), g) + ) = torch.autograd.grad(out_pt, (q, k, v), g, retain_graph=has_learnable_sink) dq = dq_pad_fn(dq_unpad) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") @@ -1508,6 +1604,14 @@ def test_flash_attn_varlen_output( print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + if has_learnable_sink: + dsink, = torch.autograd.grad(out, (learnable_sink,), g) + dsink_ref, = torch.autograd.grad(out_ref, (learnable_sink,), g) + dsink_pt, = torch.autograd.grad(out_pt, (learnable_sink,), g) + print(f"dSink max diff: {(dsink - dsink_ref).abs().max().item()}") + print(f"dSink mean diff: {(dsink - dsink_ref).abs().mean().item()}") + print(f"dSink Pytorch max diff: {(dsink_pt - dsink_ref).abs().max().item()}") + print(f"dSink Pytorch mean diff: {(dsink_pt - dsink_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. @@ -1552,8 +1656,8 @@ def test_flash_attn_varlen_output( (1023, 1024), ], ) -# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) -def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): +@pytest.mark.parametrize("has_learnable_sink", [True, False]) +def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype, has_learnable_sink): if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 @@ -1571,9 +1675,13 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) - out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size) + if has_learnable_sink: + learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) + else: + learnable_sink = None + out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, learnable_sink=learnable_sink) out_ref, attn_ref = attention_ref( - q, k, v, None, None, None, 0.0, None, causal=causal, window_size=window_size + q, k, v, None, None, None, 0.0, None, causal=causal, window_size=window_size, learnable_sink=learnable_sink ) out_pt, attn_pt = attention_ref( q, @@ -1586,6 +1694,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): None, causal=causal, window_size=window_size, + learnable_sink=learnable_sink, upcast=False, reorder_ops=True, ) @@ -1601,17 +1710,17 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): dq, dk, dv, - ) = torch.autograd.grad(out, (q, k, v), g) + ) = torch.autograd.grad(out, (q, k, v), g, retain_graph=has_learnable_sink) ( dq_ref, dk_ref, dv_ref, - ) = torch.autograd.grad(out_ref, (q, k, v), g) + ) = torch.autograd.grad(out_ref, (q, k, v), g, retain_graph=has_learnable_sink) ( dq_pt, dk_pt, dv_pt, - ) = torch.autograd.grad(out_pt, (q, k, v), g) + ) = torch.autograd.grad(out_pt, (q, k, v), g, retain_graph=has_learnable_sink) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") @@ -1624,6 +1733,14 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + if has_learnable_sink: + dsink, = torch.autograd.grad(out, (learnable_sink,), g) + dsink_ref, = torch.autograd.grad(out_ref, (learnable_sink,), g) + dsink_pt, = torch.autograd.grad(out_pt, (learnable_sink,), g) + print(f"dSink max diff: {(dsink - dsink_ref).abs().max().item()}") + print(f"dSink mean diff: {(dsink - dsink_ref).abs().mean().item()}") + print(f"dSink Pytorch max diff: {(dsink_pt - dsink_ref).abs().max().item()}") + print(f"dSink Pytorch mean diff: {(dsink_pt - dsink_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. @@ -1632,6 +1749,8 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5 assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 + if has_learnable_sink: + assert (dsink - dsink_ref).abs().max().item() <= 3 * (dsink_pt - dsink_ref).abs().max().item() + 1e-5 @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @@ -1664,8 +1783,9 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): # TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512]) # @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)]) +@pytest.mark.parametrize("has_learnable_sink", [True, False]) def test_flash_attn_varlen_causal( - seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype + seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, has_learnable_sink, dtype ): if ( max(seqlen_q, seqlen_k) >= 2048 @@ -1695,6 +1815,10 @@ def test_flash_attn_varlen_causal( k, v, block_table, k_cache_paged, v_cache_paged, num_blocks = _generate_block_kvcache( seqlen_k, paged_kv_block_size, batch_size, nheads, d, device, dtype ) + if has_learnable_sink: + learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) + else: + learnable_sink = None query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") ( @@ -1723,6 +1847,7 @@ def test_flash_attn_varlen_causal( 0.0, causal=causal, window_size=window_size, + learnable_sink=learnable_sink, block_table=block_table, ) out = output_pad_fn(out_unpad) @@ -1737,6 +1862,7 @@ def test_flash_attn_varlen_causal( None, causal=causal, window_size=window_size, + learnable_sink=learnable_sink, ) out_pt, attn_pt = attention_ref( q, @@ -1749,6 +1875,7 @@ def test_flash_attn_varlen_causal( None, causal=causal, window_size=window_size, + learnable_sink=learnable_sink, upcast=False, reorder_ops=True, ) @@ -1766,7 +1893,7 @@ def test_flash_attn_varlen_causal( dq_unpad, dk_unpad, dv_unpad, - ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g) + ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=has_learnable_sink) dq = dq_pad_fn(dq_unpad) dk = dk_pad_fn(dk_unpad) dv = dk_pad_fn(dv_unpad) @@ -1774,12 +1901,12 @@ def test_flash_attn_varlen_causal( dq_ref, dk_ref, dv_ref, - ) = torch.autograd.grad(out_ref, (q, k, v), g) + ) = torch.autograd.grad(out_ref, (q, k, v), g, retain_graph=has_learnable_sink) ( dq_pt, dk_pt, dv_pt, - ) = torch.autograd.grad(out_pt, (q, k, v), g) + ) = torch.autograd.grad(out_pt, (q, k, v), g, retain_graph=has_learnable_sink) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") @@ -1792,6 +1919,14 @@ def test_flash_attn_varlen_causal( print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + if has_learnable_sink: + dsink, = torch.autograd.grad(out, (learnable_sink,), g) + dsink_ref, = torch.autograd.grad(out_ref, (learnable_sink,), g) + dsink_pt, = torch.autograd.grad(out_pt, (learnable_sink,), g) + print(f"dSink max diff: {(dsink - dsink_ref).abs().max().item()}") + print(f"dSink mean diff: {(dsink - dsink_ref).abs().mean().item()}") + print(f"dSink Pytorch max diff: {(dsink_pt - dsink_ref).abs().max().item()}") + print(f"dSink Pytorch mean diff: {(dsink_pt - dsink_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. @@ -1801,6 +1936,8 @@ def test_flash_attn_varlen_causal( assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5 assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 + if has_learnable_sink: + assert (dsink - dsink_ref).abs().max().item() <= 2 * (dsink_pt - dsink_ref).abs().max().item() + 1e-5 @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @@ -1836,8 +1973,9 @@ def test_flash_attn_varlen_causal( ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +@pytest.mark.parametrize("has_learnable_sink", [False, True]) def test_flash_attn_splitkv( - seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, deterministic, dtype + seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, deterministic, dtype, has_learnable_sink ): if swap_sq_sk: seqlen_q, seqlen_k = seqlen_k, seqlen_q @@ -1855,6 +1993,10 @@ def test_flash_attn_splitkv( attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal) else: alibi_slopes, attn_bias = None, None + if has_learnable_sink: + learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) + else: + learnable_sink = None out, lse, _ = flash_attn_func( q, k, @@ -1863,11 +2005,12 @@ def test_flash_attn_splitkv( causal=causal, window_size=window_size, alibi_slopes=alibi_slopes, + learnable_sink=learnable_sink, deterministic=deterministic, return_attn_probs=True, ) out_ref, attn_ref = attention_ref( - q, k, v, None, None, attn_bias, 0.0, None, causal=causal, window_size=window_size + q, k, v, None, None, attn_bias, 0.0, None, causal=causal, window_size=window_size, learnable_sink=learnable_sink ) out_pt, attn_pt = attention_ref( q, @@ -1880,6 +2023,7 @@ def test_flash_attn_splitkv( None, causal=causal, window_size=window_size, + learnable_sink=learnable_sink, upcast=False, reorder_ops=True, ) @@ -1895,17 +2039,17 @@ def test_flash_attn_splitkv( dq, dk, dv, - ) = torch.autograd.grad(out, (q, k, v), g) + ) = torch.autograd.grad(out, (q, k, v), g, retain_graph=has_learnable_sink) ( dq_ref, dk_ref, dv_ref, - ) = torch.autograd.grad(out_ref, (q, k, v), g) + ) = torch.autograd.grad(out_ref, (q, k, v), g, retain_graph=has_learnable_sink) ( dq_pt, dk_pt, dv_pt, - ) = torch.autograd.grad(out_pt, (q, k, v), g) + ) = torch.autograd.grad(out_pt, (q, k, v), g, retain_graph=has_learnable_sink) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") @@ -1918,6 +2062,14 @@ def test_flash_attn_splitkv( print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + if has_learnable_sink: + dsink, = torch.autograd.grad(out, (learnable_sink,), g) + dsink_ref, = torch.autograd.grad(out_ref, (learnable_sink,), g) + dsink_pt, = torch.autograd.grad(out_pt, (learnable_sink,), g) + print(f"dSink max diff: {(dsink - dsink_ref).abs().max().item()}") + print(f"dSink mean diff: {(dsink - dsink_ref).abs().mean().item()}") + print(f"dSink Pytorch max diff: {(dsink_pt - dsink_ref).abs().max().item()}") + print(f"dSink Pytorch mean diff: {(dsink_pt - dsink_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. @@ -1927,6 +2079,8 @@ def test_flash_attn_splitkv( assert (dq - dq_ref).abs().max().item() <= mult * (dq_pt - dq_ref).abs().max().item() + 2e-4 assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4 assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4 + if has_learnable_sink: + assert (dsink - dsink_ref).abs().max().item() <= mult * (dsink_pt - dsink_ref).abs().max().item() + 2e-4 # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @@ -2598,36 +2752,14 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus assert torch.equal(dk, dk0) assert torch.equal(dq, dq0) - -@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -# @pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("local", [False, True]) -# @pytest.mark.parametrize("local", [True]) -@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) -# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) -# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) -# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) -# @pytest.mark.parametrize('d', [56, 80]) -# @pytest.mark.parametrize("d", [64, 128]) -@pytest.mark.parametrize("swap_sq_sk", [False, True]) -# @pytest.mark.parametrize("swap_sq_sk", [True]) @pytest.mark.parametrize( - "seqlen_q,seqlen_k", + "seqlen_q,seqlen_k,swap_sq_sk,d,local,sink,dtype", [ - (1, 239), # disable seqlenq_ngroups_swapped - (3, 799), - (127, 512), - (127, 513), - (113, 203), - (128, 217), - (113, 211), - (108, 256), - (256, 512), - (1023, 1024), + (1, 239, True, 64, True, True, torch.bfloat16), + (127, 513, True, 128, True, True, torch.float16), ], ) -# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) -def test_flash_attn_sink_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): +def test_flash_attn_sink_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, sink, dtype): if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 @@ -2639,9 +2771,9 @@ def test_flash_attn_sink_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype) causal = True # set seed torch.random.manual_seed(0) - batch_size = 1 - nheads = 64 - num_key_value_groups = 8 + batch_size = 8 + nheads = 9 + num_key_value_groups = 1 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads // num_key_value_groups, d, device=device, dtype=dtype, requires_grad=True) From a317bae024275ede8762c1dff08bbcf9b39f550f Mon Sep 17 00:00:00 2001 From: Shaohong Fu Date: Thu, 4 Sep 2025 15:19:50 +0800 Subject: [PATCH 20/29] Debug for softmax LSE calculation. --- csrc/flash_attn/src/flash_fwd_kernel.h | 44 +++++++++++--------------- 1 file changed, 19 insertions(+), 25 deletions(-) diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 89200f3381e..a239a46d99d 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -1196,36 +1196,30 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += __expf(lse_accum(l) - lse_max); } SumOp sum_op; lse_sum = Allreduce::run(lse_sum, sum_op); + + if constexpr(Has_sink){ + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + if (row < params.num_splits && col < kBlockM) { + const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose; + if(params.unpadded_lse){ + // LSE is written as (h, seqlen_q, b) or (h, b, seqlen_q). + const int head_idx = lse_offset / (params.b * params.seqlen_q); + const float sink_val_exp = params.learnable_sink_ptr == nullptr ? 0.f : __expf(reinterpret_cast(params.learnable_sink_ptr)[head_idx] - lse_max); + }else{ + // LSE is written as (b, h, seqlen_q). + const int head_idx = (lse_offset % (params.h * params.seqlen_q)) / params.seqlen_q; + const float sink_val_exp = params.learnable_sink_ptr == nullptr ? 0.f : __expf(reinterpret_cast(params.learnable_sink_ptr)[head_idx] - lse_max); + } + lse_sum += sink_val_exp; + } + } + // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : __logf(lse_sum) + lse_max; // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } - if constexpr (Has_sink) { - #pragma unroll - for (int l = 0; l < kNLsePerThread; ++l) { - const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; - const int col = tidx / kRowsPerLoadTranspose; - if (row < params.num_splits && col < kBlockM) { - const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose; - if (params.unpadded_lse) { - // LSE is written as (h, seqlen_q, b) or (h, b, seqlen_q). - if (lse_offset < lse_size) { - const int head_idx = lse_offset / (params.b * params.seqlen_q); - const float lse_logsum_sink = __expf(lse_logsum); - const float sink_val_exp = params.learnable_sink_ptr == nullptr ? 0.f : __expf(reinterpret_cast(params.learnable_sink_ptr)[head_idx]); - lse_logsum = __logf(lse_logsum_sink + sink_val_exp);; - } - } else { - // LSE is written as (b, h, seqlen_q). - const int head_idx = (lse_offset % (params.h * params.seqlen_q)) / params.seqlen_q; - const float lse_logsum_sink = __expf(lse_logsum); - const float sink_val_exp = params.learnable_sink_ptr == nullptr ? 0.f : __expf(reinterpret_cast(params.learnable_sink_ptr)[head_idx]); - lse_logsum = __logf(lse_logsum_sink + sink_val_exp); - } - } - } - } if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { if (params.unpadded_lse) { const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose; From 594365e8e5dff41e2c8140887691f506c843f5ca Mon Sep 17 00:00:00 2001 From: Shaohong Fu Date: Fri, 5 Sep 2025 09:49:08 +0800 Subject: [PATCH 21/29] Debug for sink calculation. --- csrc/flash_attn/src/flash_fwd_kernel.h | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index a239a46d99d..a71e47828b7 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -1198,20 +1198,21 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { lse_sum = Allreduce::run(lse_sum, sum_op); if constexpr(Has_sink){ - const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int row = tidx % kRowsPerLoadTranspose; const int col = tidx / kRowsPerLoadTranspose; if (row < params.num_splits && col < kBlockM) { - const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose; + const index_t lse_offset = row_offset_lse + col; if(params.unpadded_lse){ // LSE is written as (h, seqlen_q, b) or (h, b, seqlen_q). const int head_idx = lse_offset / (params.b * params.seqlen_q); const float sink_val_exp = params.learnable_sink_ptr == nullptr ? 0.f : __expf(reinterpret_cast(params.learnable_sink_ptr)[head_idx] - lse_max); + lse_sum += sink_val_exp; }else{ // LSE is written as (b, h, seqlen_q). const int head_idx = (lse_offset % (params.h * params.seqlen_q)) / params.seqlen_q; const float sink_val_exp = params.learnable_sink_ptr == nullptr ? 0.f : __expf(reinterpret_cast(params.learnable_sink_ptr)[head_idx] - lse_max); + lse_sum += sink_val_exp; } - lse_sum += sink_val_exp; } } From 0a4ebacefde397a7605acda677af8ead3a8cc78d Mon Sep 17 00:00:00 2001 From: jerryao Date: Wed, 10 Sep 2025 19:24:01 +0800 Subject: [PATCH 22/29] Fix some bugs. --- csrc/flash_attn/src/flash_bwd_kernel.h | 28 ++- csrc/flash_attn/src/flash_fwd_kernel.h | 14 +- tests/test_flash_attn.py | 239 +++++++++++-------------- 3 files changed, 130 insertions(+), 151 deletions(-) diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index 1b77f979e90..a2ee466e535 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -86,7 +86,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Shared memory. extern __shared__ char smem_[]; - float dsink_val = 0.f; + double dsink_val = 0.0L; + + const float sink_val = !Has_sink || params.learnable_sink_ptr == nullptr ? -INFINITY : reinterpret_cast(params.learnable_sink_ptr)[bidh]; // The thread index. const int tidx = threadIdx.x; @@ -581,12 +583,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (row=(2, MMA_N), col=(2, MMA_N)) Tensor dS = make_tensor(acc_dp.data(), scores.layout()); - auto pointwise_mult = [](float p, float dp, float d) { + auto pointwise_mult = [](double p, double dp, double d) { return p * (!Is_dropout || p >= 0 ? dp - d : d); }; #pragma unroll for (int mi = 0; mi < size<0>(dS); ++mi) { - float dsink_val_cols = 0.f; + double dsink_val_cols = 0.0L; #pragma unroll for (int ni = 0; ni < size<1>(dS); ++ni) { if constexpr (Has_sink) { dsink_val_cols += pointwise_mult(scores(mi, ni), dS(mi, ni), 0.f); } @@ -594,7 +596,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in if constexpr (Is_softcap) { scaled_ds *= dtanh(mi, ni); } dS(mi, ni) = scaled_ds; } - if constexpr (Has_sink) { dsink_val += dsink_val_cols / expf(lse(mi)); } + if constexpr (Has_sink) { dsink_val += dsink_val_cols * expf(sink_val - lse(mi)); } } // if (cute::thread0()) { print(dS); } @@ -796,12 +798,22 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in ); if constexpr (Has_sink) { - const float sink_val = !Has_sink || params.learnable_sink_ptr == nullptr ? -INFINITY : reinterpret_cast(params.learnable_sink_ptr)[bidh]; - SumOp sum_op; + SumOp sum_op; + + __shared__ double dsink_block_sum; + if (tidx == 0) dsink_block_sum = 0.0; + __syncthreads(); + dsink_val = Allreduce<32>::run(dsink_val, sum_op); - if (tidx % 32 == 0 && params.dsink_ptr != nullptr) { + + if (tidx % 32 == 0) { + atomicAdd(&dsink_block_sum, dsink_val); + } + __syncthreads(); + + if (tidx == 0 && params.dsink_ptr != nullptr) { float* dsink_ptr = reinterpret_cast(params.dsink_ptr); - float val = -dsink_val * exp2f(sink_val * float(M_LOG2E)) * params.rp_dropout; + float val = -dsink_block_sum * params.rp_dropout; atomicAdd(dsink_ptr + bidh, val); } } diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index a71e47828b7..1b02701cbe7 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -997,8 +997,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Epilogue - // We fix the lse for sink in combine_attn_seqk_parallel. - Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); + // We fix the lse for sink in combine_attn_seqk_parallel if Split is true. + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); // if (cute::thread0()) { print(lse); } Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) @@ -1196,18 +1196,18 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += __expf(lse_accum(l) - lse_max); } SumOp sum_op; lse_sum = Allreduce::run(lse_sum, sum_op); - - if constexpr(Has_sink){ + + if constexpr(Has_sink) { const int row = tidx % kRowsPerLoadTranspose; const int col = tidx / kRowsPerLoadTranspose; if (row < params.num_splits && col < kBlockM) { const index_t lse_offset = row_offset_lse + col; - if(params.unpadded_lse){ + if (params.unpadded_lse) { // LSE is written as (h, seqlen_q, b) or (h, b, seqlen_q). const int head_idx = lse_offset / (params.b * params.seqlen_q); const float sink_val_exp = params.learnable_sink_ptr == nullptr ? 0.f : __expf(reinterpret_cast(params.learnable_sink_ptr)[head_idx] - lse_max); lse_sum += sink_val_exp; - }else{ + } else { // LSE is written as (b, h, seqlen_q). const int head_idx = (lse_offset % (params.h * params.seqlen_q)) / params.seqlen_q; const float sink_val_exp = params.learnable_sink_ptr == nullptr ? 0.f : __expf(reinterpret_cast(params.learnable_sink_ptr)[head_idx] - lse_max); @@ -1215,7 +1215,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { } } } - + // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : __logf(lse_sum) + lse_max; diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 6ce288e1859..b7407295fa8 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -288,13 +288,11 @@ def attention_ref( if learnable_sink is None: attention = torch.softmax(scores, dim=-1).to(v.dtype) else: - scores_fp32 = scores.to(torch.float32) - logits_max = torch.amax(scores_fp32, dim=-1, keepdim=True) - learnable_sink = rearrange(learnable_sink, "h -> h 1 1") - logits_or_sinks_max = torch.maximum(learnable_sink, logits_max) - unnormalized_scores = torch.exp(scores_fp32 - logits_or_sinks_max) - normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp(learnable_sink - logits_or_sinks_max) - attention = (unnormalized_scores / normalizer).to(v.dtype) + sinks = repeat(learnable_sink, 'h -> b h n 1', b=scores.shape[0], n=scores.shape[2]) + scores = torch.cat([scores, sinks], dim=-1) + attention = torch.softmax(scores, dim=-1).to(v.dtype) + attention = attention[..., :-1] + # Some rows might be completely masked out so we fill them with zero instead of NaN if window_size[0] >= 0 or window_size[1] >= 0: attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) @@ -527,7 +525,7 @@ def normalize_flash_attn_S( lse = torch.logsumexp(lse_block, dim=-1) scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1) else: - sinks = learnable_sink.reshape(1, -1, 1, 1).expand(scores.shape[0], -1, scores.shape[2], -1) + sinks = repeat(learnable_sink, 'h -> b h n 1', b=scores.shape[0], n=scores.shape[2]) lse = torch.logsumexp(torch.cat([lse_block, sinks], dim=-1), dim=-1) scores_max_block = torch.stack([torch.amax(torch.cat([s, sinks], dim=-1), dim=-1) for s in scores_block], dim=-1) # lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf @@ -824,8 +822,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() if has_learnable_sink: - tol_multiplier = 7 if dtype == torch.float16 and dropout_p > 0.0 else 3 - assert (dsink - dsink_ref).abs().max().item() <= tol_multiplier * (dsink_pt - dsink_ref).abs().max().item() + assert (dsink - dsink_ref).abs().max().item() <= 2 * (dsink_pt - dsink_ref).abs().max().item() @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @@ -994,8 +991,7 @@ def test_flash_attn_varlen_qkvpacked( if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() if has_learnable_sink: - tol_multiplier = 7 if dtype == torch.float16 and dropout_p > 0.0 else 3 - assert (dsink - dsink_ref).abs().max().item() <= tol_multiplier * (dsink_pt - dsink_ref).abs().max().item() + assert (dsink - dsink_ref).abs().max().item() <= 2 * (dsink_pt - dsink_ref).abs().max().item() @pytest.mark.parametrize("kvpacked", [True, False]) @@ -1366,10 +1362,6 @@ def test_flash_attn_varlen_output( v = torch.randn( batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True ) - if has_learnable_sink: - learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) - else: - learnable_sink = None query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full') @@ -1380,7 +1372,10 @@ def test_flash_attn_varlen_output( ) else: alibi_slopes, attn_bias = None, None - + if has_learnable_sink: + learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) + else: + learnable_sink = None if kvpacked: ( q_unpad, @@ -1476,6 +1471,7 @@ def test_flash_attn_varlen_output( dropout_p > 0.0, causal=causal, window_size=window_size, + learnable_sink=learnable_sink, ) dropout_fraction = get_dropout_fraction( dropout_mask, @@ -1499,6 +1495,7 @@ def test_flash_attn_varlen_output( dropout_mask, causal=causal, window_size=window_size, + learnable_sink=learnable_sink, softcap=softcap, ) out_pt, attn_pt = attention_kvpacked_ref( @@ -1511,6 +1508,7 @@ def test_flash_attn_varlen_output( dropout_mask, causal=causal, window_size=window_size, + learnable_sink=learnable_sink, softcap=softcap, upcast=False, reorder_ops=True, @@ -1627,6 +1625,8 @@ def test_flash_attn_varlen_output( assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() + if has_learnable_sink: + assert (dsink - dsink_ref).abs().max().item() <= 3 * (dsink_pt - dsink_ref).abs().max().item() @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @@ -1656,7 +1656,7 @@ def test_flash_attn_varlen_output( (1023, 1024), ], ) -@pytest.mark.parametrize("has_learnable_sink", [True, False]) +@pytest.mark.parametrize("has_learnable_sink", [False, True]) def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype, has_learnable_sink): if ( max(seqlen_q, seqlen_k) >= 2048 @@ -1750,7 +1750,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype, has_ assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5 assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 if has_learnable_sink: - assert (dsink - dsink_ref).abs().max().item() <= 3 * (dsink_pt - dsink_ref).abs().max().item() + 1e-5 + assert (dsink - dsink_ref).abs().max().item() <= 2 * (dsink_pt - dsink_ref).abs().max().item() + 1e-5 @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @@ -1783,7 +1783,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype, has_ # TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512]) # @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)]) -@pytest.mark.parametrize("has_learnable_sink", [True, False]) +@pytest.mark.parametrize("has_learnable_sink", [False, True]) def test_flash_attn_varlen_causal( seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, has_learnable_sink, dtype ): @@ -1977,6 +1977,8 @@ def test_flash_attn_varlen_causal( def test_flash_attn_splitkv( seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, deterministic, dtype, has_learnable_sink ): + if alibi and has_learnable_sink: + pytest.skip("Alibi and learnable sink not supported together") if swap_sq_sk: seqlen_q, seqlen_k = seqlen_k, seqlen_q device = "cuda" @@ -2063,13 +2065,13 @@ def test_flash_attn_splitkv( print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") if has_learnable_sink: - dsink, = torch.autograd.grad(out, (learnable_sink,), g) - dsink_ref, = torch.autograd.grad(out_ref, (learnable_sink,), g) - dsink_pt, = torch.autograd.grad(out_pt, (learnable_sink,), g) - print(f"dSink max diff: {(dsink - dsink_ref).abs().max().item()}") - print(f"dSink mean diff: {(dsink - dsink_ref).abs().mean().item()}") - print(f"dSink Pytorch max diff: {(dsink_pt - dsink_ref).abs().max().item()}") - print(f"dSink Pytorch mean diff: {(dsink_pt - dsink_ref).abs().mean().item()}") + dsink, = torch.autograd.grad(out, (learnable_sink,), g) + dsink_ref, = torch.autograd.grad(out_ref, (learnable_sink,), g) + dsink_pt, = torch.autograd.grad(out_pt, (learnable_sink,), g) + print(f"dSink max diff: {(dsink - dsink_ref).abs().max().item()}") + print(f"dSink mean diff: {(dsink - dsink_ref).abs().mean().item()}") + print(f"dSink Pytorch max diff: {(dsink_pt - dsink_ref).abs().max().item()}") + print(f"dSink Pytorch mean diff: {(dsink_pt - dsink_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. @@ -2132,6 +2134,7 @@ def test_flash_attn_splitkv( ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +@pytest.mark.parametrize("has_learnable_sink", [True]) def test_flash_attn_kvcache( seqlen_q, seqlen_k, @@ -2149,6 +2152,7 @@ def test_flash_attn_kvcache( mha_type, num_splits, dtype, + has_learnable_sink ): if seqlen_q > seqlen_k and new_kv: pytest.skip() @@ -2158,6 +2162,9 @@ def test_flash_attn_kvcache( pytest.skip() if has_leftpad and paged_kv_block_size is not None: pytest.skip() + if alibi and has_learnable_sink: + pytest.skip("Alibi and learnable sink not supported together") + device = "cuda" # set seed torch.random.manual_seed(0) @@ -2229,6 +2236,10 @@ def test_flash_attn_kvcache( ) else: alibi_slopes, attn_bias = None, None + if has_learnable_sink: + learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) + else: + learnable_sink = None # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) if rotary_dim > 0: angle = ( @@ -2296,6 +2307,7 @@ def test_flash_attn_kvcache( window_size=window_size, rotary_interleaved=rotary_interleaved, alibi_slopes=alibi_slopes, + learnable_sink=learnable_sink, num_splits=num_splits, ) # out = flash_attn_with_kvcache( @@ -2319,6 +2331,7 @@ def test_flash_attn_kvcache( None, causal=causal, window_size=window_size, + learnable_sink=learnable_sink, key_leftpad=cache_leftpad, ) out_pt, _ = attention_ref( @@ -2332,6 +2345,7 @@ def test_flash_attn_kvcache( None, causal=causal, window_size=window_size, + learnable_sink=learnable_sink, upcast=False, reorder_ops=True, key_leftpad=cache_leftpad, @@ -2424,7 +2438,8 @@ def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, ) @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) # @pytest.mark.parametrize("dropout_p", [0.0]) -def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dtype): +@pytest.mark.parametrize("has_learnable_sink", [False, True]) +def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dtype, has_learnable_sink): device = "cuda" # set seed torch.random.manual_seed(0) @@ -2433,21 +2448,28 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dty q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + if has_learnable_sink: + learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) + else: + learnable_sink = None torch.random.manual_seed(42) - out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True) + out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True, learnable_sink=learnable_sink) g = torch.randn_like(out0) if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): ( dq0, dk0, dv0, - ) = torch.autograd.grad(out0, (q, k, v), g) + ) = torch.autograd.grad(out0, (q, k, v), g, retain_graph=has_learnable_sink) # Numerical error if we just do any arithmetic on dq dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item() + if has_learnable_sink: + dlearnable_sink0, = torch.autograd.grad(out0, (learnable_sink,), g) + dsink_atol = 2 * ((dlearnable_sink0 + 0.3 - 0.3) - dlearnable_sink0).abs().max().item() for i in range(250): torch.random.manual_seed(42) - out, lse, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True) + out, lse, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True, learnable_sink=learnable_sink) assert torch.equal(out, out0) assert torch.equal(lse, lse0) @@ -2456,13 +2478,19 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dty dq, dk, dv, - ) = torch.autograd.grad(out, (q, k, v), g) + ) = torch.autograd.grad(out, (q, k, v), g, retain_graph=has_learnable_sink) dq_equal = torch.allclose(dq, dq0, atol=dq_atol) if not dq_equal: print(f"Iter {i}, {dq_atol = }, dQ max diff: {(dq - dq0).abs().max().item()}") assert torch.equal(dv, dv0) assert torch.equal(dk, dk0) assert dq_equal + if has_learnable_sink: + dlearnable_sink, = torch.autograd.grad(out, (learnable_sink,), g) + dsink_equal = torch.allclose(dlearnable_sink, dlearnable_sink0, atol=dsink_atol) + if not dsink_equal: + print(f"Iter {i}, {dsink_atol = }, dSink max diff: {(dlearnable_sink - dlearnable_sink0).abs().max().item()}") + assert dsink_equal @pytest.mark.parametrize("dtype", [torch.float16]) @@ -2472,7 +2500,8 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dty # @pytest.mark.parametrize('d', [16]) @pytest.mark.parametrize("seqlen", [1, 2, 5, 17, 128]) # @pytest.mark.parametrize('seqlen', [2]) -def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype): +@pytest.mark.parametrize("has_learnable_sink", [False, True]) +def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype, has_learnable_sink): """We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ, in the case where seqlen % 128 != 0. """ @@ -2489,18 +2518,24 @@ def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype): q.requires_grad_(True) k.requires_grad_(True) v.requires_grad_(True) - out = flash_attn_func(q, k, v, causal=causal) + if has_learnable_sink: + learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) + else: + learnable_sink = None + out = flash_attn_func(q, k, v, causal=causal, learnable_sink=learnable_sink) g = torch.randn_like(out) out.backward(g) q_pt = q.detach().clone().requires_grad_(True) k_pt = k.detach().clone().requires_grad_(True) v_pt = v.detach().clone().requires_grad_(True) - out_pt, _ = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True) + learnable_sink_pt = learnable_sink.detach().clone().requires_grad_(True) if has_learnable_sink else None + out_pt, _ = attention_ref(q_pt, k_pt, v_pt, causal=causal, learnable_sink=learnable_sink_pt, upcast=False, reorder_ops=True) out_pt.backward(g) q_ref = q.detach().clone().requires_grad_(True) k_ref = k.detach().clone().requires_grad_(True) v_ref = v.detach().clone().requires_grad_(True) - out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal) + learnable_sink_ref = learnable_sink.detach().clone().requires_grad_(True) if has_learnable_sink else None + out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal, learnable_sink=learnable_sink_ref) out_ref.backward(g) print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}") print(f"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}") @@ -2518,6 +2553,12 @@ def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype): assert (v.grad - v_ref.grad).abs().max().item() <= 5 * ( v_pt.grad - v_ref.grad ).abs().max().item() + 1e-3 + if has_learnable_sink: + print(f"dSink max diff: {(learnable_sink.grad - learnable_sink_ref.grad).abs().max().item()}") + print(f"dSink Pytorch max diff: {(learnable_sink_pt.grad - learnable_sink_ref.grad).abs().max().item()}") + assert (learnable_sink.grad - learnable_sink_ref.grad).abs().max().item() <= 2 * ( + learnable_sink_pt.grad - learnable_sink_ref.grad + ).abs().max().item() @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @@ -2528,7 +2569,8 @@ def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype): # @pytest.mark.parametrize('d', [64]) @pytest.mark.parametrize("seqlen", [97, 128, 200, 256]) # @pytest.mark.parametrize('seqlen', [128]) -def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype): +@pytest.mark.parametrize("has_learnable_sink", [False, True]) +def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype, has_learnable_sink): """We previously had a bug where we were using the wrong strides of dout, which shows up when dout is not contiguous. """ @@ -2541,20 +2583,26 @@ def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype): torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda", requires_grad=True) for _ in range(3) ] - out = rearrange(flash_attn_func(q, k, v, causal=causal), "b s ... -> s b ...") + if has_learnable_sink: + learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) + else: + learnable_sink = None + out = rearrange(flash_attn_func(q, k, v, causal=causal, learnable_sink=learnable_sink), "b s ... -> s b ...") # So g is not contiguous g = torch.randn(seqlen, 2 * batch_size, nheads, d, dtype=dtype, device="cuda")[:, ::2] out.backward(g) q_pt = q.detach().clone().requires_grad_(True) k_pt = k.detach().clone().requires_grad_(True) v_pt = v.detach().clone().requires_grad_(True) - out_pt, attn_pt = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True) + learnable_sink_pt = learnable_sink.detach().clone().requires_grad_(True) if has_learnable_sink else None + out_pt, attn_pt = attention_ref(q_pt, k_pt, v_pt, causal=causal, learnable_sink=learnable_sink_pt, upcast=False, reorder_ops=True) out_pt = rearrange(out_pt, "b s ... -> s b ...") out_pt.backward(g) q_ref = q.detach().clone().requires_grad_(True) k_ref = k.detach().clone().requires_grad_(True) v_ref = v.detach().clone().requires_grad_(True) - out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal) + learnable_sink_ref = learnable_sink.detach().clone().requires_grad_(True) if has_learnable_sink else None + out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal, learnable_sink=learnable_sink_ref) out_ref = rearrange(out_ref, "b s ... -> s b ...") out_ref.backward(g) print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}") @@ -2573,6 +2621,12 @@ def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype): assert (v.grad - v_ref.grad).abs().max().item() <= 2 * ( v_pt.grad - v_ref.grad ).abs().max().item() + if has_learnable_sink: + print(f"dSink max diff: {(learnable_sink.grad - learnable_sink_ref.grad).abs().max().item()}") + print(f"dSink Pytorch max diff: {(learnable_sink_pt.grad - learnable_sink_ref.grad).abs().max().item()}") + assert (learnable_sink.grad - learnable_sink_ref.grad).abs().max().item() <= 2 * ( + learnable_sink_pt.grad - learnable_sink_ref.grad + ).abs().max().item() @pytest.mark.parametrize("dtype", [torch.float16]) @@ -2580,7 +2634,8 @@ def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype): # @pytest.mark.parametrize('causal', [False]) @pytest.mark.parametrize("d", [16, 32, 64]) # @pytest.mark.parametrize('d', [16]) -def test_flash_attn_bwd_varlen_overflow(d, causal, dtype): +@pytest.mark.parametrize("has_learnable_sink", [False, True]) +def test_flash_attn_bwd_varlen_overflow(d, causal, dtype, has_learnable_sink): """We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ, in the case where seqlen % 128 != 0 or varlen. """ @@ -2598,15 +2653,20 @@ def test_flash_attn_bwd_varlen_overflow(d, causal, dtype): q.requires_grad_(True) k.requires_grad_(True) v.requires_grad_(True) + if has_learnable_sink: + learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) + else: + learnable_sink = None - out = flash_attn_varlen_func(q, k, v, q_cuseqlen, k_cuseqlen, Mq, Mk, causal=causal) + out = flash_attn_varlen_func(q, k, v, q_cuseqlen, k_cuseqlen, Mq, Mk, causal=causal, learnable_sink=learnable_sink) g = torch.randn_like(out) out.backward(g) assert not q.grad.isnan().any() assert not k.grad.isnan().any() assert not v.grad.isnan().any() - + if has_learnable_sink: + assert not learnable_sink.grad.isnan().any() @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @@ -2751,96 +2811,3 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus assert torch.equal(dv, dv0) assert torch.equal(dk, dk0) assert torch.equal(dq, dq0) - -@pytest.mark.parametrize( - "seqlen_q,seqlen_k,swap_sq_sk,d,local,sink,dtype", - [ - (1, 239, True, 64, True, True, torch.bfloat16), - (127, 513, True, 128, True, True, torch.float16), - ], -) -def test_flash_attn_sink_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, sink, dtype): - if ( - max(seqlen_q, seqlen_k) >= 2048 - and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 - ): - pytest.skip() # Reference implementation OOM - if swap_sq_sk: - seqlen_q, seqlen_k = seqlen_k, seqlen_q - device = "cuda" - causal = True - # set seed - torch.random.manual_seed(0) - batch_size = 8 - nheads = 9 - num_key_value_groups = 1 - window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) - q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) - k = torch.randn(batch_size, seqlen_k, nheads // num_key_value_groups, d, device=device, dtype=dtype, requires_grad=True) - v = torch.randn(batch_size, seqlen_k, nheads // num_key_value_groups, d, device=device, dtype=dtype, requires_grad=True) - q_ref = q.transpose(1, 2) - k_ref = k.transpose(1, 2) - v_ref = v.transpose(1, 2) - sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) - out = flash_attn_func(q, k, v, 0.0, softmax_scale=d**-0.5, causal=causal, window_size=window_size, learnable_sink=sink) - - attention_mask = get_attention_mask(seqlen_q, seqlen_k, causal, device, window_size) - if attention_mask is not None: - attention_mask = attention_mask.expand(batch_size, nheads, -1, -1).to(dtype) - - out_ref, _ = attention_sink_ref(q_ref.float(), k_ref.float(), v_ref.float(), sink.float(), attention_mask.float(), d**-0.5, 0.0, num_key_value_groups) - out_ref = out_ref.to(dtype) - out_pt, _ = attention_sink_ref(q_ref, k_ref, v_ref, sink.to(dtype), attention_mask, d**-0.5, 0.0, num_key_value_groups) - - print(f"Output max diff: {(out - out_ref).abs().max().detach().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().detach().item()}") - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().detach().item()}") - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().detach().item()}") - - g = torch.randn_like(out) - do_o = (g.float() * out.float()).sum(-1) - ( - dq, - dk, - dv, - dsink, - ) = torch.autograd.grad(out, (q, k, v, sink), g) - ( - dq_ref, - dk_ref, - dv_ref, - dsink_ref, - ) = torch.autograd.grad(out_ref, (q, k, v, sink), g) - ( - dq_pt, - dk_pt, - dv_pt, - dsink_pt, - ) = torch.autograd.grad(out_pt, (q, k, v, sink), g) - print(f"dQ max diff: {(dq - dq_ref).abs().max().detach().item()}") - print(f"dK max diff: {(dk - dk_ref).abs().max().detach().item()}") - print(f"dV max diff: {(dv - dv_ref).abs().max().detach().item()}") - print(f"dS max diff: {(dsink - dsink_ref).abs().max().detach().item()}") - print(f"dQ mean diff: {(dq - dq_ref).abs().mean().detach().item()}") - print(f"dK mean diff: {(dk - dk_ref).abs().mean().detach().item()}") - print(f"dV mean diff: {(dv - dv_ref).abs().mean().detach().item()}") - print(f"dS mean diff: {(dsink - dsink_ref).abs().mean().detach().item()}") - print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().detach().item()}") - print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().detach().item()}") - print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().detach().item()}") - print(f"dS Pytorch max diff: {(dsink_pt - dsink_ref).abs().max().detach().item()}") - print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().detach().item()}") - print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().detach().item()}") - print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().detach().item()}") - print(f"dS Pytorch mean diff: {(dsink_pt - dsink_ref).abs().mean().detach().item()}") - print(f"dS Relative error: {torch.mean(torch.abs(dsink - dsink_ref) / (torch.abs(dsink_ref) + 1e-8)).detach()}") - print(f"dS Pytorch relative error: {torch.mean(torch.abs(dsink_pt - dsink_ref) / (torch.abs(dsink_ref) + 1e-8)).detach()}") - - # Check that FlashAttention's numerical error is at most twice the numerical error - # of a Pytorch implementation. - assert (out - out_ref).abs().max().detach().item() <= 2 * (out_pt - out_ref).abs().max().detach().item() + 1e-5 - - assert (dq - dq_ref).abs().max().detach().item() <= 2 * (dq_pt - dq_ref).abs().max().detach().item() + 1e-5 - assert (dk - dk_ref).abs().max().detach().item() <= 2 * (dk_pt - dk_ref).abs().max().detach().item() + 1e-5 - assert (dv - dv_ref).abs().max().detach().item() <= 2 * (dv_pt - dv_ref).abs().max().detach().item() + 1e-5 - assert (dsink - dsink_ref).abs().max().detach().item() <= 2 * (dsink_pt - dsink_ref).abs().max().detach().item() + 1e-5 From 067cf4746dca60e31532d2f3b5eb7ad9be89d2f7 Mon Sep 17 00:00:00 2001 From: jerryao Date: Thu, 11 Sep 2025 15:33:47 +0800 Subject: [PATCH 23/29] Update tests. --- tests/test_flash_attn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index b7407295fa8..73d6cdd3fa3 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -288,6 +288,7 @@ def attention_ref( if learnable_sink is None: attention = torch.softmax(scores, dim=-1).to(v.dtype) else: + learnable_sink = learnable_sink.to(q.dtype) sinks = repeat(learnable_sink, 'h -> b h n 1', b=scores.shape[0], n=scores.shape[2]) scores = torch.cat([scores, sinks], dim=-1) attention = torch.softmax(scores, dim=-1).to(v.dtype) @@ -525,6 +526,7 @@ def normalize_flash_attn_S( lse = torch.logsumexp(lse_block, dim=-1) scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1) else: + learnable_sink = learnable_sink.to(q.dtype) sinks = repeat(learnable_sink, 'h -> b h n 1', b=scores.shape[0], n=scores.shape[2]) lse = torch.logsumexp(torch.cat([lse_block, sinks], dim=-1), dim=-1) scores_max_block = torch.stack([torch.amax(torch.cat([s, sinks], dim=-1), dim=-1) for s in scores_block], dim=-1) From 408083e08ebab944953588bf8da106b6840c074a Mon Sep 17 00:00:00 2001 From: jerryao Date: Sat, 13 Sep 2025 22:12:01 +0800 Subject: [PATCH 24/29] Fix. --- csrc/flash_attn/src/flash_fwd_kernel.h | 11 +++++------ csrc/flash_attn/src/softmax.h | 4 ++-- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 1b02701cbe7..8ee81442a96 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -341,7 +341,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // TODO: when we have key_padding_mask we'll need to Check_inf masking_step == 0 - ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) + ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2, params.scale_softmax) : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); // Convert acc_s from fp32 to fp16/bf16 @@ -535,7 +535,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); } - if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0 // We exit early and write 0 to gOaccum and -inf to gLSEaccum. // Otherwise we might read OOB elements from gK and gV, @@ -917,7 +916,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // We have key_padding_mask so we'll need to Check_inf masking_step == 0 - ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) + ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2, params.scale_softmax) : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } @@ -1191,9 +1190,9 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { MaxOp max_op; lse_max = Allreduce::run(lse_max, max_op); lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf - float lse_sum = __expf(lse_accum(0) - lse_max); + float lse_sum = expf(lse_accum(0) - lse_max); #pragma unroll - for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += __expf(lse_accum(l) - lse_max); } + for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); } SumOp sum_op; lse_sum = Allreduce::run(lse_sum, sum_op); @@ -1218,7 +1217,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. - ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : __logf(lse_sum) + lse_max; + ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max; // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { diff --git a/csrc/flash_attn/src/softmax.h b/csrc/flash_attn/src/softmax.h index b4b0f96d31c..a419a8f72f0 100644 --- a/csrc/flash_attn/src/softmax.h +++ b/csrc/flash_attn/src/softmax.h @@ -135,13 +135,13 @@ struct Softmax { __forceinline__ __device__ Softmax(const float sink_val = -INFINITY) : sink_val(sink_val) {}; template - __forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) { + __forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2, float softmax_scale=1.0) { // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(acc_s.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_s.layout())); static_assert(decltype(size<0>(scores))::value == kNRows); if (Is_first) { if constexpr (Has_sink) { - const float sink_scaled = (sink_val * float(M_LOG2E) / softmax_scale_log2); + const float sink_scaled = sink_val / softmax_scale; #pragma unroll for (int mi = 0; mi < size(row_max); ++mi) { row_max(mi) = sink_scaled; } FLASH_NAMESPACE::template reduce_max(scores, row_max); From f5dd38e9aead908490934558b57a53841308b76f Mon Sep 17 00:00:00 2001 From: jerryao Date: Mon, 15 Sep 2025 14:21:37 +0800 Subject: [PATCH 25/29] Fix bugs. --- csrc/flash_attn/src/flash_fwd_kernel.h | 8 +- .../src/flash_fwd_launch_template.h | 3 +- tests/test_flash_attn.py | 77 ++++--------------- 3 files changed, 20 insertions(+), 68 deletions(-) diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 8ee81442a96..9e96e9f487f 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -89,6 +89,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max); // } } + const float sink_val = !Has_sink || params.learnable_sink_ptr == nullptr ? -INFINITY : reinterpret_cast(params.learnable_sink_ptr)[bidh]; // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0. // Otherwise we might read OOB elements from gK and gV. if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) { @@ -122,7 +123,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi #pragma unroll for (int m = 0; m < size<1>(tOgO); ++m) { const int row = get<0>(tOcO(0, m, 0)); - if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; } + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = Has_sink ? sink_val : INFINITY; } } return; } @@ -282,7 +283,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi clear(acc_o); - const float sink_val = !Has_sink || params.learnable_sink_ptr == nullptr ? -INFINITY : reinterpret_cast(params.learnable_sink_ptr)[bidh]; FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax(sink_val); const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; @@ -996,8 +996,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Epilogue - // We fix the lse for sink in combine_attn_seqk_parallel if Split is true. - Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); // if (cute::thread0()) { print(lse); } Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) @@ -1219,7 +1218,6 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max; // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } - if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { if (params.unpadded_lse) { const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose; diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index 3150637a1e6..bce29444449 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -120,7 +120,8 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_splitkv_kernel; + // We fix the lse for sink in combine_attn_seqk_parallel if Split is true. + auto kernel = &flash_fwd_splitkv_kernel; // auto kernel = &flash_fwd_splitkv_kernel; // auto kernel = &flash_fwd_splitkv_kernel; if (smem_size >= 48 * 1024) { diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 73d6cdd3fa3..49f7ebed489 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -585,53 +585,6 @@ def get_dropout_fraction( return dropped.sum() / valid.sum() -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). - The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) - to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand( - batch, num_key_value_heads, n_rep, slen, head_dim - ) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def attention_sink_ref( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - sink: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - num_key_value_groups: int = 8, - **kwargs, -): - key_states = repeat_kv(key, num_key_value_groups) - value_states = repeat_kv(value, num_key_value_groups) - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - sinks = sink.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) - combined_logits = torch.cat([attn_weights, sinks], dim=-1) - # This was not in the original implementation and slightly affect results; - # it prevents overflow in BF16/FP16 when training with bsz>1 we clamp max values. - combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values - - probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) - scores = probs[..., :-1] # we drop the sink here - attn_weights = torch.nn.functional.dropout(scores, p=dropout, training=True) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - return attn_output, attn_weights - - def get_attention_mask( seqlen_q, seqlen_k, @@ -698,7 +651,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ else: alibi_slopes, attn_bias = None, None if has_learnable_sink: - learnable_sink = torch.randn(nheads, device=device, dtype=torch.float32, requires_grad=True) + learnable_sink = torch.randn(nheads, device=device, dtype=torch.float32, requires_grad=True) * 0.3 else: learnable_sink = None out, lse, S_dmask = flash_attn_qkvpacked_func( @@ -862,7 +815,7 @@ def test_flash_attn_varlen_qkvpacked( batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True ) if has_learnable_sink: - learnable_sink = torch.randn(nheads, device=device, dtype=torch.float32, requires_grad=True) + learnable_sink = torch.randn(nheads, device=device, dtype=torch.float32, requires_grad=True) * 0.3 else: learnable_sink = None @@ -993,7 +946,7 @@ def test_flash_attn_varlen_qkvpacked( if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() if has_learnable_sink: - assert (dsink - dsink_ref).abs().max().item() <= 2 * (dsink_pt - dsink_ref).abs().max().item() + assert (dsink - dsink_ref).abs().max().item() <= 3 * (dsink_pt - dsink_ref).abs().max().item() @pytest.mark.parametrize("kvpacked", [True, False]) @@ -1077,7 +1030,7 @@ def test_flash_attn_output( else: alibi_slopes, attn_bias = None, None if has_learnable_sink: - learnable_sink = torch.randn(nheads, device=device, dtype=torch.float32, requires_grad=True) + learnable_sink = torch.randn(nheads, device=device, dtype=torch.float32, requires_grad=True) * 0.3 else: learnable_sink = None if kvpacked: @@ -1375,7 +1328,7 @@ def test_flash_attn_varlen_output( else: alibi_slopes, attn_bias = None, None if has_learnable_sink: - learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) + learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) * 0.3 else: learnable_sink = None if kvpacked: @@ -1678,7 +1631,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype, has_ k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) if has_learnable_sink: - learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) + learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) * 0.3 else: learnable_sink = None out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, learnable_sink=learnable_sink) @@ -1818,7 +1771,7 @@ def test_flash_attn_varlen_causal( seqlen_k, paged_kv_block_size, batch_size, nheads, d, device, dtype ) if has_learnable_sink: - learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) + learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) * 0.3 else: learnable_sink = None query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") @@ -1998,7 +1951,7 @@ def test_flash_attn_splitkv( else: alibi_slopes, attn_bias = None, None if has_learnable_sink: - learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) + learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) * 0.3 else: learnable_sink = None out, lse, _ = flash_attn_func( @@ -2084,7 +2037,7 @@ def test_flash_attn_splitkv( assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4 assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4 if has_learnable_sink: - assert (dsink - dsink_ref).abs().max().item() <= mult * (dsink_pt - dsink_ref).abs().max().item() + 2e-4 + assert (dsink - dsink_ref).abs().max().item() <= 3 * (dsink_pt - dsink_ref).abs().max().item() + 2e-4 # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @@ -2136,7 +2089,7 @@ def test_flash_attn_splitkv( ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) -@pytest.mark.parametrize("has_learnable_sink", [True]) +@pytest.mark.parametrize("has_learnable_sink", [False, True]) def test_flash_attn_kvcache( seqlen_q, seqlen_k, @@ -2239,7 +2192,7 @@ def test_flash_attn_kvcache( else: alibi_slopes, attn_bias = None, None if has_learnable_sink: - learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) + learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) * 0.3 else: learnable_sink = None # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) @@ -2451,7 +2404,7 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dty k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) if has_learnable_sink: - learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) + learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) * 0.3 else: learnable_sink = None torch.random.manual_seed(42) @@ -2521,7 +2474,7 @@ def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype, has_learnable_sink): k.requires_grad_(True) v.requires_grad_(True) if has_learnable_sink: - learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) + learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) * 0.3 else: learnable_sink = None out = flash_attn_func(q, k, v, causal=causal, learnable_sink=learnable_sink) @@ -2586,7 +2539,7 @@ def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype, has_learnable_sink): for _ in range(3) ] if has_learnable_sink: - learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) + learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) * 0.3 else: learnable_sink = None out = rearrange(flash_attn_func(q, k, v, causal=causal, learnable_sink=learnable_sink), "b s ... -> s b ...") @@ -2656,7 +2609,7 @@ def test_flash_attn_bwd_varlen_overflow(d, causal, dtype, has_learnable_sink): k.requires_grad_(True) v.requires_grad_(True) if has_learnable_sink: - learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) + learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) * 0.3 else: learnable_sink = None From b0ed042411a7854904f9eac5e482d039b0fe61d6 Mon Sep 17 00:00:00 2001 From: jerryao Date: Tue, 16 Sep 2025 14:58:12 +0800 Subject: [PATCH 26/29] clean code. --- csrc/flash_attn/src/flash_bwd_kernel.h | 25 +++++++++++++------------ csrc/flash_attn/src/kernel_traits.h | 2 +- tests/test_flash_attn.py | 6 ++++-- 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index a2ee466e535..f373cb141c2 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -86,7 +86,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Shared memory. extern __shared__ char smem_[]; - double dsink_val = 0.0L; + float dsink_val = 0.0f; const float sink_val = !Has_sink || params.learnable_sink_ptr == nullptr ? -INFINITY : reinterpret_cast(params.learnable_sink_ptr)[bidh]; @@ -453,6 +453,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in clear(acc_dv); clear(acc_dk); + if constexpr (Has_sink) { + float* dsink_block_sum_ptr = reinterpret_cast(smem_ + Kernel_traits::kSmemSize1colblock - sizeof(float)); + if (tidx == 0) *dsink_block_sum_ptr = 0.0f; + } + const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; FLASH_NAMESPACE::Alibi alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q); @@ -583,15 +588,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (row=(2, MMA_N), col=(2, MMA_N)) Tensor dS = make_tensor(acc_dp.data(), scores.layout()); - auto pointwise_mult = [](double p, double dp, double d) { + auto pointwise_mult = [](float p, float dp, float d) { return p * (!Is_dropout || p >= 0 ? dp - d : d); }; #pragma unroll for (int mi = 0; mi < size<0>(dS); ++mi) { - double dsink_val_cols = 0.0L; + float dsink_val_cols = 0.0f; #pragma unroll for (int ni = 0; ni < size<1>(dS); ++ni) { - if constexpr (Has_sink) { dsink_val_cols += pointwise_mult(scores(mi, ni), dS(mi, ni), 0.f); } + if constexpr (Has_sink) { dsink_val_cols += pointwise_mult(scores(mi, ni), dS(mi, ni), 0.0f); } float scaled_ds = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi)); if constexpr (Is_softcap) { scaled_ds *= dtanh(mi, ni); } dS(mi, ni) = scaled_ds; @@ -798,22 +803,18 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in ); if constexpr (Has_sink) { - SumOp sum_op; - - __shared__ double dsink_block_sum; - if (tidx == 0) dsink_block_sum = 0.0; - __syncthreads(); - + SumOp sum_op; + float* dsink_block_sum_ptr = reinterpret_cast(smem_ + Kernel_traits::kSmemSize1colblock - sizeof(float)); dsink_val = Allreduce<32>::run(dsink_val, sum_op); if (tidx % 32 == 0) { - atomicAdd(&dsink_block_sum, dsink_val); + atomicAdd(dsink_block_sum_ptr, dsink_val); } __syncthreads(); if (tidx == 0 && params.dsink_ptr != nullptr) { float* dsink_ptr = reinterpret_cast(params.dsink_ptr); - float val = -dsink_block_sum * params.rp_dropout; + float val = -(*dsink_block_sum_ptr) * params.rp_dropout; atomicAdd(dsink_ptr + bidh, val); } } diff --git a/csrc/flash_attn/src/kernel_traits.h b/csrc/flash_attn/src/kernel_traits.h index 8c0897488dc..49c824fa987 100644 --- a/csrc/flash_attn/src/kernel_traits.h +++ b/csrc/flash_attn/src/kernel_traits.h @@ -287,7 +287,7 @@ struct Flash_bwd_kernel_traits : public Base { static constexpr int kSmemSize1colblock = kSmemQdOSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + kSmemPSize - : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)) + sizeof(float); // + sizeof(float) for dsink static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 49f7ebed489..2389b1d2590 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -1031,6 +1031,8 @@ def test_flash_attn_output( alibi_slopes, attn_bias = None, None if has_learnable_sink: learnable_sink = torch.randn(nheads, device=device, dtype=torch.float32, requires_grad=True) * 0.3 + if softcap > 0: + learnable_sink = learnable_sink * softcap else: learnable_sink = None if kvpacked: @@ -2032,7 +2034,7 @@ def test_flash_attn_splitkv( # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 - mult = 2 if not alibi else 8 + mult = 2 if not (alibi or (has_learnable_sink and dtype == torch.float16)) else 8 assert (dq - dq_ref).abs().max().item() <= mult * (dq_pt - dq_ref).abs().max().item() + 2e-4 assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4 assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4 @@ -2333,7 +2335,7 @@ def test_flash_attn_kvcache( )[:, :seqlen_k] assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) assert torch.equal(v_cache_select, v_cache_ref) - mult = 3 if not alibi else 5 + mult = 3 if not (alibi or (has_learnable_sink and dtype == torch.float16)) else 5 assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 From b8f99537fb0249fa34ebfe693d3464834a3d5a8b Mon Sep 17 00:00:00 2001 From: jerryao Date: Thu, 18 Sep 2025 20:13:18 +0800 Subject: [PATCH 27/29] Fix tests. --- tests/test_flash_attn.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 2389b1d2590..3236cb69443 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -1032,7 +1032,7 @@ def test_flash_attn_output( if has_learnable_sink: learnable_sink = torch.randn(nheads, device=device, dtype=torch.float32, requires_grad=True) * 0.3 if softcap > 0: - learnable_sink = learnable_sink * softcap + learnable_sink = learnable_sink * softcap * 2 else: learnable_sink = None if kvpacked: @@ -1242,7 +1242,8 @@ def test_flash_attn_output( assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() if has_learnable_sink: - assert (dsink - dsink_ref).abs().max().item() <= 3 * (dsink_pt - dsink_ref).abs().max().item() + mult = 3 if not (dtype == torch.float16 and softcap > 0) else softcap + assert (dsink - dsink_ref).abs().max().item() <= mult * (dsink_pt - dsink_ref).abs().max().item() + 1e-5 @pytest.mark.parametrize("kvpacked", [True, False]) @@ -2422,7 +2423,7 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dty dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item() if has_learnable_sink: dlearnable_sink0, = torch.autograd.grad(out0, (learnable_sink,), g) - dsink_atol = 2 * ((dlearnable_sink0 + 0.3 - 0.3) - dlearnable_sink0).abs().max().item() + dsink_atol = 2 * ((dlearnable_sink0 + 0.3 - 0.3) - dlearnable_sink0).abs().max().item() + 2e-4 for i in range(250): torch.random.manual_seed(42) @@ -2476,7 +2477,7 @@ def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype, has_learnable_sink): k.requires_grad_(True) v.requires_grad_(True) if has_learnable_sink: - learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) * 0.3 + learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) else: learnable_sink = None out = flash_attn_func(q, k, v, causal=causal, learnable_sink=learnable_sink) @@ -2541,7 +2542,7 @@ def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype, has_learnable_sink): for _ in range(3) ] if has_learnable_sink: - learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) * 0.3 + learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) else: learnable_sink = None out = rearrange(flash_attn_func(q, k, v, causal=causal, learnable_sink=learnable_sink), "b s ... -> s b ...") @@ -2611,7 +2612,7 @@ def test_flash_attn_bwd_varlen_overflow(d, causal, dtype, has_learnable_sink): k.requires_grad_(True) v.requires_grad_(True) if has_learnable_sink: - learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) * 0.3 + learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) else: learnable_sink = None From 9d099a0bbeb99fe0cfc9ecbff71aff454769a747 Mon Sep 17 00:00:00 2001 From: jerryao Date: Tue, 21 Oct 2025 10:07:23 +0800 Subject: [PATCH 28/29] Add new line. --- csrc/flash_attn/src/flash_bwd_kernel.h | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index f373cb141c2..8eee366c5b2 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -108,7 +108,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left, kBlockM)); } - const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + (m_block_max - 1) * kBlockM * params.q_row_stride + bidh * params.q_head_stride; const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) From 0d26f1e41fed09ee03801403164dd6ea5c26f6ad Mon Sep 17 00:00:00 2001 From: aoxy Date: Wed, 25 Feb 2026 14:50:59 +0800 Subject: [PATCH 29/29] Refactoring: Unifying sink-related parameter processing and optimizing kernel and interfaces --- benchmarks/benchmark_flash_attention.py | 2 +- csrc/flash_attn/flash_api.cpp | 188 ++++++++++------ csrc/flash_attn/src/flash.h | 6 +- csrc/flash_attn/src/flash_bwd_kernel.h | 25 ++- .../src/flash_bwd_launch_template.h | 10 +- csrc/flash_attn/src/flash_fwd_kernel.h | 30 +-- .../src/flash_fwd_launch_template.h | 12 +- csrc/flash_attn/src/softmax.h | 2 +- flash_attn/flash_attn_interface.py | 206 +++++++++--------- setup.py | 2 + tests/test_flash_attn.py | 181 +++++++-------- 11 files changed, 342 insertions(+), 322 deletions(-) diff --git a/benchmarks/benchmark_flash_attention.py b/benchmarks/benchmark_flash_attention.py index 7aa4fdc3a4a..a97668298fb 100644 --- a/benchmarks/benchmark_flash_attention.py +++ b/benchmarks/benchmark_flash_attention.py @@ -110,7 +110,7 @@ def time_fwd_bwd(func, *args, **kwargs): if "Flash2Sink" in methods: qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, requires_grad=True) - sink = torch.randn((nheads,), dtype=torch.float32, device=device, requires_grad=True) + sink = torch.randn((nheads,), dtype=dtype, device=device, requires_grad=True) f, b = time_fwd_bwd( flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, learnable_sink=sink, repeats=repeats, verbose=False ) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 22769a0e554..9eb4c73a2b9 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -347,32 +347,10 @@ void set_params_alibi(Flash_fwd_params ¶ms, std::optional &alibi #endif } -void set_params_sink(Flash_fwd_params ¶ms, const std::optional &learnable_sink_, int num_heads) { -#ifdef FLASHATTENTION_DISABLE_SINK - TORCH_CHECK(!learnable_sink_.has_value(), "This flash attention build does not support learnable sink."); - params.learnable_sink_ptr = nullptr; -#else - if (learnable_sink_.has_value()) { - // Make the compiler happy by forbidding Learnable sink and ALiBi to party together — - // mixing them causes a template explosion and very long compile times! - TORCH_CHECK(params.alibi_slopes_ptr == nullptr, "Learnable sink and ALiBi slopes cannot be used together"); - auto learnable_sink = learnable_sink_.value(); - TORCH_CHECK(learnable_sink.dtype() == torch::kFloat32, "Learnable sink must have dtype fp32"); - CHECK_DEVICE(learnable_sink); - TORCH_CHECK(learnable_sink.stride(-1) == 1, "Learnable sink tensor must have contiguous last dimension"); - CHECK_SHAPE(learnable_sink, num_heads); - params.learnable_sink_ptr = learnable_sink.data_ptr(); - } else { - params.learnable_sink_ptr = nullptr; - } -#endif -} - std::vector mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) - std::optional &learnable_sink_, // num_heads std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) std::optional &alibi_slopes_, // num_heads or batch_size x num_heads const float p_dropout, @@ -382,7 +360,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_mult int window_size_right, const float softcap, const bool return_softmax, - std::optional gen_) { + std::optional gen_, + std::optional learnable_sink_) { // Otherwise the kernel will be launched from cuda:0 device at::cuda::CUDAGuard device_guard{q.device()}; @@ -516,7 +495,18 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_mult } set_params_alibi(params, alibi_slopes_, batch_size, num_heads); - set_params_sink(params, learnable_sink_, num_heads); + + at::Tensor learnable_sink; + if (learnable_sink_.has_value()) { + learnable_sink = learnable_sink_.value().to(torch::kFloat32); + TORCH_CHECK(params.alibi_slopes_ptr == nullptr, "Learnable sink and ALiBi slopes cannot be used together"); + CHECK_DEVICE(learnable_sink); CHECK_CONTIGUOUS(learnable_sink); + TORCH_CHECK(learnable_sink.stride(-1) == 1, "Learnable sink tensor must have contiguous last dimension"); + CHECK_SHAPE(learnable_sink, num_heads); + params.learnable_sink_ptr = learnable_sink.data_ptr(); + } else { + params.learnable_sink_ptr = nullptr; + } if (seqlen_k > 0) { auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -539,8 +529,7 @@ std::vector mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. - std::optional &learnable_sink_, // num_heads - std::optional &out_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. @@ -557,7 +546,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s int window_size_right, const float softcap, const bool return_softmax, - std::optional gen_) { + std::optional gen_, + std::optional learnable_sink_) { // Otherwise the kernel will be launched from cuda:0 device at::cuda::CUDAGuard device_guard{q.device()}; @@ -758,7 +748,18 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s } set_params_alibi(params, alibi_slopes_, batch_size, num_heads); - set_params_sink(params, learnable_sink_, num_heads); + + at::Tensor learnable_sink; + if (learnable_sink_.has_value()) { + learnable_sink = learnable_sink_.value().to(torch::kFloat32); + TORCH_CHECK(params.alibi_slopes_ptr == nullptr, "Learnable sink and ALiBi slopes cannot be used together"); + CHECK_DEVICE(learnable_sink); CHECK_CONTIGUOUS(learnable_sink); + TORCH_CHECK(learnable_sink.stride(-1) == 1, "Learnable sink tensor must have contiguous last dimension"); + CHECK_SHAPE(learnable_sink, num_heads); + params.learnable_sink_ptr = learnable_sink.data_ptr(); + } else { + params.learnable_sink_ptr = nullptr; + } if (max_seqlen_k > 0) { auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -795,13 +796,11 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size - std::optional &learnable_sink, // num_heads const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &softmax_lse, // b x h x seqlen_q std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size std::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size std::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size - std::optional &dsink_, // num_heads std::optional &alibi_slopes_, // num_heads or batch_size x num_heads const float p_dropout, // probability to drop const float softmax_scale, @@ -811,7 +810,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl const float softcap, const bool deterministic, std::optional gen_, - std::optional &rng_state) { + std::optional &rng_state, + std::optional learnable_sink_, + std::optional dsink_) { #ifdef FLASHATTENTION_DISABLE_BACKWARD TORCH_CHECK(false, "This flash attention build does not support backward."); @@ -874,7 +875,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size); - at::Tensor dq, dk, dv, dsink; + at::Tensor dq, dk, dv; if (dq_.has_value()) { dq = dq_.value(); TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); @@ -902,17 +903,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl } else { dv = torch::empty_like(v); } - if (learnable_sink.has_value()) { - if (dsink_.has_value()) { - dsink = dsink_.value(); - TORCH_CHECK(dsink.dtype() == torch::kFloat32, "dsink must have dtype fp32"); - CHECK_DEVICE(dsink); - TORCH_CHECK(dsink.stride(-1) == 1, "dsink tensor must have contiguous last dimension"); - CHECK_SHAPE(dsink, num_heads); - } else { - dsink = torch::zeros_like(v); - } - } // bool loop = seqlen_k > blocksize_c; // TODO: change later, for now set to true for simplicity @@ -970,7 +960,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl /*unpadded_lse*/false); params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0); - auto launch = &run_mha_bwd; auto gen = at::get_generator_or_default( @@ -991,8 +980,30 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl } set_params_alibi(params, alibi_slopes_, batch_size, num_heads); - set_params_sink(params, learnable_sink, num_heads); - params.dsink_ptr = learnable_sink.has_value() ? dsink.data_ptr() : nullptr; + + at::Tensor dsink; + at::Tensor learnable_sink; + if (learnable_sink_.has_value()) { + learnable_sink = learnable_sink_.value().to(torch::kFloat32); + TORCH_CHECK(params.alibi_slopes_ptr == nullptr, "Learnable sink and ALiBi slopes cannot be used together"); + CHECK_DEVICE(learnable_sink); CHECK_CONTIGUOUS(learnable_sink); + TORCH_CHECK(learnable_sink.stride(-1) == 1, "Learnable sink tensor must have contiguous last dimension"); + CHECK_SHAPE(learnable_sink, num_heads); + if (dsink_.has_value() && dsink_.value().dtype() == torch::kFloat32) { + dsink = dsink_.value(); + CHECK_DEVICE(dsink); CHECK_CONTIGUOUS(dsink); + TORCH_CHECK(dsink.stride(-1) == 1, "dsink tensor must have contiguous last dimension"); + CHECK_SHAPE(dsink, num_heads); + } else { + dsink = torch::empty_like(learnable_sink); + } + dsink.zero_(); + params.learnable_sink_ptr = learnable_sink.data_ptr(); + params.dsink_ptr = dsink.data_ptr(); + } else { + params.learnable_sink_ptr = nullptr; + params.dsink_ptr = nullptr; + } if (seqlen_q > 0) { launch(params, stream); @@ -1009,22 +1020,28 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); } - return { dq, dk, dv, dsink, softmax_d }; -} + if (learnable_sink_.has_value()) { + if (!dsink_.has_value()) { + dsink = dsink.to(learnable_sink_.value().dtype()); + } else if (dsink_.value().dtype() != torch::kFloat32) { + dsink = dsink.to(dsink_.value().dtype()); + dsink_.value().copy_(dsink); + } + } + return { dq, dk, dv, softmax_d, dsink }; +} std::vector mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - std::optional &learnable_sink, // num_heads const at::Tensor &out, // total_q x num_heads x head_size const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp std::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i std::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i std::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - std::optional &dsink_, // num_heads const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 std::optional &alibi_slopes_, // num_heads or b x num_heads @@ -1039,7 +1056,9 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size const float softcap, const bool deterministic, std::optional gen_, - std::optional &rng_state) { + std::optional &rng_state, + std::optional learnable_sink_, + std::optional dsink_) { #ifdef FLASHATTENTION_DISABLE_BACKWARD TORCH_CHECK(false, "This flash attention build does not support backward."); @@ -1108,7 +1127,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size CHECK_SHAPE(cu_seqlens_q, batch_size + 1); CHECK_SHAPE(cu_seqlens_k, batch_size + 1); - at::Tensor dq, dk, dv, dsink; + at::Tensor dq, dk, dv; if (dq_.has_value()) { dq = dq_.value(); TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); @@ -1136,17 +1155,6 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size } else { dv = torch::empty_like(v); } - if (learnable_sink.has_value()) { - if (dsink_.has_value()) { - dsink = dsink_.value(); - TORCH_CHECK(dsink.dtype() == torch::kFloat32, "dsink must have dtype fp32"); - CHECK_DEVICE(dsink); - TORCH_CHECK(dsink.stride(-1) == 1, "dsink tensor must have contiguous last dimension"); - CHECK_SHAPE(dsink, num_heads); - } else { - dsink = torch::zeros_like(v); - } - } // bool loop = max_seqlen_k > blocksize_c; // TODO: change later, for now set to true for simplicity @@ -1236,8 +1244,30 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size } set_params_alibi(params, alibi_slopes_, batch_size, num_heads); - set_params_sink(params, learnable_sink, num_heads); - params.dsink_ptr = learnable_sink.has_value() ? dsink.data_ptr() : nullptr; + + at::Tensor dsink; + at::Tensor learnable_sink; + if (learnable_sink_.has_value()) { + learnable_sink = learnable_sink_.value().to(torch::kFloat32); + TORCH_CHECK(params.alibi_slopes_ptr == nullptr, "Learnable sink and ALiBi slopes cannot be used together"); + CHECK_DEVICE(learnable_sink); CHECK_CONTIGUOUS(learnable_sink); + TORCH_CHECK(learnable_sink.stride(-1) == 1, "Learnable sink tensor must have contiguous last dimension"); + CHECK_SHAPE(learnable_sink, num_heads); + if (dsink_.has_value() && dsink_.value().dtype() == torch::kFloat32) { + dsink = dsink_.value(); + CHECK_DEVICE(dsink); CHECK_CONTIGUOUS(dsink); + TORCH_CHECK(dsink.stride(-1) == 1, "dsink tensor must have contiguous last dimension"); + CHECK_SHAPE(dsink, num_heads); + } else { + dsink = torch::empty_like(learnable_sink); + } + dsink.zero_(); + params.learnable_sink_ptr = learnable_sink.data_ptr(); + params.dsink_ptr = dsink.data_ptr(); + } else { + params.learnable_sink_ptr = nullptr; + params.dsink_ptr = nullptr; + } if (max_seqlen_q > 0) { launch(params, stream); @@ -1254,7 +1284,16 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2}); } - return { dq, dk, dv, dsink, softmax_d }; + if (learnable_sink_.has_value()) { + if (!dsink_.has_value()) { + dsink = dsink.to(learnable_sink_.value().dtype()); + } else if (dsink_.value().dtype() != torch::kFloat32) { + dsink = dsink.to(dsink_.value().dtype()); + dsink_.value().copy_(dsink); + } + } + + return { dq, dk, dv, softmax_d, dsink }; } std::vector @@ -1270,7 +1309,6 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he std::optional &leftpad_k_, // batch_size std::optional &block_table_, // batch_size x max_num_blocks_per_seq std::optional &alibi_slopes_, // num_heads or batch_size x num_heads - std::optional &learnable_sink_, // num_heads std::optional &out_, // batch_size x seqlen_q x num_heads x head_size const float softmax_scale, bool is_causal, @@ -1278,7 +1316,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he int window_size_right, const float softcap, bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 - int num_splits + int num_splits, + std::optional learnable_sink_ ) { // Otherwise the kernel will be launched from cuda:0 device @@ -1509,7 +1548,18 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he set_params_alibi(params, alibi_slopes_, batch_size, num_heads); - set_params_sink(params, learnable_sink_, num_heads); + + at::Tensor learnable_sink; + if (learnable_sink_.has_value()) { + learnable_sink = learnable_sink_.value().to(torch::kFloat32); + TORCH_CHECK(params.alibi_slopes_ptr == nullptr, "Learnable sink and ALiBi slopes cannot be used together"); + CHECK_DEVICE(learnable_sink); CHECK_CONTIGUOUS(learnable_sink); + TORCH_CHECK(learnable_sink.stride(-1) == 1, "Learnable sink tensor must have contiguous last dimension"); + CHECK_SHAPE(learnable_sink, num_heads); + params.learnable_sink_ptr = learnable_sink.data_ptr(); + } else { + params.learnable_sink_ptr = nullptr; + } auto stream = at::cuda::getCurrentCUDAStream().stream(); // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx, diff --git a/csrc/flash_attn/src/flash.h b/csrc/flash_attn/src/flash.h index 151c4428da2..4453c9e55f0 100644 --- a/csrc/flash_attn/src/flash.h +++ b/csrc/flash_attn/src/flash.h @@ -140,8 +140,7 @@ struct Flash_fwd_params : public Qkv_params { bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q]. bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d). - - void *__restrict__ learnable_sink_ptr; // For gpt_oss + void *__restrict__ learnable_sink_ptr; }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -184,8 +183,7 @@ struct Flash_bwd_params : public Flash_fwd_params { bool deterministic; index_t dq_accum_split_stride; - - void *__restrict__ dsink_ptr; // For gpt_oss + void *__restrict__ dsink_ptr; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index 8eee366c5b2..e48bc2ca11d 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -77,7 +77,7 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom const& copy_atom, //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, const int bidh, const int n_block) { using Element = typename Kernel_traits::Element; @@ -86,9 +86,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Shared memory. extern __shared__ char smem_[]; - float dsink_val = 0.0f; - - const float sink_val = !Has_sink || params.learnable_sink_ptr == nullptr ? -INFINITY : reinterpret_cast(params.learnable_sink_ptr)[bidh]; // The thread index. const int tidx = threadIdx.x; @@ -103,6 +100,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in const BlockInfo binfo(params, bidb); if (n_block * kBlockN >= binfo.actual_seqlen_k) return; + float dsink_val = 0.0f; + + const float sink_val = !Has_sink ? -INFINITY : reinterpret_cast(params.learnable_sink_ptr)[bidh]; + int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM); if (Is_local) { m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left, kBlockM)); @@ -811,7 +812,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in } __syncthreads(); - if (tidx == 0 && params.dsink_ptr != nullptr) { + if (tidx == 0 && params.dsink_ptr != nullptr && *dsink_block_sum_ptr != 0) { float* dsink_ptr = reinterpret_cast(params.dsink_ptr); float val = -(*dsink_block_sum_ptr) * params.rp_dropout; atomicAdd(dsink_ptr + bidh, val); @@ -822,7 +823,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv(const Params ¶ms) { // The block index for the batch. @@ -836,20 +837,20 @@ inline __device__ void compute_dq_dk_dv(const Params ¶ms) { const int n_block_max = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; if (n_block_max == 1) { - compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); + compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); } else { // Iterating backward from n_block_max - 1 to 0 might save 1 register - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block_max - 1); + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block_max - 1); for (int n_block = n_block_max - 2; n_block > 0; n_block--) { - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); } - compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); + compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); } } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) { // The block index for the batch. @@ -859,7 +860,7 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) { // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer. for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) { - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); } } diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index 54581960caa..925942ac513 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -31,18 +31,18 @@ namespace FLASH_NAMESPACE { template \ __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params) -DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Has_sink, bool Is_even_M, bool Is_even_K) { +DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K) { #if defined(ARCH_SUPPORTS_FLASH) - FLASH_NAMESPACE::compute_dq_dk_dv(params); + FLASH_NAMESPACE::compute_dq_dk_dv(params); #else FLASH_UNSUPPORTED_ARCH #endif } -DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Has_sink, bool Is_even_MN, bool Is_even_K, bool Is_softcap) { +DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Has_sink) { #if defined(ARCH_SUPPORTS_FLASH) static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false - FLASH_NAMESPACE::compute_dq_dk_dv_seqk_parallel(params); + FLASH_NAMESPACE::compute_dq_dk_dv_seqk_parallel(params); #else FLASH_UNSUPPORTED_ARCH #endif @@ -103,7 +103,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If Is_local, set Is_causal to false - auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; if (smem_size_dq_dk_dv >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 9e96e9f487f..6f1fb01e475 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -48,7 +48,7 @@ __forceinline__ __device__ auto get_lse_tile(const Params ¶ms, const int bid } -template +template inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { using Element = typename Kernel_traits::Element; @@ -80,6 +80,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi const BlockInfo binfo(params, bidb); if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + const float sink_val = !Has_sink ? -INFINITY : reinterpret_cast(params.learnable_sink_ptr)[bidh]; + const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); if (Is_causal || Is_local) { @@ -89,7 +91,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max); // } } - const float sink_val = !Has_sink || params.learnable_sink_ptr == nullptr ? -INFINITY : reinterpret_cast(params.learnable_sink_ptr)[bidh]; // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0. // Otherwise we might read OOB elements from gK and gV. if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) { @@ -496,7 +497,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { using Element = typename Kernel_traits::Element; @@ -576,6 +577,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons return; } + const float sink_val = !Has_sink ? -INFINITY : reinterpret_cast(params.learnable_sink_ptr)[bidh]; // We iterate over the blocks in reverse order. This is because the last block is the only one // that needs masking when we read K and V from global memory. Moreover, iterating in reverse // might save us 1 register (we just need n_block instead of both n_block and n_block_max). @@ -834,7 +836,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons clear(acc_o); - const float sink_val = !Has_sink || params.learnable_sink_ptr == nullptr ? -INFINITY : reinterpret_cast(params.learnable_sink_ptr)[bidh]; FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax(sink_val); const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; @@ -1074,7 +1075,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn(const Params ¶ms) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1090,12 +1091,12 @@ inline __device__ void compute_attn(const Params ¶ms) { // the attention matrix. This way, as long as we have the batch, head, and the location of // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. - FLASH_NAMESPACE::compute_attn_1rowblock(params, bidb, bidh, m_block); + FLASH_NAMESPACE::compute_attn_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_splitkv(const Params ¶ms) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1104,7 +1105,7 @@ inline __device__ void compute_attn_splitkv(const Params ¶ms) { const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; const int n_split_idx = Split ? blockIdx.y : 0; const int num_n_splits = Split ? gridDim.y : 1; - FLASH_NAMESPACE::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); + FLASH_NAMESPACE::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1194,23 +1195,22 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); } SumOp sum_op; lse_sum = Allreduce::run(lse_sum, sum_op); - + if constexpr(Has_sink) { const int row = tidx % kRowsPerLoadTranspose; const int col = tidx / kRowsPerLoadTranspose; if (row < params.num_splits && col < kBlockM) { const index_t lse_offset = row_offset_lse + col; + int head_idx; if (params.unpadded_lse) { // LSE is written as (h, seqlen_q, b) or (h, b, seqlen_q). - const int head_idx = lse_offset / (params.b * params.seqlen_q); - const float sink_val_exp = params.learnable_sink_ptr == nullptr ? 0.f : __expf(reinterpret_cast(params.learnable_sink_ptr)[head_idx] - lse_max); - lse_sum += sink_val_exp; + head_idx = lse_offset / (params.b * params.seqlen_q); } else { // LSE is written as (b, h, seqlen_q). - const int head_idx = (lse_offset % (params.h * params.seqlen_q)) / params.seqlen_q; - const float sink_val_exp = params.learnable_sink_ptr == nullptr ? 0.f : __expf(reinterpret_cast(params.learnable_sink_ptr)[head_idx] - lse_max); - lse_sum += sink_val_exp; + head_idx = (lse_offset % (params.h * params.seqlen_q)) / params.seqlen_q; } + const float sink_val_exp = __expf(reinterpret_cast(params.learnable_sink_ptr)[head_idx] - lse_max); + lse_sum += sink_val_exp; } } diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index bce29444449..b83bc9fce3f 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -29,18 +29,18 @@ namespace FLASH_NAMESPACE { template \ __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Has_sink, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) { +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, bool Has_sink) { #if defined(ARCH_SUPPORTS_FLASH) static_assert(!(Is_causal && Is_local)); // Enforce constraints - FLASH_NAMESPACE::compute_attn(params); + FLASH_NAMESPACE::compute_attn(params); #else FLASH_UNSUPPORTED_ARCH #endif } -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Has_sink, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV) { +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_sink) { #if defined(ARCH_SUPPORTS_FLASH) - FLASH_NAMESPACE::compute_attn_splitkv(params); + FLASH_NAMESPACE::compute_attn_splitkv(params); #else FLASH_UNSUPPORTED_ARCH #endif @@ -77,7 +77,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // If return_softmax, set IsEvenMNConst to false to reduce number of templates // If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_kernel; + auto kernel = &flash_fwd_kernel; // auto kernel = &flash_fwd_kernel; // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); // auto kernel = &flash_fwd_kernel; @@ -121,7 +121,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If Is_local, set Is_causal to false // We fix the lse for sink in combine_attn_seqk_parallel if Split is true. - auto kernel = &flash_fwd_splitkv_kernel; + auto kernel = &flash_fwd_splitkv_kernel; // auto kernel = &flash_fwd_splitkv_kernel; // auto kernel = &flash_fwd_splitkv_kernel; if (smem_size >= 48 * 1024) { diff --git a/csrc/flash_attn/src/softmax.h b/csrc/flash_attn/src/softmax.h index a419a8f72f0..2fd406bba86 100644 --- a/csrc/flash_attn/src/softmax.h +++ b/csrc/flash_attn/src/softmax.h @@ -135,7 +135,7 @@ struct Softmax { __forceinline__ __device__ Softmax(const float sink_val = -INFINITY) : sink_val(sink_val) {}; template - __forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2, float softmax_scale=1.0) { + __forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2, float softmax_scale=1.0f) { // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(acc_s.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_s.layout())); static_assert(decltype(size<0>(scores))::value == kNRows); diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 1a8f6fbfa5c..fc8d33d6025 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -78,7 +78,6 @@ def _flash_attn_forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - learnable_sink: Optional[torch.Tensor], dropout_p: float, softmax_scale: float, causal: bool, @@ -86,14 +85,14 @@ def _flash_attn_forward( window_size_right: int, softcap: float, alibi_slopes: Optional[torch.Tensor], - return_softmax: bool + return_softmax: bool, + learnable_sink: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd( q, k, v, - learnable_sink, None, alibi_slopes, dropout_p, @@ -104,6 +103,7 @@ def _flash_attn_forward( softcap, return_softmax, None, + learnable_sink, ) return out, softmax_lse, S_dmask, rng_state @@ -113,7 +113,6 @@ def _flash_attn_forward_fake( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - learnable_sink: Optional[torch.Tensor], dropout_p: float, softmax_scale: float, causal: bool, @@ -121,7 +120,8 @@ def _flash_attn_forward_fake( window_size_right: int, softcap: float, alibi_slopes: Optional[torch.Tensor], - return_softmax: bool + return_softmax: bool, + learnable_sink: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, v = [maybe_contiguous(x) for x in (q, k, v)] batch_size, seqlen_q, num_heads, head_size = q.shape @@ -139,7 +139,7 @@ def _flash_attn_forward_fake( return out, softmax_lse, p, rng_state -if torch.__version__ >= "2.4.0": +if False: # if torch.__version__ >= "2.4.0": _wrapped_flash_attn_forward = torch.ops.flash_attn._flash_attn_forward else: _wrapped_flash_attn_forward = _flash_attn_forward @@ -150,7 +150,6 @@ def _flash_attn_varlen_forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - learnable_sink: Optional[torch.Tensor], cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, @@ -167,13 +166,13 @@ def _flash_attn_varlen_forward( leftpad_k: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, zero_tensors: bool = False, + learnable_sink: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd( q, k, v, - learnable_sink, None, cu_seqlens_q, cu_seqlens_k, @@ -192,6 +191,7 @@ def _flash_attn_varlen_forward( softcap, return_softmax, None, + learnable_sink, ) # if out.isnan().any() or softmax_lse.isnan().any(): # breakpoint() @@ -203,7 +203,6 @@ def _flash_attn_varlen_forward_fake( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - learnable_sink: Optional[torch.Tensor], cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, @@ -220,6 +219,7 @@ def _flash_attn_varlen_forward_fake( leftpad_k: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, zero_tensors: bool = False, + learnable_sink: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, v = [maybe_contiguous(x) for x in (q, k, v)] paged_kv = block_table is not None @@ -238,25 +238,23 @@ def _flash_attn_varlen_forward_fake( return out, softmax_lse, p, rng_state -if torch.__version__ >= "2.4.0": +if False: # if torch.__version__ >= "2.4.0": _wrapped_flash_attn_varlen_forward = torch.ops.flash_attn._flash_attn_varlen_forward else: _wrapped_flash_attn_varlen_forward = _flash_attn_varlen_forward -@_torch_custom_op_wrapper("flash_attn::_flash_attn_backward", mutates_args=("dq", "dk", "dv", "dsink"), device_types="cuda") +@_torch_custom_op_wrapper("flash_attn::_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda") def _flash_attn_backward( dout: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - learnable_sink: Optional[torch.Tensor], out: torch.Tensor, softmax_lse: torch.Tensor, dq: Optional[torch.Tensor], dk: Optional[torch.Tensor], dv: Optional[torch.Tensor], - dsink: Optional[torch.Tensor], dropout_p: float, softmax_scale: float, causal: bool, @@ -266,6 +264,8 @@ def _flash_attn_backward( alibi_slopes: Optional[torch.Tensor], deterministic: bool, rng_state: Optional[torch.Tensor] = None, + learnable_sink: Optional[torch.Tensor] = None, + dsink: Optional[torch.Tensor] = None, ) -> torch.Tensor: # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] @@ -273,20 +273,18 @@ def _flash_attn_backward( dq, dk, dv, - dsink, softmax_d, + dsink, ) = flash_attn_gpu.bwd( dout, q, k, v, - learnable_sink, out, softmax_lse, dq, dk, dv, - dsink, alibi_slopes, dropout_p, softmax_scale, @@ -297,6 +295,8 @@ def _flash_attn_backward( deterministic, None, rng_state, + learnable_sink, + dsink, ) return softmax_d @@ -307,13 +307,11 @@ def _flash_attn_backward_fake( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - learnable_sink: Optional[torch.Tensor], out: torch.Tensor, softmax_lse: torch.Tensor, dq: Optional[torch.Tensor], dk: Optional[torch.Tensor], dv: Optional[torch.Tensor], - dsink: Optional[torch.Tensor], dropout_p: float, softmax_scale: float, causal: bool, @@ -323,6 +321,8 @@ def _flash_attn_backward_fake( alibi_slopes: Optional[torch.Tensor], deterministic: bool, rng_state: Optional[torch.Tensor] = None, + learnable_sink: Optional[torch.Tensor] = None, + dsink: Optional[torch.Tensor] = None, ) -> torch.Tensor: dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] if dq is None: @@ -342,25 +342,23 @@ def _flash_attn_backward_fake( return softmax_d -if torch.__version__ >= "2.4.0": +if False: # if torch.__version__ >= "2.4.0": _wrapped_flash_attn_backward = torch.ops.flash_attn._flash_attn_backward else: _wrapped_flash_attn_backward = _flash_attn_backward -@_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_backward", mutates_args=("dq", "dk", "dv", "dsink"), device_types="cuda") +@_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda") def _flash_attn_varlen_backward( dout: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - learnable_sink: Optional[torch.Tensor], out: torch.Tensor, softmax_lse: torch.Tensor, dq: Optional[torch.Tensor], dk: Optional[torch.Tensor], dv: Optional[torch.Tensor], - dsink: Optional[torch.Tensor], cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, @@ -375,6 +373,8 @@ def _flash_attn_varlen_backward( deterministic: bool, rng_state: Optional[torch.Tensor] = None, zero_tensors: bool = False, + learnable_sink: Optional[torch.Tensor] = None, + dsink: Optional[torch.Tensor] = None, ) -> torch.Tensor: # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] @@ -382,20 +382,18 @@ def _flash_attn_varlen_backward( dq, dk, dv, - dsink, softmax_d, + dsink, ) = flash_attn_gpu.varlen_bwd( dout, q, k, v, - learnable_sink, out, softmax_lse, dq, dk, dv, - dsink, cu_seqlens_q, cu_seqlens_k, alibi_slopes, @@ -411,6 +409,8 @@ def _flash_attn_varlen_backward( deterministic, None, rng_state, + learnable_sink, + dsink, ) # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): # breakpoint() @@ -423,13 +423,11 @@ def _flash_attn_varlen_backward_fake( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - learnable_sink: Optional[torch.Tensor], out: torch.Tensor, softmax_lse: torch.Tensor, dq: Optional[torch.Tensor], dk: Optional[torch.Tensor], dv: Optional[torch.Tensor], - dsink: Optional[torch.Tensor], cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, @@ -444,6 +442,8 @@ def _flash_attn_varlen_backward_fake( deterministic: bool, rng_state: Optional[torch.Tensor] = None, zero_tensors: bool = False, + learnable_sink: Optional[torch.Tensor] = None, + dsink: Optional[torch.Tensor] = None, ) -> torch.Tensor: dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] batch_size = cu_seqlens_q.numel() - 1 @@ -465,7 +465,7 @@ def _flash_attn_varlen_backward_fake( return softmax_d -if torch.__version__ >= "2.4.0": +if False: # if torch.__version__ >= "2.4.0": _wrapped_flash_attn_varlen_backward = torch.ops.flash_attn._flash_attn_varlen_backward else: _wrapped_flash_attn_varlen_backward = _flash_attn_varlen_backward @@ -476,7 +476,6 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): def forward( ctx, qkv, - learnable_sink, dropout_p, softmax_scale, causal, @@ -486,6 +485,7 @@ def forward( deterministic, return_softmax, is_grad_enabled, + learnable_sink, ): is_grad = is_grad_enabled and qkv.requires_grad if softmax_scale is None: @@ -500,7 +500,6 @@ def forward( q, k, v, - learnable_sink, dropout_p, softmax_scale, causal=causal, @@ -509,9 +508,10 @@ def forward( softcap=softcap, alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, + learnable_sink=learnable_sink, ) if is_grad: - ctx.save_for_backward(q, k, v, learnable_sink, out_padded, softmax_lse, rng_state) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state, learnable_sink) ctx.dropout_p = dropout_p ctx.softmax_scale = softmax_scale ctx.causal = causal @@ -524,10 +524,10 @@ def forward( @staticmethod def backward(ctx, dout, *args): - q, k, v, learnable_sink, out, softmax_lse, rng_state = ctx.saved_tensors + q, k, v, out, softmax_lse, rng_state, learnable_sink = ctx.saved_tensors qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) - dsink = None if learnable_sink is None else torch.zeros_like(learnable_sink) + dsink = None if learnable_sink is None else torch.empty_like(learnable_sink) head_size_og = dout.size(3) dout_padded = dout if head_size_og % 8 != 0: @@ -537,13 +537,11 @@ def backward(ctx, dout, *args): q, k, v, - learnable_sink, out, softmax_lse, dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2], - dsink, ctx.dropout_p, ctx.softmax_scale, ctx.causal, @@ -553,9 +551,11 @@ def backward(ctx, dout, *args): ctx.alibi_slopes, ctx.deterministic, rng_state=rng_state, + learnable_sink=learnable_sink, + dsink=dsink, ) dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, dsink, None, None, None, None, None, None, None, None, None + return dqkv, None, None, None, None, None, None, None, None, None, dsink class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): @@ -563,7 +563,6 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): def forward( ctx, qkv, - learnable_sink, cu_seqlens, max_seqlen, dropout_p, @@ -575,6 +574,7 @@ def forward( deterministic, return_softmax, is_grad_enabled, + learnable_sink, ): is_grad = is_grad_enabled and qkv.requires_grad if softmax_scale is None: @@ -589,7 +589,6 @@ def forward( q, k, v, - learnable_sink, cu_seqlens, cu_seqlens, max_seqlen, @@ -603,9 +602,10 @@ def forward( alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, block_table=None, + learnable_sink=learnable_sink, ) if is_grad: - ctx.save_for_backward(q, k, v, learnable_sink, out_padded, softmax_lse, cu_seqlens, rng_state) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state, learnable_sink) ctx.dropout_p = dropout_p ctx.max_seqlen = max_seqlen ctx.softmax_scale = softmax_scale @@ -619,10 +619,10 @@ def forward( @staticmethod def backward(ctx, dout, *args): - q, k, v, learnable_sink, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors + q, k, v, out, softmax_lse, cu_seqlens, rng_state, learnable_sink = ctx.saved_tensors qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) - dsink = None if learnable_sink is None else torch.zeros_like(learnable_sink) + dsink = None if learnable_sink is None else torch.empty_like(learnable_sink) head_size_og = dout.size(2) dout_padded = dout if head_size_og % 8 != 0: @@ -632,13 +632,11 @@ def backward(ctx, dout, *args): q, k, v, - learnable_sink, out, softmax_lse, dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], - dsink, cu_seqlens, cu_seqlens, ctx.max_seqlen, @@ -652,9 +650,11 @@ def backward(ctx, dout, *args): ctx.alibi_slopes, ctx.deterministic, rng_state=rng_state, + learnable_sink=learnable_sink, + dsink=dsink, ) dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, dsink, None, None, None, None, None, None, None, None, None, None, None + return dqkv, None, None, None, None, None, None, None, None, None, None, None, dsink class FlashAttnKVPackedFunc(torch.autograd.Function): @@ -663,7 +663,6 @@ def forward( ctx, q, kv, - learnable_sink, dropout_p, softmax_scale, causal, @@ -673,6 +672,7 @@ def forward( deterministic, return_softmax, is_grad_enabled, + learnable_sink, ): is_grad = is_grad_enabled and any( x.requires_grad for x in [q, kv] @@ -689,7 +689,6 @@ def forward( q, k, v, - learnable_sink, dropout_p, softmax_scale, causal=causal, @@ -698,9 +697,10 @@ def forward( softcap=softcap, alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, + learnable_sink=learnable_sink, ) if is_grad: - ctx.save_for_backward(q, k, v, learnable_sink, out_padded, softmax_lse, rng_state) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state, learnable_sink) ctx.dropout_p = dropout_p ctx.softmax_scale = softmax_scale ctx.causal = causal @@ -713,11 +713,11 @@ def forward( @staticmethod def backward(ctx, dout, *args): - q, k, v, learnable_sink, out, softmax_lse, rng_state = ctx.saved_tensors + q, k, v, out, softmax_lse, rng_state, learnable_sink = ctx.saved_tensors dq = torch.empty_like(q) kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) - dsink = None if learnable_sink is None else torch.zeros_like(learnable_sink) + dsink = None if learnable_sink is None else torch.empty_like(learnable_sink) head_size_og = dout.size(3) dout_padded = dout if head_size_og % 8 != 0: @@ -727,13 +727,11 @@ def backward(ctx, dout, *args): q, k, v, - learnable_sink, out, softmax_lse, dq, dkv[:, :, 0], dkv[:, :, 1], - dsink, ctx.dropout_p, ctx.softmax_scale, ctx.causal, @@ -743,10 +741,12 @@ def backward(ctx, dout, *args): ctx.alibi_slopes, ctx.deterministic, rng_state=rng_state, + learnable_sink=learnable_sink, + dsink=dsink, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dkv = dkv[..., : dout.shape[-1]] - return dq, dkv, dsink, None, None, None, None, None, None, None, None, None + return dq, dkv, None, None, None, None, None, None, None, None, None, dsink class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): @@ -755,7 +755,6 @@ def forward( ctx, q, kv, - learnable_sink, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -769,6 +768,7 @@ def forward( deterministic, return_softmax, is_grad_enabled, + learnable_sink, ): is_grad = is_grad_enabled and any( x.requires_grad for x in [q, kv] @@ -785,7 +785,6 @@ def forward( q, k, v, - learnable_sink, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -799,10 +798,11 @@ def forward( alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, block_table=None, + learnable_sink=learnable_sink, ) if is_grad: ctx.save_for_backward( - q, k, v, learnable_sink, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state + q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, learnable_sink ) ctx.dropout_p = dropout_p ctx.max_seqlen_q = max_seqlen_q @@ -818,11 +818,11 @@ def forward( @staticmethod def backward(ctx, dout, *args): - q, k, v, learnable_sink, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, learnable_sink = ctx.saved_tensors dq = torch.empty_like(q) kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) - dsink = None if learnable_sink is None else torch.zeros_like(learnable_sink) + dsink = None if learnable_sink is None else torch.empty_like(learnable_sink) head_size_og = dout.size(2) dout_padded = dout if head_size_og % 8 != 0: @@ -832,13 +832,11 @@ def backward(ctx, dout, *args): q, k, v, - learnable_sink, out, softmax_lse, dq, dkv[:, 0], dkv[:, 1], - dsink, cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, @@ -852,10 +850,12 @@ def backward(ctx, dout, *args): ctx.alibi_slopes, ctx.deterministic, rng_state=rng_state, + learnable_sink=learnable_sink, + dsink=dsink, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dkv = dkv[..., : dout.shape[-1]] - return dq, dkv, dsink, None, None, None, None, None, None, None, None, None, None, None, None, None + return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None, None, dsink class FlashAttnFunc(torch.autograd.Function): @@ -865,7 +865,6 @@ def forward( q, k, v, - learnable_sink, dropout_p, softmax_scale, causal, @@ -875,6 +874,7 @@ def forward( deterministic, return_softmax, is_grad_enabled, + learnable_sink, ): is_grad = is_grad_enabled and any( x.requires_grad for x in [q, k, v] @@ -890,7 +890,6 @@ def forward( q, k, v, - learnable_sink, dropout_p, softmax_scale, causal=causal, @@ -899,9 +898,10 @@ def forward( softcap=softcap, alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, + learnable_sink=learnable_sink, ) if is_grad: - ctx.save_for_backward(q, k, v, learnable_sink, out_padded, softmax_lse, rng_state) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state, learnable_sink) ctx.dropout_p = dropout_p ctx.softmax_scale = softmax_scale ctx.causal = causal @@ -914,9 +914,9 @@ def forward( @staticmethod def backward(ctx, dout, *args): - q, k, v, learnable_sink, out, softmax_lse, rng_state = ctx.saved_tensors + q, k, v, out, softmax_lse, rng_state, learnable_sink = ctx.saved_tensors dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) - dsink = None if learnable_sink is None else torch.zeros_like(learnable_sink) + dsink = None if learnable_sink is None else torch.empty_like(learnable_sink) head_size_og = dout.size(3) dout_padded = dout if head_size_og % 8 != 0: @@ -926,13 +926,11 @@ def backward(ctx, dout, *args): q, k, v, - learnable_sink, out, softmax_lse, dq, dk, dv, - dsink, ctx.dropout_p, ctx.softmax_scale, ctx.causal, @@ -942,11 +940,13 @@ def backward(ctx, dout, *args): ctx.alibi_slopes, ctx.deterministic, rng_state=rng_state, + learnable_sink=learnable_sink, + dsink=dsink, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, dsink, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, dsink class FlashAttnVarlenFunc(torch.autograd.Function): @@ -956,7 +956,6 @@ def forward( q, k, v, - learnable_sink, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -971,6 +970,7 @@ def forward( return_softmax, block_table, is_grad_enabled, + learnable_sink, ): is_grad = is_grad_enabled and any( x.requires_grad for x in [q, k, v] @@ -986,7 +986,6 @@ def forward( q, k, v, - learnable_sink, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -1000,10 +999,11 @@ def forward( alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, block_table=block_table, + learnable_sink=learnable_sink, ) if is_grad: ctx.save_for_backward( - q, k, v, learnable_sink, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state + q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, learnable_sink ) ctx.dropout_p = dropout_p ctx.max_seqlen_q = max_seqlen_q @@ -1020,9 +1020,9 @@ def forward( @staticmethod def backward(ctx, dout, *args): - q, k, v, learnable_sink, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, learnable_sink = ctx.saved_tensors dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) - dsink = None if learnable_sink is None else torch.zeros_like(learnable_sink) + dsink = None if learnable_sink is None else torch.empty_like(learnable_sink) head_size_og = dout.size(2) dout_padded = dout if head_size_og % 8 != 0: @@ -1032,13 +1032,11 @@ def backward(ctx, dout, *args): q, k, v, - learnable_sink, out, softmax_lse, dq, dk, dv, - dsink, cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, @@ -1052,11 +1050,13 @@ def backward(ctx, dout, *args): ctx.alibi_slopes, ctx.deterministic, rng_state=rng_state, + learnable_sink=learnable_sink, + dsink=dsink, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, dsink, None, None, None, None, None, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, dsink def flash_attn_qkvpacked_func( @@ -1067,9 +1067,9 @@ def flash_attn_qkvpacked_func( window_size=(-1, -1), # -1 means infinite context window softcap=0.0, # <=0.0 means deactivate alibi_slopes=None, - learnable_sink=None, deterministic=False, return_attn_probs=False, + learnable_sink=None, ): """dropout_p should be set to 0.0 during evaluation If Q, K, V are already stacked into 1 tensor, this function will be faster than @@ -1091,13 +1091,13 @@ def flash_attn_qkvpacked_func( softcap: float. Anything > 0 activates softcapping attention. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to the attention score of query i and key j. - learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an - additional attention logit to each query's attention scores before softmax. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). + learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an + additional attention logit to each query's attention scores before softmax. Return: out: (batch_size, seqlen, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The @@ -1109,7 +1109,6 @@ def flash_attn_qkvpacked_func( """ return FlashAttnQKVPackedFunc.apply( qkv, - learnable_sink, dropout_p, softmax_scale, causal, @@ -1119,6 +1118,7 @@ def flash_attn_qkvpacked_func( deterministic, return_attn_probs, torch.is_grad_enabled(), + learnable_sink, ) @@ -1131,9 +1131,9 @@ def flash_attn_kvpacked_func( window_size=(-1, -1), # -1 means infinite context window softcap=0.0, # 0.0 means deactivated alibi_slopes=None, - learnable_sink=None, deterministic=False, return_attn_probs=False, + learnable_sink=None, ): """dropout_p should be set to 0.0 during evaluation If K, V are already stacked into 1 tensor, this function will be faster than @@ -1172,13 +1172,13 @@ def flash_attn_kvpacked_func( alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) is added to the attention score of query i and key j. - learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an - additional attention logit to each query's attention scores before softmax. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). + learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an + additional attention logit to each query's attention scores before softmax. Return: out: (batch_size, seqlen, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The @@ -1191,7 +1191,6 @@ def flash_attn_kvpacked_func( return FlashAttnKVPackedFunc.apply( q, kv, - learnable_sink, dropout_p, softmax_scale, causal, @@ -1201,6 +1200,7 @@ def flash_attn_kvpacked_func( deterministic, return_attn_probs, torch.is_grad_enabled(), + learnable_sink, ) @@ -1214,9 +1214,9 @@ def flash_attn_func( window_size=(-1, -1), # -1 means infinite context window softcap=0.0, # 0.0 means deactivated alibi_slopes=None, - learnable_sink=None, deterministic=False, return_attn_probs=False, + learnable_sink=None, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads @@ -1252,13 +1252,13 @@ def flash_attn_func( alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) is added to the attention score of query i and key j. - learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an - additional attention logit to each query's attention scores before softmax. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). + learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an + additional attention logit to each query's attention scores before softmax. Return: out: (batch_size, seqlen, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The @@ -1272,7 +1272,6 @@ def flash_attn_func( q, k, v, - learnable_sink, dropout_p, softmax_scale, causal, @@ -1282,6 +1281,7 @@ def flash_attn_func( deterministic, return_attn_probs, torch.is_grad_enabled(), + learnable_sink, ) @@ -1295,9 +1295,9 @@ def flash_attn_varlen_qkvpacked_func( window_size=(-1, -1), # -1 means infinite context window softcap=0.0, # 0.0 means deactivated alibi_slopes=None, - learnable_sink=None, deterministic=False, return_attn_probs=False, + learnable_sink=None, ): """dropout_p should be set to 0.0 during evaluation If Q, K, V are already stacked into 1 tensor, this function will be faster than @@ -1322,13 +1322,13 @@ def flash_attn_varlen_qkvpacked_func( softcap: float. Anything > 0 activates softcapping attention. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to the attention score of query i and key j. - learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an - additional attention logit to each query's attention scores before softmax. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). + learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an + additional attention logit to each query's attention scores before softmax. Return: out: (total, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The @@ -1340,7 +1340,6 @@ def flash_attn_varlen_qkvpacked_func( """ return FlashAttnVarlenQKVPackedFunc.apply( qkv, - learnable_sink, cu_seqlens, max_seqlen, dropout_p, @@ -1352,6 +1351,7 @@ def flash_attn_varlen_qkvpacked_func( deterministic, return_attn_probs, torch.is_grad_enabled(), + learnable_sink, ) @@ -1368,9 +1368,9 @@ def flash_attn_varlen_kvpacked_func( window_size=(-1, -1), # -1 means infinite context window softcap=0.0, # 0.0 means deactivated alibi_slopes=None, - learnable_sink=None, deterministic=False, return_attn_probs=False, + learnable_sink=None, ): """dropout_p should be set to 0.0 during evaluation If K, V are already stacked into 1 tensor, this function will be faster than @@ -1415,13 +1415,13 @@ def flash_attn_varlen_kvpacked_func( alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) is added to the attention score of query i and key j. - learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an - additional attention logit to each query's attention scores before softmax. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). + learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an + additional attention logit to each query's attention scores before softmax. Return: out: (total, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The @@ -1434,7 +1434,6 @@ def flash_attn_varlen_kvpacked_func( return FlashAttnVarlenKVPackedFunc.apply( q, kv, - learnable_sink, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -1448,6 +1447,7 @@ def flash_attn_varlen_kvpacked_func( deterministic, return_attn_probs, torch.is_grad_enabled(), + learnable_sink, ) @@ -1465,10 +1465,10 @@ def flash_attn_varlen_func( window_size=(-1, -1), # -1 means infinite context window softcap=0.0, # 0.0 means deactivated alibi_slopes=None, - learnable_sink=None, deterministic=False, return_attn_probs=False, block_table=None, + learnable_sink=None, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads @@ -1511,13 +1511,13 @@ def flash_attn_varlen_func( alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) is added to the attention score of query i and key j. - learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an - additional attention logit to each query's attention scores before softmax. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). + learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an + additional attention logit to each query's attention scores before softmax. Return: out: (total, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The @@ -1531,7 +1531,6 @@ def flash_attn_varlen_func( q, k, v, - learnable_sink, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -1546,6 +1545,7 @@ def flash_attn_varlen_func( return_attn_probs, block_table, torch.is_grad_enabled(), + learnable_sink, ) @@ -1567,9 +1567,9 @@ def flash_attn_with_kvcache( softcap=0.0, # 0.0 means deactivated rotary_interleaved=True, alibi_slopes=None, - learnable_sink=None, num_splits=0, return_softmax_lse=False, + learnable_sink=None, ): """ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from @@ -1646,8 +1646,6 @@ def flash_attn_with_kvcache( alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) is added to the attention score of query i and key j. - learnable_sink: (nheads,) fp32 tensor. A learnable attention sink value per head that is appended as an - additional attention logit to each query's attention scores before softmax. num_splits: int. If > 1, split the key/value into this many chunks along the sequence. If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic to automatically determine the number of splits. @@ -1685,7 +1683,6 @@ def flash_attn_with_kvcache( cache_leftpad, block_table, alibi_slopes, - learnable_sink, None, softmax_scale, causal, @@ -1694,5 +1691,6 @@ def flash_attn_with_kvcache( softcap, rotary_interleaved, num_splits, + learnable_sink, ) return (out, softmax_lse) if return_softmax_lse else out diff --git a/setup.py b/setup.py index fafea904998..2eb88cca0ed 100644 --- a/setup.py +++ b/setup.py @@ -65,6 +65,7 @@ USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" SKIP_CK_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CK_BUILD", "TRUE") == "TRUE" if USE_TRITON_ROCM else False NVCC_THREADS = os.getenv("NVCC_THREADS") or "4" +DISABLE_SINK = os.getenv("FLASH_ATTENTION_DISABLE_SINK", "FALSE") == "TRUE" @functools.lru_cache(maxsize=None) def cuda_archs() -> str: @@ -271,6 +272,7 @@ def validate_and_update_archs(archs): # "-DFLASHATTENTION_DISABLE_SOFTCAP", # "-DFLASHATTENTION_DISABLE_UNEVEN_K", # "-DFLASHATTENTION_DISABLE_LOCAL", + # "-DFLASHATTENTION_DISABLE_SINK", ] compiler_c17_flag=["-O3", "-std=c++17"] diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 3236cb69443..101937230ed 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -226,11 +226,11 @@ def attention_ref( dropout_mask=None, causal=False, window_size=(-1, -1), # -1 means infinite window size - learnable_sink=None, softcap=0.0, upcast=True, reorder_ops=False, key_leftpad=None, + learnable_sink=None, ): """ Arguments: @@ -324,11 +324,11 @@ def attention_kvpacked_ref( dropout_mask=None, causal=False, window_size=(-1, -1), # -1 means infinite window size - learnable_sink=None, softcap=0.0, upcast=True, reorder_ops=False, key_leftpad=None, + learnable_sink=None, ): return attention_ref( q, @@ -342,10 +342,10 @@ def attention_kvpacked_ref( upcast=upcast, causal=causal, window_size=window_size, - learnable_sink=learnable_sink, softcap=softcap, reorder_ops=reorder_ops, key_leftpad=key_leftpad, + learnable_sink=learnable_sink, ) @@ -357,10 +357,10 @@ def attention_qkvpacked_ref( dropout_mask=None, causal=False, window_size=(-1, -1), # -1 means infinite window size - learnable_sink=None, softcap=0.0, upcast=True, reorder_ops=False, + learnable_sink=None, ): return attention_ref( qkv[:, :, 0], @@ -374,9 +374,9 @@ def attention_qkvpacked_ref( upcast=upcast, causal=causal, window_size=window_size, - learnable_sink=learnable_sink, softcap=softcap, reorder_ops=reorder_ops, + learnable_sink=learnable_sink, ) @@ -585,32 +585,6 @@ def get_dropout_fraction( return dropped.sum() / valid.sum() -def get_attention_mask( - seqlen_q, - seqlen_k, - causal, - device, - window_size=(-1, -1), - query_padding_mask=None, - key_padding_mask=None, - key_leftpad=None, -): - if causal: - window_size = (window_size[0], 0) - if window_size[0] >= 0 or window_size[1] >= 0: - local_mask = construct_local_mask( - seqlen_q, - seqlen_k, - window_size, - query_padding_mask, - key_padding_mask, - device, - key_leftpad=key_leftpad, - ) - return torch.where(local_mask, torch.tensor(float('-inf')), torch.tensor(0.0)) - return None - - @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("deterministic", [False, True]) @@ -650,9 +624,8 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal) else: alibi_slopes, attn_bias = None, None - if has_learnable_sink: - learnable_sink = torch.randn(nheads, device=device, dtype=torch.float32, requires_grad=True) * 0.3 - else: + learnable_sink = torch.randn(nheads, device=device, dtype=dtype, requires_grad=True) * 0.3 + if not has_learnable_sink: learnable_sink = None out, lse, S_dmask = flash_attn_qkvpacked_func( qkv, @@ -660,9 +633,9 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ causal=causal, window_size=window_size, alibi_slopes=alibi_slopes, - learnable_sink=learnable_sink, deterministic=deterministic, return_attn_probs=True, + learnable_sink=learnable_sink, ) if dropout_p > 0.0: S_dmask_converted = convert_flash_attn_S_to_softmax( @@ -709,9 +682,9 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ dropout_mask, causal=causal, window_size=window_size, - learnable_sink=learnable_sink, upcast=False, reorder_ops=True, + learnable_sink=learnable_sink, ) # v = qkv[:, :, 2].float() # qk = torch.einsum('bshd,bthd->bhst', qkv[:, :, 0], qkv[:, :, 1]).float() @@ -814,10 +787,6 @@ def test_flash_attn_varlen_qkvpacked( qkv = torch.randn( batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True ) - if has_learnable_sink: - learnable_sink = torch.randn(nheads, device=device, dtype=torch.float32, requires_grad=True) * 0.3 - else: - learnable_sink = None key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode="random") # key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full') @@ -829,6 +798,10 @@ def test_flash_attn_varlen_qkvpacked( else: alibi_slopes, attn_bias = None, None + learnable_sink = torch.randn(nheads, device=device, dtype=dtype, requires_grad=True) * 0.3 + if not has_learnable_sink: + learnable_sink = None + qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv( *qkv.unbind(dim=2), key_padding_mask, key_padding_mask, qkvpacked=True ) @@ -841,9 +814,9 @@ def test_flash_attn_varlen_qkvpacked( causal=causal, window_size=window_size, alibi_slopes=alibi_slopes, - learnable_sink=learnable_sink, deterministic=deterministic, return_attn_probs=True, + learnable_sink=learnable_sink, ) out = output_pad_fn(out_unpad) if dropout_p > 0.0: @@ -898,9 +871,9 @@ def test_flash_attn_varlen_qkvpacked( dropout_mask, causal=causal, window_size=window_size, - learnable_sink=learnable_sink, upcast=False, reorder_ops=True, + learnable_sink=learnable_sink, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") @@ -1029,12 +1002,11 @@ def test_flash_attn_output( attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal) else: alibi_slopes, attn_bias = None, None - if has_learnable_sink: - learnable_sink = torch.randn(nheads, device=device, dtype=torch.float32, requires_grad=True) * 0.3 - if softcap > 0: - learnable_sink = learnable_sink * softcap * 2 - else: + + learnable_sink = torch.randn(nheads, device=device, dtype=dtype, requires_grad=True) * 0.3 + if not has_learnable_sink: learnable_sink = None + if kvpacked: out, lse, S_dmask = flash_attn_kvpacked_func( q, @@ -1044,9 +1016,9 @@ def test_flash_attn_output( window_size=window_size, softcap=softcap, alibi_slopes=alibi_slopes, - learnable_sink=learnable_sink, deterministic=deterministic, return_attn_probs=True, + learnable_sink=learnable_sink, ) else: out, lse, S_dmask = flash_attn_func( @@ -1058,9 +1030,9 @@ def test_flash_attn_output( window_size=window_size, softcap=softcap, alibi_slopes=alibi_slopes, - learnable_sink=learnable_sink, deterministic=deterministic, return_attn_probs=True, + learnable_sink=learnable_sink, ) if dropout_p > 0.0: S_dmask_converted = convert_flash_attn_S_to_softmax( @@ -1113,8 +1085,8 @@ def test_flash_attn_output( dropout_mask, causal=causal, window_size=window_size, - learnable_sink=learnable_sink, softcap=softcap, + learnable_sink=learnable_sink, ) out_pt, attn_pt = attention_kvpacked_ref( q, @@ -1126,10 +1098,10 @@ def test_flash_attn_output( dropout_mask, causal=causal, window_size=window_size, - learnable_sink=learnable_sink, softcap=softcap, upcast=False, reorder_ops=True, + learnable_sink=learnable_sink, ) else: out_ref, attn_ref = attention_ref( @@ -1143,8 +1115,8 @@ def test_flash_attn_output( dropout_mask, causal=causal, window_size=window_size, - learnable_sink=learnable_sink, softcap=softcap, + learnable_sink=learnable_sink, ) out_pt, attn_pt = attention_ref( q, @@ -1157,10 +1129,10 @@ def test_flash_attn_output( dropout_mask, causal=causal, window_size=window_size, - learnable_sink=learnable_sink, softcap=softcap, upcast=False, reorder_ops=True, + learnable_sink=learnable_sink, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") @@ -1219,9 +1191,9 @@ def test_flash_attn_output( print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") if has_learnable_sink: - dsink, = torch.autograd.grad(out, (learnable_sink,), g) - dsink_ref, = torch.autograd.grad(out_ref, (learnable_sink,), g) - dsink_pt, = torch.autograd.grad(out_pt, (learnable_sink,), g) + (dsink,) = torch.autograd.grad(out, learnable_sink, g) + (dsink_ref,) = torch.autograd.grad(out_ref, learnable_sink, g) + (dsink_pt,) = torch.autograd.grad(out_pt, learnable_sink, g) print(f"dSink max diff: {(dsink - dsink_ref).abs().max().item()}") print(f"dSink mean diff: {(dsink - dsink_ref).abs().mean().item()}") print(f"dSink Pytorch max diff: {(dsink_pt - dsink_ref).abs().max().item()}") @@ -1330,10 +1302,11 @@ def test_flash_attn_varlen_output( ) else: alibi_slopes, attn_bias = None, None - if has_learnable_sink: - learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) * 0.3 - else: + + learnable_sink = torch.randn(nheads, device=device, dtype=dtype, requires_grad=True) * 0.3 + if not has_learnable_sink: learnable_sink = None + if kvpacked: ( q_unpad, @@ -1358,11 +1331,11 @@ def test_flash_attn_varlen_output( dropout_p, causal=causal, window_size=window_size, - learnable_sink=learnable_sink, softcap=softcap, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, + learnable_sink=learnable_sink, ) else: ( @@ -1391,11 +1364,11 @@ def test_flash_attn_varlen_output( dropout_p, causal=causal, window_size=window_size, - learnable_sink=learnable_sink, softcap=softcap, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, + learnable_sink=learnable_sink, ) out = output_pad_fn(out_unpad) if dropout_p > 0.0: @@ -1453,8 +1426,8 @@ def test_flash_attn_varlen_output( dropout_mask, causal=causal, window_size=window_size, - learnable_sink=learnable_sink, softcap=softcap, + learnable_sink=learnable_sink, ) out_pt, attn_pt = attention_kvpacked_ref( q, @@ -1466,10 +1439,10 @@ def test_flash_attn_varlen_output( dropout_mask, causal=causal, window_size=window_size, - learnable_sink=learnable_sink, softcap=softcap, upcast=False, reorder_ops=True, + learnable_sink=learnable_sink, ) else: out_ref, attn_ref = attention_ref( @@ -1483,8 +1456,8 @@ def test_flash_attn_varlen_output( dropout_mask, causal=causal, window_size=window_size, - learnable_sink=learnable_sink, softcap=softcap, + learnable_sink=learnable_sink, ) out_pt, attn_pt = attention_ref( q, @@ -1497,10 +1470,10 @@ def test_flash_attn_varlen_output( dropout_mask, causal=causal, window_size=window_size, - learnable_sink=learnable_sink, softcap=softcap, upcast=False, reorder_ops=True, + learnable_sink=learnable_sink, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") @@ -1561,9 +1534,9 @@ def test_flash_attn_varlen_output( print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") if has_learnable_sink: - dsink, = torch.autograd.grad(out, (learnable_sink,), g) - dsink_ref, = torch.autograd.grad(out_ref, (learnable_sink,), g) - dsink_pt, = torch.autograd.grad(out_pt, (learnable_sink,), g) + (dsink,) = torch.autograd.grad(out, learnable_sink, g) + (dsink_ref,) = torch.autograd.grad(out_ref, learnable_sink, g) + (dsink_pt,) = torch.autograd.grad(out_pt, learnable_sink, g) print(f"dSink max diff: {(dsink - dsink_ref).abs().max().item()}") print(f"dSink mean diff: {(dsink - dsink_ref).abs().mean().item()}") print(f"dSink Pytorch max diff: {(dsink_pt - dsink_ref).abs().max().item()}") @@ -1633,9 +1606,8 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype, has_ q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) - if has_learnable_sink: - learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) * 0.3 - else: + learnable_sink = torch.randn(nheads, device=device, dtype=dtype, requires_grad=True) * 0.3 + if not has_learnable_sink: learnable_sink = None out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, learnable_sink=learnable_sink) out_ref, attn_ref = attention_ref( @@ -1652,9 +1624,9 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype, has_ None, causal=causal, window_size=window_size, - learnable_sink=learnable_sink, upcast=False, reorder_ops=True, + learnable_sink=learnable_sink, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") @@ -1692,9 +1664,9 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype, has_ print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") if has_learnable_sink: - dsink, = torch.autograd.grad(out, (learnable_sink,), g) - dsink_ref, = torch.autograd.grad(out_ref, (learnable_sink,), g) - dsink_pt, = torch.autograd.grad(out_pt, (learnable_sink,), g) + (dsink,) = torch.autograd.grad(out, learnable_sink, g) + (dsink_ref,) = torch.autograd.grad(out_ref, learnable_sink, g) + (dsink_pt,) = torch.autograd.grad(out_pt, learnable_sink, g) print(f"dSink max diff: {(dsink - dsink_ref).abs().max().item()}") print(f"dSink mean diff: {(dsink - dsink_ref).abs().mean().item()}") print(f"dSink Pytorch max diff: {(dsink_pt - dsink_ref).abs().max().item()}") @@ -1743,7 +1715,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype, has_ # @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)]) @pytest.mark.parametrize("has_learnable_sink", [False, True]) def test_flash_attn_varlen_causal( - seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, has_learnable_sink, dtype + seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype, has_learnable_sink ): if ( max(seqlen_q, seqlen_k) >= 2048 @@ -1773,9 +1745,8 @@ def test_flash_attn_varlen_causal( k, v, block_table, k_cache_paged, v_cache_paged, num_blocks = _generate_block_kvcache( seqlen_k, paged_kv_block_size, batch_size, nheads, d, device, dtype ) - if has_learnable_sink: - learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) * 0.3 - else: + learnable_sink = torch.randn(nheads, device=device, dtype=dtype, requires_grad=True) * 0.3 + if not has_learnable_sink: learnable_sink = None query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") @@ -1805,8 +1776,8 @@ def test_flash_attn_varlen_causal( 0.0, causal=causal, window_size=window_size, - learnable_sink=learnable_sink, block_table=block_table, + learnable_sink=learnable_sink, ) out = output_pad_fn(out_unpad) out_ref, attn_ref = attention_ref( @@ -1833,9 +1804,9 @@ def test_flash_attn_varlen_causal( None, causal=causal, window_size=window_size, - learnable_sink=learnable_sink, upcast=False, reorder_ops=True, + learnable_sink=learnable_sink, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") @@ -1878,9 +1849,9 @@ def test_flash_attn_varlen_causal( print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") if has_learnable_sink: - dsink, = torch.autograd.grad(out, (learnable_sink,), g) - dsink_ref, = torch.autograd.grad(out_ref, (learnable_sink,), g) - dsink_pt, = torch.autograd.grad(out_pt, (learnable_sink,), g) + (dsink,) = torch.autograd.grad(out, learnable_sink, g) + (dsink_ref,) = torch.autograd.grad(out_ref, learnable_sink, g) + (dsink_pt,) = torch.autograd.grad(out_pt, learnable_sink, g) print(f"dSink max diff: {(dsink - dsink_ref).abs().max().item()}") print(f"dSink mean diff: {(dsink - dsink_ref).abs().mean().item()}") print(f"dSink Pytorch max diff: {(dsink_pt - dsink_ref).abs().max().item()}") @@ -1953,10 +1924,11 @@ def test_flash_attn_splitkv( attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal) else: alibi_slopes, attn_bias = None, None - if has_learnable_sink: - learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) * 0.3 - else: + + learnable_sink = torch.randn(nheads, device=device, dtype=dtype, requires_grad=True) * 0.3 + if not has_learnable_sink: learnable_sink = None + out, lse, _ = flash_attn_func( q, k, @@ -1965,9 +1937,9 @@ def test_flash_attn_splitkv( causal=causal, window_size=window_size, alibi_slopes=alibi_slopes, - learnable_sink=learnable_sink, deterministic=deterministic, return_attn_probs=True, + learnable_sink=learnable_sink, ) out_ref, attn_ref = attention_ref( q, k, v, None, None, attn_bias, 0.0, None, causal=causal, window_size=window_size, learnable_sink=learnable_sink @@ -1983,9 +1955,9 @@ def test_flash_attn_splitkv( None, causal=causal, window_size=window_size, - learnable_sink=learnable_sink, upcast=False, reorder_ops=True, + learnable_sink=learnable_sink, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") @@ -2023,9 +1995,9 @@ def test_flash_attn_splitkv( print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") if has_learnable_sink: - dsink, = torch.autograd.grad(out, (learnable_sink,), g) - dsink_ref, = torch.autograd.grad(out_ref, (learnable_sink,), g) - dsink_pt, = torch.autograd.grad(out_pt, (learnable_sink,), g) + (dsink,) = torch.autograd.grad(out, learnable_sink, g) + (dsink_ref,) = torch.autograd.grad(out_ref, learnable_sink, g) + (dsink_pt,) = torch.autograd.grad(out_pt, learnable_sink, g) print(f"dSink max diff: {(dsink - dsink_ref).abs().max().item()}") print(f"dSink mean diff: {(dsink - dsink_ref).abs().mean().item()}") print(f"dSink Pytorch max diff: {(dsink_pt - dsink_ref).abs().max().item()}") @@ -2110,7 +2082,7 @@ def test_flash_attn_kvcache( mha_type, num_splits, dtype, - has_learnable_sink + has_learnable_sink, ): if seqlen_q > seqlen_k and new_kv: pytest.skip() @@ -2194,9 +2166,9 @@ def test_flash_attn_kvcache( ) else: alibi_slopes, attn_bias = None, None - if has_learnable_sink: - learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) * 0.3 - else: + + learnable_sink = torch.randn(nheads, device=device, dtype=dtype, requires_grad=True) * 0.3 + if not has_learnable_sink: learnable_sink = None # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) if rotary_dim > 0: @@ -2265,8 +2237,8 @@ def test_flash_attn_kvcache( window_size=window_size, rotary_interleaved=rotary_interleaved, alibi_slopes=alibi_slopes, - learnable_sink=learnable_sink, num_splits=num_splits, + learnable_sink=learnable_sink, ) # out = flash_attn_with_kvcache( # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size @@ -2289,8 +2261,8 @@ def test_flash_attn_kvcache( None, causal=causal, window_size=window_size, - learnable_sink=learnable_sink, key_leftpad=cache_leftpad, + learnable_sink=learnable_sink, ) out_pt, _ = attention_ref( q_ro, @@ -2303,10 +2275,10 @@ def test_flash_attn_kvcache( None, causal=causal, window_size=window_size, - learnable_sink=learnable_sink, upcast=False, reorder_ops=True, key_leftpad=cache_leftpad, + learnable_sink=learnable_sink, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") @@ -2406,9 +2378,8 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dty q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) - if has_learnable_sink: - learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) * 0.3 - else: + learnable_sink = torch.randn(nheads, device=device, dtype=dtype, requires_grad=True) * 0.3 + if not has_learnable_sink: learnable_sink = None torch.random.manual_seed(42) out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True, learnable_sink=learnable_sink) @@ -2477,7 +2448,7 @@ def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype, has_learnable_sink): k.requires_grad_(True) v.requires_grad_(True) if has_learnable_sink: - learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) + learnable_sink = torch.randn((nheads,), device=device, dtype=dtype, requires_grad=True) else: learnable_sink = None out = flash_attn_func(q, k, v, causal=causal, learnable_sink=learnable_sink) @@ -2542,7 +2513,7 @@ def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype, has_learnable_sink): for _ in range(3) ] if has_learnable_sink: - learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) + learnable_sink = torch.randn((nheads,), device=device, dtype=dtype, requires_grad=True) else: learnable_sink = None out = rearrange(flash_attn_func(q, k, v, causal=causal, learnable_sink=learnable_sink), "b s ... -> s b ...") @@ -2612,7 +2583,7 @@ def test_flash_attn_bwd_varlen_overflow(d, causal, dtype, has_learnable_sink): k.requires_grad_(True) v.requires_grad_(True) if has_learnable_sink: - learnable_sink = torch.randn((nheads,), device=device, dtype=torch.float32, requires_grad=True) + learnable_sink = torch.randn((nheads,), device=device, dtype=dtype, requires_grad=True) else: learnable_sink = None