Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
d715844
Add new API.
Aug 15, 2025
5501390
Add tests for sink.
Aug 15, 2025
f3b159e
Right fwd.
Aug 16, 2025
be46c29
Right bwd.
Aug 17, 2025
a53a8d8
Refine.
Aug 17, 2025
5edbfdc
Fix warp reduce.
Aug 18, 2025
b4fc07b
Add Tests.
Aug 19, 2025
66b88a5
Fix max in fwd.
Aug 20, 2025
27a005f
Fix init dsink.
Aug 20, 2025
26b4704
Add to existing interface.
Aug 21, 2025
9a57df2
Modify bwd.
Aug 21, 2025
4cf0f88
Fix.
Aug 21, 2025
9880064
Fix lse in combine_attn_seqk_parallel.
Aug 21, 2025
9117304
Update tests.
Aug 21, 2025
46b8ee9
Clean code.
Aug 22, 2025
c00f806
forbidding Learnable sink and ALiBi to party together.
Aug 22, 2025
f5044c8
learnable_sink optonal.
Aug 23, 2025
793f3f5
Fix arg learnable_sink.
Aug 24, 2025
51524da
Fix bwd with sink and dropout.
Aug 26, 2025
a317bae
Debug for softmax LSE calculation.
Henry0215 Sep 4, 2025
594365e
Debug for sink calculation.
Henry0215 Sep 5, 2025
166b7c5
Merge pull request #2 from Henry0215/feature/attention_with_sink
aoxy Sep 10, 2025
0a4ebac
Fix some bugs.
Sep 10, 2025
067cf47
Update tests.
Sep 11, 2025
408083e
Fix.
Sep 13, 2025
f5dd38e
Fix bugs.
Sep 15, 2025
b0ed042
clean code.
Sep 16, 2025
b8f9953
Fix tests.
Sep 18, 2025
8a80215
Merge branch 'main' into feature/attention_with_sink
Sep 18, 2025
9d099a0
Add new line.
Oct 21, 2025
9eb63cb
Merge branch 'main' into feature/attention_with_sink
Oct 21, 2025
be83a26
Merge branch 'main' into feature/attention_with_sink
aoxy Jan 22, 2026
0d26f1e
Refactoring: Unifying sink-related parameter processing and optimizin…
aoxy Feb 25, 2026
bb35f08
Merge branch 'main' into feature/attention_with_sink
aoxy Feb 25, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion benchmarks/benchmark_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def time_fwd_bwd(func, *args, **kwargs):
dim = 2048
dropout_p = 0.0

methods = (["Flash2", "Pytorch"]
methods = (["Flash2", "Pytorch", "Flash2Sink"]
+ (["Triton"] if attention_triton is not None else [])
+ (["xformers.c"] if xops is not None else [])
+ (["xformers.f"] if xops is not None else []))
Expand Down Expand Up @@ -106,6 +106,17 @@ def time_fwd_bwd(func, *args, **kwargs):
time_f[config, "Flash2"] = f
time_b[config, "Flash2"] = b

# FlashAttention 2 with sink
if "Flash2Sink" in methods:
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim,
device=device, dtype=dtype, requires_grad=True)
sink = torch.randn((nheads,), dtype=dtype, device=device, requires_grad=True)
f, b = time_fwd_bwd(
flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, learnable_sink=sink, repeats=repeats, verbose=False
)
time_f[config, "Flash2Sink"] = f
time_b[config, "Flash2Sink"] = b

# PyTorch baseline
if "Pytorch" in methods:
try:
Expand Down
132 changes: 121 additions & 11 deletions csrc/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_mult
int window_size_right,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_) {
std::optional<at::Generator> gen_,
std::optional<const at::Tensor> learnable_sink_) {

// Otherwise the kernel will be launched from cuda:0 device
at::cuda::CUDAGuard device_guard{q.device()};
Expand Down Expand Up @@ -405,7 +406,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_mult

// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value();
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value() && !learnable_sink_.has_value();
const int ngroups = num_heads / num_heads_k;
if (seqlenq_ngroups_swapped) {
q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2);
Expand Down Expand Up @@ -469,6 +470,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_mult
softcap
);


// Keep references to these tensors to extend their lifetime
at::Tensor softmax_lse_accum, out_accum;
std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(
Expand All @@ -494,6 +496,18 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_mult

set_params_alibi(params, alibi_slopes_, batch_size, num_heads);

at::Tensor learnable_sink;
if (learnable_sink_.has_value()) {
learnable_sink = learnable_sink_.value().to(torch::kFloat32);
TORCH_CHECK(params.alibi_slopes_ptr == nullptr, "Learnable sink and ALiBi slopes cannot be used together");
CHECK_DEVICE(learnable_sink); CHECK_CONTIGUOUS(learnable_sink);
TORCH_CHECK(learnable_sink.stride(-1) == 1, "Learnable sink tensor must have contiguous last dimension");
CHECK_SHAPE(learnable_sink, num_heads);
params.learnable_sink_ptr = learnable_sink.data_ptr();
} else {
params.learnable_sink_ptr = nullptr;
}

if (seqlen_k > 0) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
run_mha_fwd(params, stream);
Expand All @@ -515,7 +529,7 @@ std::vector<at::Tensor>
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
std::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
Expand All @@ -532,7 +546,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
int window_size_right,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_) {
std::optional<at::Generator> gen_,
std::optional<const at::Tensor> learnable_sink_) {

// Otherwise the kernel will be launched from cuda:0 device
at::cuda::CUDAGuard device_guard{q.device()};
Expand Down Expand Up @@ -589,7 +604,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s

// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value();
const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value() && !learnable_sink_.has_value();
const int ngroups = num_heads / num_heads_k;
if (seqlenq_ngroups_swapped) {
q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size});
Expand Down Expand Up @@ -734,6 +749,18 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s

set_params_alibi(params, alibi_slopes_, batch_size, num_heads);

at::Tensor learnable_sink;
if (learnable_sink_.has_value()) {
learnable_sink = learnable_sink_.value().to(torch::kFloat32);
TORCH_CHECK(params.alibi_slopes_ptr == nullptr, "Learnable sink and ALiBi slopes cannot be used together");
CHECK_DEVICE(learnable_sink); CHECK_CONTIGUOUS(learnable_sink);
TORCH_CHECK(learnable_sink.stride(-1) == 1, "Learnable sink tensor must have contiguous last dimension");
CHECK_SHAPE(learnable_sink, num_heads);
params.learnable_sink_ptr = learnable_sink.data_ptr();
} else {
params.learnable_sink_ptr = nullptr;
}

if (max_seqlen_k > 0) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
run_mha_fwd(params, stream, paged_KV);
Expand Down Expand Up @@ -783,7 +810,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl
const float softcap,
const bool deterministic,
std::optional<at::Generator> gen_,
std::optional<at::Tensor> &rng_state) {
std::optional<at::Tensor> &rng_state,
std::optional<const at::Tensor> learnable_sink_,
std::optional<at::Tensor> dsink_) {

#ifdef FLASHATTENTION_DISABLE_BACKWARD
TORCH_CHECK(false, "This flash attention build does not support backward.");
Expand Down Expand Up @@ -952,6 +981,30 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl

set_params_alibi(params, alibi_slopes_, batch_size, num_heads);

at::Tensor dsink;
at::Tensor learnable_sink;
if (learnable_sink_.has_value()) {
learnable_sink = learnable_sink_.value().to(torch::kFloat32);
TORCH_CHECK(params.alibi_slopes_ptr == nullptr, "Learnable sink and ALiBi slopes cannot be used together");
CHECK_DEVICE(learnable_sink); CHECK_CONTIGUOUS(learnable_sink);
TORCH_CHECK(learnable_sink.stride(-1) == 1, "Learnable sink tensor must have contiguous last dimension");
CHECK_SHAPE(learnable_sink, num_heads);
if (dsink_.has_value() && dsink_.value().dtype() == torch::kFloat32) {
dsink = dsink_.value();
CHECK_DEVICE(dsink); CHECK_CONTIGUOUS(dsink);
TORCH_CHECK(dsink.stride(-1) == 1, "dsink tensor must have contiguous last dimension");
CHECK_SHAPE(dsink, num_heads);
} else {
dsink = torch::empty_like(learnable_sink);
}
dsink.zero_();
params.learnable_sink_ptr = learnable_sink.data_ptr();
params.dsink_ptr = dsink.data_ptr();
} else {
params.learnable_sink_ptr = nullptr;
params.dsink_ptr = nullptr;
}

if (seqlen_q > 0) {
launch(params, stream);
} else {
Expand All @@ -967,7 +1020,16 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl
at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
}

return { dq, dk, dv, softmax_d };
if (learnable_sink_.has_value()) {
if (!dsink_.has_value()) {
dsink = dsink.to(learnable_sink_.value().dtype());
} else if (dsink_.value().dtype() != torch::kFloat32) {
dsink = dsink.to(dsink_.value().dtype());
dsink_.value().copy_(dsink);
}
}

return { dq, dk, dv, softmax_d, dsink };
}

std::vector<at::Tensor>
Expand All @@ -994,7 +1056,9 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const float softcap,
const bool deterministic,
std::optional<at::Generator> gen_,
std::optional<at::Tensor> &rng_state) {
std::optional<at::Tensor> &rng_state,
std::optional<const at::Tensor> learnable_sink_,
std::optional<at::Tensor> dsink_) {

#ifdef FLASHATTENTION_DISABLE_BACKWARD
TORCH_CHECK(false, "This flash attention build does not support backward.");
Expand Down Expand Up @@ -1181,6 +1245,30 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size

set_params_alibi(params, alibi_slopes_, batch_size, num_heads);

at::Tensor dsink;
at::Tensor learnable_sink;
if (learnable_sink_.has_value()) {
learnable_sink = learnable_sink_.value().to(torch::kFloat32);
TORCH_CHECK(params.alibi_slopes_ptr == nullptr, "Learnable sink and ALiBi slopes cannot be used together");
CHECK_DEVICE(learnable_sink); CHECK_CONTIGUOUS(learnable_sink);
TORCH_CHECK(learnable_sink.stride(-1) == 1, "Learnable sink tensor must have contiguous last dimension");
CHECK_SHAPE(learnable_sink, num_heads);
if (dsink_.has_value() && dsink_.value().dtype() == torch::kFloat32) {
dsink = dsink_.value();
CHECK_DEVICE(dsink); CHECK_CONTIGUOUS(dsink);
TORCH_CHECK(dsink.stride(-1) == 1, "dsink tensor must have contiguous last dimension");
CHECK_SHAPE(dsink, num_heads);
} else {
dsink = torch::empty_like(learnable_sink);
}
dsink.zero_();
params.learnable_sink_ptr = learnable_sink.data_ptr();
params.dsink_ptr = dsink.data_ptr();
} else {
params.learnable_sink_ptr = nullptr;
params.dsink_ptr = nullptr;
}

if (max_seqlen_q > 0) {
launch(params, stream);
} else {
Expand All @@ -1196,7 +1284,16 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
}

return { dq, dk, dv, softmax_d };
if (learnable_sink_.has_value()) {
if (!dsink_.has_value()) {
dsink = dsink.to(learnable_sink_.value().dtype());
} else if (dsink_.value().dtype() != torch::kFloat32) {
dsink = dsink.to(dsink_.value().dtype());
dsink_.value().copy_(dsink);
}
}

return { dq, dk, dv, softmax_d, dsink };
}

std::vector<at::Tensor>
Expand All @@ -1219,7 +1316,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
int window_size_right,
const float softcap,
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int num_splits
int num_splits,
std::optional<const at::Tensor> learnable_sink_
) {

// Otherwise the kernel will be launched from cuda:0 device
Expand Down Expand Up @@ -1275,7 +1373,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he

// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value() && !learnable_sink_.has_value();
if (seqlenq_ngroups_swapped) {
const int ngroups = num_heads / num_heads_k;
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
Expand Down Expand Up @@ -1451,6 +1549,18 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he

set_params_alibi(params, alibi_slopes_, batch_size, num_heads);

at::Tensor learnable_sink;
if (learnable_sink_.has_value()) {
learnable_sink = learnable_sink_.value().to(torch::kFloat32);
TORCH_CHECK(params.alibi_slopes_ptr == nullptr, "Learnable sink and ALiBi slopes cannot be used together");
CHECK_DEVICE(learnable_sink); CHECK_CONTIGUOUS(learnable_sink);
TORCH_CHECK(learnable_sink.stride(-1) == 1, "Learnable sink tensor must have contiguous last dimension");
CHECK_SHAPE(learnable_sink, num_heads);
params.learnable_sink_ptr = learnable_sink.data_ptr();
} else {
params.learnable_sink_ptr = nullptr;
}

auto stream = at::cuda::getCurrentCUDAStream().stream();
// Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx,
// or paged KV cache
Expand Down
2 changes: 2 additions & 0 deletions csrc/flash_attn/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ struct Flash_fwd_params : public Qkv_params {

bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q].
bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d).
void *__restrict__ learnable_sink_ptr;
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -182,6 +183,7 @@ struct Flash_bwd_params : public Flash_fwd_params {

bool deterministic;
index_t dq_accum_split_stride;
void *__restrict__ dsink_ptr;
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Loading