Skip to content

feat: Implement Sink Attention#1

Draft
aoxy wants to merge 34 commits into
mainfrom
feature/attention_with_sink
Draft

feat: Implement Sink Attention#1
aoxy wants to merge 34 commits into
mainfrom
feature/attention_with_sink

Conversation

@aoxy
Copy link
Copy Markdown
Owner

@aoxy aoxy commented Aug 18, 2025

This implementation provides a fused and highly optimized kernel for performing Sink Attention within the FlashAttention framework.

Key Changes

  • CUDA Kernels: The core forward and backward CUDA kernels (flash_fwd_kernel.h, flash_bwd_kernel.h) have been extended to optionally accept a sink tensor. The softmax calculation is modified to incorporate the sink logits, ensuring the initial tokens receive consistent attention.
  • C++ API: New backend functions (mha_sink_fwd, mha_sink_bwd) have been added to flash_api.cpp and exposed to Python, providing a direct interface to the new kernel logic.
  • Python Interface: A new user-facing function, flash_attn_sink_func, is introduced. It is implemented as a torch.autograd.Function that seamlessly integrates with PyTorch's automatic differentiation.
  • Testing & Benchmarking:
    • A naive, eager implementation of Sink Attention (naive_attn_with_sink.py) was added to serve as a ground truth for correctness.
    • Comprehensive tests (test.py, test_fused.py) were added to validate the fused kernel's forward and backward passes against the naive implementation, ensuring numerical correctness.
    • The main benchmark script (benchmark_flash_attention.py) has been updated to include Flash2Sink and Flash2SinkFused methods to measure the performance of this new feature.

Motivation

The growing need for efficient handling of long contexts and streaming applications in LLMs makes Sink Attention a vital feature. This implementation allows models that rely on this mechanism to leverage the speed and memory efficiency of FlashAttention.

Fixes Dao-AILab#1802

@aoxy aoxy changed the title Feature/attention with sink feat: Implement Sink Attention Aug 18, 2025
@aoxy
Copy link
Copy Markdown
Owner Author

aoxy commented Aug 18, 2025

My current implementation approach is to add a new top-level interface flash_attn_sink_func, which results in a lot of redundant code. Would you consider adding the sink parameter to the existing interfaces instead?

@aoxy
Copy link
Copy Markdown
Owner Author

aoxy commented Aug 18, 2025

H800

flash-attention/benchmarks$ python benchmark_flash_attention.py

causal=False, headdim=64, batch_size=32, seqlen=512

Flash2 fwd: 254.08 TFLOPs/s, bwd: 185.81 TFLOPs/s, fwd + bwd: 201.26 TFLOPs/s
Flash2UnPacked fwd: 260.74 TFLOPs/s, bwd: 187.10 TFLOPs/s, fwd + bwd: 203.52 TFLOPs/s
Pytorch fwd: 51.80 TFLOPs/s, bwd: 68.36 TFLOPs/s, fwd + bwd: 62.64 TFLOPs/s
Flash2Sink fwd: 114.83 TFLOPs/s, bwd: 52.10 TFLOPs/s, fwd + bwd: 61.74 TFLOPs/s
Flash2SinkFused fwd: 233.36 TFLOPs/s, bwd: 188.36 TFLOPs/s, fwd + bwd: 199.34 TFLOPs/s

causal=False, headdim=64, batch_size=16, seqlen=1024

Flash2 fwd: 290.81 TFLOPs/s, bwd: 229.42 TFLOPs/s, fwd + bwd: 244.14 TFLOPs/s
Flash2UnPacked fwd: 291.11 TFLOPs/s, bwd: 229.90 TFLOPs/s, fwd + bwd: 244.59 TFLOPs/s
Pytorch fwd: 62.52 TFLOPs/s, bwd: 76.96 TFLOPs/s, fwd + bwd: 72.20 TFLOPs/s
Flash2Sink fwd: 169.66 TFLOPs/s, bwd: 73.21 TFLOPs/s, fwd + bwd: 87.41 TFLOPs/s
Flash2SinkFused fwd: 257.55 TFLOPs/s, bwd: 227.58 TFLOPs/s, fwd + bwd: 235.41 TFLOPs/s

causal=False, headdim=64, batch_size=8, seqlen=2048

Flash2 fwd: 306.05 TFLOPs/s, bwd: 252.78 TFLOPs/s, fwd + bwd: 266.01 TFLOPs/s
Flash2UnPacked fwd: 306.19 TFLOPs/s, bwd: 253.60 TFLOPs/s, fwd + bwd: 266.69 TFLOPs/s
Pytorch fwd: 60.19 TFLOPs/s, bwd: 88.61 TFLOPs/s, fwd + bwd: 78.08 TFLOPs/s
Flash2Sink fwd: 222.61 TFLOPs/s, bwd: 91.46 TFLOPs/s, fwd + bwd: 109.97 TFLOPs/s
Flash2SinkFused fwd: 270.78 TFLOPs/s, bwd: 249.79 TFLOPs/s, fwd + bwd: 255.44 TFLOPs/s

causal=False, headdim=64, batch_size=4, seqlen=4096

Flash2 fwd: 313.88 TFLOPs/s, bwd: 266.13 TFLOPs/s, fwd + bwd: 278.22 TFLOPs/s
Flash2UnPacked fwd: 312.33 TFLOPs/s, bwd: 267.13 TFLOPs/s, fwd + bwd: 278.65 TFLOPs/s
Pytorch fwd: 53.28 TFLOPs/s, bwd: 93.25 TFLOPs/s, fwd + bwd: 76.79 TFLOPs/s
Flash2Sink fwd: 263.30 TFLOPs/s, bwd: 104.61 TFLOPs/s, fwd + bwd: 126.38 TFLOPs/s
Flash2SinkFused fwd: 276.16 TFLOPs/s, bwd: 266.60 TFLOPs/s, fwd + bwd: 269.26 TFLOPs/s

causal=False, headdim=64, batch_size=2, seqlen=8192

Flash2 fwd: 313.67 TFLOPs/s, bwd: 275.84 TFLOPs/s, fwd + bwd: 285.68 TFLOPs/s
Flash2UnPacked fwd: 313.09 TFLOPs/s, bwd: 274.15 TFLOPs/s, fwd + bwd: 284.25 TFLOPs/s
Pytorch fwd: 46.64 TFLOPs/s, bwd: 95.56 TFLOPs/s, fwd + bwd: 73.52 TFLOPs/s
Flash2Sink fwd: 288.33 TFLOPs/s, bwd: 112.51 TFLOPs/s, fwd + bwd: 136.25 TFLOPs/s
Flash2SinkFused fwd: 276.66 TFLOPs/s, bwd: 276.00 TFLOPs/s, fwd + bwd: 276.19 TFLOPs/s

causal=False, headdim=64, batch_size=1, seqlen=16384

Flash2 fwd: 315.12 TFLOPs/s, bwd: 278.46 TFLOPs/s, fwd + bwd: 288.03 TFLOPs/s
Flash2UnPacked fwd: 311.14 TFLOPs/s, bwd: 279.91 TFLOPs/s, fwd + bwd: 288.17 TFLOPs/s
Pytorch fwd: 79.41 TFLOPs/s, bwd: 96.31 TFLOPs/s, fwd + bwd: 90.79 TFLOPs/s
Flash2Sink fwd: 295.66 TFLOPs/s, bwd: 116.95 TFLOPs/s, fwd + bwd: 141.36 TFLOPs/s
Flash2SinkFused fwd: 277.14 TFLOPs/s, bwd: 278.60 TFLOPs/s, fwd + bwd: 278.18 TFLOPs/s

causal=False, headdim=128, batch_size=32, seqlen=512

Flash2 fwd: 297.26 TFLOPs/s, bwd: 198.26 TFLOPs/s, fwd + bwd: 219.11 TFLOPs/s
Flash2UnPacked fwd: 299.51 TFLOPs/s, bwd: 199.16 TFLOPs/s, fwd + bwd: 220.24 TFLOPs/s
Pytorch fwd: 74.38 TFLOPs/s, bwd: 105.23 TFLOPs/s, fwd + bwd: 94.08 TFLOPs/s
Flash2Sink fwd: 124.07 TFLOPs/s, bwd: 55.59 TFLOPs/s, fwd + bwd: 65.99 TFLOPs/s
Flash2SinkFused fwd: 259.25 TFLOPs/s, bwd: 198.85 TFLOPs/s, fwd + bwd: 213.04 TFLOPs/s

causal=False, headdim=128, batch_size=16, seqlen=1024

Flash2 fwd: 337.05 TFLOPs/s, bwd: 238.64 TFLOPs/s, fwd + bwd: 260.36 TFLOPs/s
Flash2UnPacked fwd: 338.04 TFLOPs/s, bwd: 239.44 TFLOPs/s, fwd + bwd: 261.21 TFLOPs/s
Pytorch fwd: 100.02 TFLOPs/s, bwd: 130.78 TFLOPs/s, fwd + bwd: 120.22 TFLOPs/s
Flash2Sink fwd: 187.80 TFLOPs/s, bwd: 78.33 TFLOPs/s, fwd + bwd: 93.98 TFLOPs/s
Flash2SinkFused fwd: 296.11 TFLOPs/s, bwd: 239.82 TFLOPs/s, fwd + bwd: 253.59 TFLOPs/s

causal=False, headdim=128, batch_size=8, seqlen=2048

Flash2 fwd: 357.42 TFLOPs/s, bwd: 264.75 TFLOPs/s, fwd + bwd: 285.93 TFLOPs/s
Flash2UnPacked fwd: 356.93 TFLOPs/s, bwd: 260.65 TFLOPs/s, fwd + bwd: 282.41 TFLOPs/s
Pytorch fwd: 107.00 TFLOPs/s, bwd: 159.81 TFLOPs/s, fwd + bwd: 140.06 TFLOPs/s
Flash2Sink fwd: 251.48 TFLOPs/s, bwd: 97.88 TFLOPs/s, fwd + bwd: 118.57 TFLOPs/s
Flash2SinkFused fwd: 314.67 TFLOPs/s, bwd: 260.74 TFLOPs/s, fwd + bwd: 274.17 TFLOPs/s

causal=False, headdim=128, batch_size=4, seqlen=4096

Flash2 fwd: 366.47 TFLOPs/s, bwd: 279.28 TFLOPs/s, fwd + bwd: 299.65 TFLOPs/s
Flash2UnPacked fwd: 365.22 TFLOPs/s, bwd: 280.52 TFLOPs/s, fwd + bwd: 300.43 TFLOPs/s
Pytorch fwd: 100.46 TFLOPs/s, bwd: 175.38 TFLOPs/s, fwd + bwd: 144.57 TFLOPs/s
Flash2Sink fwd: 301.79 TFLOPs/s, bwd: 111.88 TFLOPs/s, fwd + bwd: 136.40 TFLOPs/s
Flash2SinkFused fwd: 323.99 TFLOPs/s, bwd: 278.59 TFLOPs/s, fwd + bwd: 290.21 TFLOPs/s

causal=False, headdim=128, batch_size=2, seqlen=8192

Flash2 fwd: 369.66 TFLOPs/s, bwd: 289.60 TFLOPs/s, fwd + bwd: 308.70 TFLOPs/s
Flash2UnPacked fwd: 368.44 TFLOPs/s, bwd: 290.14 TFLOPs/s, fwd + bwd: 308.89 TFLOPs/s
Pytorch fwd: 90.81 TFLOPs/s, bwd: 184.42 TFLOPs/s, fwd + bwd: 142.46 TFLOPs/s
Flash2Sink fwd: 335.86 TFLOPs/s, bwd: 120.21 TFLOPs/s, fwd + bwd: 147.22 TFLOPs/s
Flash2SinkFused fwd: 326.36 TFLOPs/s, bwd: 289.08 TFLOPs/s, fwd + bwd: 298.83 TFLOPs/s

causal=False, headdim=128, batch_size=1, seqlen=16384

Flash2 fwd: 370.63 TFLOPs/s, bwd: 295.78 TFLOPs/s, fwd + bwd: 313.90 TFLOPs/s
Flash2UnPacked fwd: 360.55 TFLOPs/s, bwd: 296.29 TFLOPs/s, fwd + bwd: 312.19 TFLOPs/s
Pytorch fwd: 153.21 TFLOPs/s, bwd: 188.86 TFLOPs/s, fwd + bwd: 177.09 TFLOPs/s
Flash2Sink fwd: 352.78 TFLOPs/s, bwd: 125.45 TFLOPs/s, fwd + bwd: 153.76 TFLOPs/s
Flash2SinkFused fwd: 329.46 TFLOPs/s, bwd: 296.61 TFLOPs/s, fwd + bwd: 305.31 TFLOPs/s

causal=True, headdim=64, batch_size=32, seqlen=512

Flash2 fwd: 175.09 TFLOPs/s, bwd: 124.75 TFLOPs/s, fwd + bwd: 135.91 TFLOPs/s
Flash2UnPacked fwd: 176.13 TFLOPs/s, bwd: 125.47 TFLOPs/s, fwd + bwd: 136.70 TFLOPs/s
Pytorch fwd: 16.10 TFLOPs/s, bwd: 34.13 TFLOPs/s, fwd + bwd: 25.85 TFLOPs/s
Flash2Sink fwd: 64.46 TFLOPs/s, bwd: 30.98 TFLOPs/s, fwd + bwd: 36.38 TFLOPs/s
Flash2SinkFused fwd: 165.45 TFLOPs/s, bwd: 124.65 TFLOPs/s, fwd + bwd: 134.09 TFLOPs/s

causal=True, headdim=64, batch_size=16, seqlen=1024

Flash2 fwd: 223.30 TFLOPs/s, bwd: 174.80 TFLOPs/s, fwd + bwd: 186.37 TFLOPs/s
Flash2UnPacked fwd: 223.73 TFLOPs/s, bwd: 175.38 TFLOPs/s, fwd + bwd: 186.92 TFLOPs/s
Pytorch fwd: 18.12 TFLOPs/s, bwd: 38.38 TFLOPs/s, fwd + bwd: 29.09 TFLOPs/s
Flash2Sink fwd: 106.79 TFLOPs/s, bwd: 49.59 TFLOPs/s, fwd + bwd: 58.55 TFLOPs/s
Flash2SinkFused fwd: 212.14 TFLOPs/s, bwd: 175.01 TFLOPs/s, fwd + bwd: 184.22 TFLOPs/s

causal=True, headdim=64, batch_size=8, seqlen=2048

Flash2 fwd: 255.18 TFLOPs/s, bwd: 217.07 TFLOPs/s, fwd + bwd: 226.75 TFLOPs/s
Flash2UnPacked fwd: 255.71 TFLOPs/s, bwd: 216.73 TFLOPs/s, fwd + bwd: 226.60 TFLOPs/s
Pytorch fwd: 17.56 TFLOPs/s, bwd: 44.27 TFLOPs/s, fwd + bwd: 30.86 TFLOPs/s
Flash2Sink fwd: 157.26 TFLOPs/s, bwd: 70.39 TFLOPs/s, fwd + bwd: 83.58 TFLOPs/s
Flash2SinkFused fwd: 245.43 TFLOPs/s, bwd: 217.60 TFLOPs/s, fwd + bwd: 224.89 TFLOPs/s

causal=True, headdim=64, batch_size=4, seqlen=4096

Flash2 fwd: 275.49 TFLOPs/s, bwd: 242.80 TFLOPs/s, fwd + bwd: 251.32 TFLOPs/s
Flash2UnPacked fwd: 274.88 TFLOPs/s, bwd: 243.27 TFLOPs/s, fwd + bwd: 251.53 TFLOPs/s
Pytorch fwd: 15.46 TFLOPs/s, bwd: 46.61 TFLOPs/s, fwd + bwd: 29.58 TFLOPs/s
Flash2Sink fwd: 206.07 TFLOPs/s, bwd: 88.81 TFLOPs/s, fwd + bwd: 106.05 TFLOPs/s
Flash2SinkFused fwd: 265.78 TFLOPs/s, bwd: 246.56 TFLOPs/s, fwd + bwd: 251.76 TFLOPs/s

causal=True, headdim=64, batch_size=2, seqlen=8192

Flash2 fwd: 286.01 TFLOPs/s, bwd: 263.67 TFLOPs/s, fwd + bwd: 269.69 TFLOPs/s
Flash2UnPacked fwd: 285.35 TFLOPs/s, bwd: 264.82 TFLOPs/s, fwd + bwd: 270.38 TFLOPs/s
Pytorch fwd: 14.19 TFLOPs/s, bwd: 47.71 TFLOPs/s, fwd + bwd: 28.49 TFLOPs/s
Flash2Sink fwd: 244.20 TFLOPs/s, bwd: 102.64 TFLOPs/s, fwd + bwd: 123.02 TFLOPs/s
Flash2SinkFused fwd: 277.75 TFLOPs/s, bwd: 262.48 TFLOPs/s, fwd + bwd: 266.67 TFLOPs/s

causal=True, headdim=64, batch_size=1, seqlen=16384

Flash2 fwd: 290.80 TFLOPs/s, bwd: 278.21 TFLOPs/s, fwd + bwd: 281.69 TFLOPs/s
Flash2UnPacked fwd: 285.52 TFLOPs/s, bwd: 273.40 TFLOPs/s, fwd + bwd: 276.76 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Flash2Sink fwd: 269.39 TFLOPs/s, bwd: 112.03 TFLOPs/s, fwd + bwd: 134.47 TFLOPs/s
Flash2SinkFused fwd: 282.67 TFLOPs/s, bwd: 280.94 TFLOPs/s, fwd + bwd: 281.43 TFLOPs/s

causal=True, headdim=128, batch_size=32, seqlen=512

Flash2 fwd: 198.90 TFLOPs/s, bwd: 133.71 TFLOPs/s, fwd + bwd: 147.52 TFLOPs/s
Flash2UnPacked fwd: 200.02 TFLOPs/s, bwd: 134.39 TFLOPs/s, fwd + bwd: 148.29 TFLOPs/s
Pytorch fwd: 25.69 TFLOPs/s, bwd: 52.65 TFLOPs/s, fwd + bwd: 40.51 TFLOPs/s
Flash2Sink fwd: 69.12 TFLOPs/s, bwd: 33.30 TFLOPs/s, fwd + bwd: 39.08 TFLOPs/s
Flash2SinkFused fwd: 175.35 TFLOPs/s, bwd: 134.90 TFLOPs/s, fwd + bwd: 144.42 TFLOPs/s

causal=True, headdim=128, batch_size=16, seqlen=1024

Flash2 fwd: 254.38 TFLOPs/s, bwd: 184.45 TFLOPs/s, fwd + bwd: 200.18 TFLOPs/s
Flash2UnPacked fwd: 255.37 TFLOPs/s, bwd: 184.91 TFLOPs/s, fwd + bwd: 200.73 TFLOPs/s
Pytorch fwd: 31.54 TFLOPs/s, bwd: 65.44 TFLOPs/s, fwd + bwd: 50.07 TFLOPs/s
Flash2Sink fwd: 115.72 TFLOPs/s, bwd: 52.89 TFLOPs/s, fwd + bwd: 62.61 TFLOPs/s
Flash2SinkFused fwd: 231.08 TFLOPs/s, bwd: 185.47 TFLOPs/s, fwd + bwd: 196.55 TFLOPs/s

causal=True, headdim=128, batch_size=8, seqlen=2048

Flash2 fwd: 294.34 TFLOPs/s, bwd: 226.72 TFLOPs/s, fwd + bwd: 242.65 TFLOPs/s
Flash2UnPacked fwd: 294.98 TFLOPs/s, bwd: 227.10 TFLOPs/s, fwd + bwd: 243.08 TFLOPs/s
Pytorch fwd: 32.57 TFLOPs/s, bwd: 80.07 TFLOPs/s, fwd + bwd: 56.52 TFLOPs/s
Flash2Sink fwd: 173.63 TFLOPs/s, bwd: 74.36 TFLOPs/s, fwd + bwd: 88.88 TFLOPs/s
Flash2SinkFused fwd: 266.34 TFLOPs/s, bwd: 226.97 TFLOPs/s, fwd + bwd: 236.98 TFLOPs/s

causal=True, headdim=128, batch_size=4, seqlen=4096

Flash2 fwd: 317.22 TFLOPs/s, bwd: 255.42 TFLOPs/s, fwd + bwd: 270.48 TFLOPs/s
Flash2UnPacked fwd: 318.22 TFLOPs/s, bwd: 256.67 TFLOPs/s, fwd + bwd: 271.69 TFLOPs/s
Pytorch fwd: 29.69 TFLOPs/s, bwd: 87.89 TFLOPs/s, fwd + bwd: 56.33 TFLOPs/s
Flash2Sink fwd: 231.73 TFLOPs/s, bwd: 93.86 TFLOPs/s, fwd + bwd: 113.08 TFLOPs/s
Flash2SinkFused fwd: 288.53 TFLOPs/s, bwd: 257.22 TFLOPs/s, fwd + bwd: 265.45 TFLOPs/s

causal=True, headdim=128, batch_size=2, seqlen=8192

Flash2 fwd: 329.36 TFLOPs/s, bwd: 275.89 TFLOPs/s, fwd + bwd: 289.31 TFLOPs/s
Flash2UnPacked fwd: 331.18 TFLOPs/s, bwd: 273.30 TFLOPs/s, fwd + bwd: 287.66 TFLOPs/s
Pytorch fwd: 27.60 TFLOPs/s, bwd: 92.39 TFLOPs/s, fwd + bwd: 55.30 TFLOPs/s
Flash2Sink fwd: 277.66 TFLOPs/s, bwd: 108.54 TFLOPs/s, fwd + bwd: 131.41 TFLOPs/s
Flash2SinkFused fwd: 301.16 TFLOPs/s, bwd: 275.66 TFLOPs/s, fwd + bwd: 282.49 TFLOPs/s

causal=True, headdim=128, batch_size=1, seqlen=16384

Flash2 fwd: 332.55 TFLOPs/s, bwd: 296.42 TFLOPs/s, fwd + bwd: 305.92 TFLOPs/s
Flash2UnPacked fwd: 338.71 TFLOPs/s, bwd: 294.45 TFLOPs/s, fwd + bwd: 305.87 TFLOPs/s
Pytorch fwd: 36.72 TFLOPs/s, bwd: 94.61 TFLOPs/s, fwd + bwd: 65.23 TFLOPs/s
Flash2Sink fwd: 310.95 TFLOPs/s, bwd: 120.92 TFLOPs/s, fwd + bwd: 146.50 TFLOPs/s
Flash2SinkFused fwd: 305.13 TFLOPs/s, bwd: 296.15 TFLOPs/s, fwd + bwd: 298.66 TFLOPs/s

@aoxy
Copy link
Copy Markdown
Owner Author

aoxy commented Aug 18, 2025

H20

flash-attention/benchmarks$ CUDA_VISIBLE_DEVICES=2 python benchmark_flash_attention.py

causal=False, headdim=64, batch_size=32, seqlen=512

Flash2 fwd: 80.92 TFLOPs/s, bwd: 73.09 TFLOPs/s, fwd + bwd: 75.17 TFLOPs/s
Flash2UnPacked fwd: 81.55 TFLOPs/s, bwd: 73.23 TFLOPs/s, fwd + bwd: 75.43 TFLOPs/s
Pytorch fwd: 40.54 TFLOPs/s, bwd: 62.09 TFLOPs/s, fwd + bwd: 53.90 TFLOPs/s
Flash2Sink fwd: 54.35 TFLOPs/s, bwd: 24.36 TFLOPs/s, fwd + bwd: 28.92 TFLOPs/s
Flash2SinkFused fwd: 79.94 TFLOPs/s, bwd: 73.06 TFLOPs/s, fwd + bwd: 74.90 TFLOPs/s

causal=False, headdim=64, batch_size=16, seqlen=1024

Flash2 fwd: 84.14 TFLOPs/s, bwd: 79.36 TFLOPs/s, fwd + bwd: 80.67 TFLOPs/s
Flash2UnPacked fwd: 84.19 TFLOPs/s, bwd: 79.45 TFLOPs/s, fwd + bwd: 80.75 TFLOPs/s
Pytorch fwd: 48.37 TFLOPs/s, bwd: 66.96 TFLOPs/s, fwd + bwd: 60.34 TFLOPs/s
Flash2Sink fwd: 66.81 TFLOPs/s, bwd: 29.11 TFLOPs/s, fwd + bwd: 34.71 TFLOPs/s
Flash2SinkFused fwd: 80.94 TFLOPs/s, bwd: 79.37 TFLOPs/s, fwd + bwd: 79.81 TFLOPs/s

causal=False, headdim=64, batch_size=8, seqlen=2048

Flash2 fwd: 85.02 TFLOPs/s, bwd: 82.84 TFLOPs/s, fwd + bwd: 83.45 TFLOPs/s
Flash2UnPacked fwd: 84.97 TFLOPs/s, bwd: 82.91 TFLOPs/s, fwd + bwd: 83.49 TFLOPs/s
Pytorch fwd: 44.86 TFLOPs/s, bwd: 81.33 TFLOPs/s, fwd + bwd: 66.00 TFLOPs/s
Flash2Sink fwd: 75.18 TFLOPs/s, bwd: 32.22 TFLOPs/s, fwd + bwd: 38.51 TFLOPs/s
Flash2SinkFused fwd: 80.60 TFLOPs/s, bwd: 82.42 TFLOPs/s, fwd + bwd: 81.89 TFLOPs/s

causal=False, headdim=64, batch_size=4, seqlen=4096

Flash2 fwd: 85.53 TFLOPs/s, bwd: 84.71 TFLOPs/s, fwd + bwd: 84.95 TFLOPs/s
Flash2UnPacked fwd: 85.61 TFLOPs/s, bwd: 84.74 TFLOPs/s, fwd + bwd: 84.99 TFLOPs/s
Pytorch fwd: 34.87 TFLOPs/s, bwd: 84.59 TFLOPs/s, fwd + bwd: 60.11 TFLOPs/s
Flash2Sink fwd: 80.26 TFLOPs/s, bwd: 34.04 TFLOPs/s, fwd + bwd: 40.74 TFLOPs/s
Flash2SinkFused fwd: 80.63 TFLOPs/s, bwd: 84.63 TFLOPs/s, fwd + bwd: 83.45 TFLOPs/s

causal=False, headdim=64, batch_size=2, seqlen=8192

Flash2 fwd: 85.76 TFLOPs/s, bwd: 85.69 TFLOPs/s, fwd + bwd: 85.71 TFLOPs/s
Flash2UnPacked fwd: 85.81 TFLOPs/s, bwd: 85.70 TFLOPs/s, fwd + bwd: 85.73 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Flash2Sink fwd: 83.07 TFLOPs/s, bwd: 35.03 TFLOPs/s, fwd + bwd: 41.96 TFLOPs/s
Flash2SinkFused fwd: 81.77 TFLOPs/s, bwd: 85.70 TFLOPs/s, fwd + bwd: 84.54 TFLOPs/s

causal=False, headdim=64, batch_size=1, seqlen=16384

Flash2 fwd: 85.91 TFLOPs/s, bwd: 86.20 TFLOPs/s, fwd + bwd: 86.11 TFLOPs/s
Flash2UnPacked fwd: 85.96 TFLOPs/s, bwd: 86.20 TFLOPs/s, fwd + bwd: 86.13 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Flash2Sink fwd: 84.68 TFLOPs/s, bwd: 35.56 TFLOPs/s, fwd + bwd: 42.63 TFLOPs/s
Flash2SinkFused fwd: 81.84 TFLOPs/s, bwd: 86.21 TFLOPs/s, fwd + bwd: 84.91 TFLOPs/s

causal=False, headdim=128, batch_size=32, seqlen=512

Flash2 fwd: 84.86 TFLOPs/s, bwd: 75.61 TFLOPs/s, fwd + bwd: 78.04 TFLOPs/s
Flash2UnPacked fwd: 85.01 TFLOPs/s, bwd: 75.75 TFLOPs/s, fwd + bwd: 78.18 TFLOPs/s
Pytorch fwd: 49.25 TFLOPs/s, bwd: 75.02 TFLOPs/s, fwd + bwd: 65.26 TFLOPs/s
Flash2Sink fwd: 56.30 TFLOPs/s, bwd: 25.41 TFLOPs/s, fwd + bwd: 30.14 TFLOPs/s
Flash2SinkFused fwd: 82.11 TFLOPs/s, bwd: 75.62 TFLOPs/s, fwd + bwd: 77.37 TFLOPs/s

causal=False, headdim=128, batch_size=16, seqlen=1024

Flash2 fwd: 86.95 TFLOPs/s, bwd: 81.30 TFLOPs/s, fwd + bwd: 82.84 TFLOPs/s
Flash2UnPacked fwd: 87.05 TFLOPs/s, bwd: 81.43 TFLOPs/s, fwd + bwd: 82.96 TFLOPs/s
Pytorch fwd: 61.90 TFLOPs/s, bwd: 86.03 TFLOPs/s, fwd + bwd: 77.41 TFLOPs/s
Flash2Sink fwd: 69.04 TFLOPs/s, bwd: 30.05 TFLOPs/s, fwd + bwd: 35.83 TFLOPs/s
Flash2SinkFused fwd: 83.81 TFLOPs/s, bwd: 81.42 TFLOPs/s, fwd + bwd: 82.09 TFLOPs/s

causal=False, headdim=128, batch_size=8, seqlen=2048

Flash2 fwd: 87.85 TFLOPs/s, bwd: 84.52 TFLOPs/s, fwd + bwd: 85.45 TFLOPs/s
Flash2UnPacked fwd: 87.93 TFLOPs/s, bwd: 84.56 TFLOPs/s, fwd + bwd: 85.49 TFLOPs/s
Pytorch fwd: 62.92 TFLOPs/s, bwd: 102.88 TFLOPs/s, fwd + bwd: 87.08 TFLOPs/s
Flash2Sink fwd: 77.69 TFLOPs/s, bwd: 33.06 TFLOPs/s, fwd + bwd: 39.55 TFLOPs/s
Flash2SinkFused fwd: 84.77 TFLOPs/s, bwd: 84.52 TFLOPs/s, fwd + bwd: 84.59 TFLOPs/s

causal=False, headdim=128, batch_size=4, seqlen=4096

Flash2 fwd: 88.39 TFLOPs/s, bwd: 86.22 TFLOPs/s, fwd + bwd: 86.83 TFLOPs/s
Flash2UnPacked fwd: 88.41 TFLOPs/s, bwd: 86.26 TFLOPs/s, fwd + bwd: 86.87 TFLOPs/s
Pytorch fwd: 53.91 TFLOPs/s, bwd: 109.15 TFLOPs/s, fwd + bwd: 84.43 TFLOPs/s
Flash2Sink fwd: 82.92 TFLOPs/s, bwd: 34.80 TFLOPs/s, fwd + bwd: 41.72 TFLOPs/s
Flash2SinkFused fwd: 85.07 TFLOPs/s, bwd: 86.13 TFLOPs/s, fwd + bwd: 85.82 TFLOPs/s

causal=False, headdim=128, batch_size=2, seqlen=8192

Flash2 fwd: 88.63 TFLOPs/s, bwd: 87.09 TFLOPs/s, fwd + bwd: 87.52 TFLOPs/s
Flash2UnPacked fwd: 88.67 TFLOPs/s, bwd: 87.10 TFLOPs/s, fwd + bwd: 87.54 TFLOPs/s
Pytorch fwd: 50.74 TFLOPs/s, bwd: 113.28 TFLOPs/s, fwd + bwd: 83.77 TFLOPs/s
Flash2Sink fwd: 85.80 TFLOPs/s, bwd: 35.74 TFLOPs/s, fwd + bwd: 42.89 TFLOPs/s
Flash2SinkFused fwd: 85.56 TFLOPs/s, bwd: 87.10 TFLOPs/s, fwd + bwd: 86.65 TFLOPs/s

causal=False, headdim=128, batch_size=1, seqlen=16384

Flash2 fwd: 88.78 TFLOPs/s, bwd: 87.53 TFLOPs/s, fwd + bwd: 87.88 TFLOPs/s
Flash2UnPacked fwd: 88.81 TFLOPs/s, bwd: 87.55 TFLOPs/s, fwd + bwd: 87.90 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Flash2Sink fwd: 87.48 TFLOPs/s, bwd: 36.24 TFLOPs/s, fwd + bwd: 43.52 TFLOPs/s
Flash2SinkFused fwd: 85.70 TFLOPs/s, bwd: 87.54 TFLOPs/s, fwd + bwd: 87.01 TFLOPs/s

causal=True, headdim=64, batch_size=32, seqlen=512

Flash2 fwd: 61.08 TFLOPs/s, bwd: 52.77 TFLOPs/s, fwd + bwd: 54.91 TFLOPs/s
Flash2UnPacked fwd: 61.18 TFLOPs/s, bwd: 52.82 TFLOPs/s, fwd + bwd: 54.97 TFLOPs/s
Pytorch fwd: 11.25 TFLOPs/s, bwd: 31.04 TFLOPs/s, fwd + bwd: 20.66 TFLOPs/s
Flash2Sink fwd: 34.81 TFLOPs/s, bwd: 16.10 TFLOPs/s, fwd + bwd: 19.02 TFLOPs/s
Flash2SinkFused fwd: 61.03 TFLOPs/s, bwd: 52.69 TFLOPs/s, fwd + bwd: 54.83 TFLOPs/s

causal=True, headdim=64, batch_size=16, seqlen=1024

Flash2 fwd: 70.45 TFLOPs/s, bwd: 64.96 TFLOPs/s, fwd + bwd: 66.44 TFLOPs/s
Flash2UnPacked fwd: 70.58 TFLOPs/s, bwd: 65.11 TFLOPs/s, fwd + bwd: 66.58 TFLOPs/s
Pytorch fwd: 12.35 TFLOPs/s, bwd: 33.60 TFLOPs/s, fwd + bwd: 22.53 TFLOPs/s
Flash2Sink fwd: 49.16 TFLOPs/s, bwd: 22.10 TFLOPs/s, fwd + bwd: 26.23 TFLOPs/s
Flash2SinkFused fwd: 70.66 TFLOPs/s, bwd: 64.87 TFLOPs/s, fwd + bwd: 66.43 TFLOPs/s

causal=True, headdim=64, batch_size=8, seqlen=2048

Flash2 fwd: 76.25 TFLOPs/s, bwd: 73.57 TFLOPs/s, fwd + bwd: 74.32 TFLOPs/s
Flash2UnPacked fwd: 76.16 TFLOPs/s, bwd: 73.63 TFLOPs/s, fwd + bwd: 74.33 TFLOPs/s
Pytorch fwd: 11.70 TFLOPs/s, bwd: 40.85 TFLOPs/s, fwd + bwd: 23.87 TFLOPs/s
Flash2Sink fwd: 61.57 TFLOPs/s, bwd: 27.14 TFLOPs/s, fwd + bwd: 32.30 TFLOPs/s
Flash2SinkFused fwd: 76.61 TFLOPs/s, bwd: 73.59 TFLOPs/s, fwd + bwd: 74.43 TFLOPs/s

causal=True, headdim=64, batch_size=4, seqlen=4096

Flash2 fwd: 79.42 TFLOPs/s, bwd: 78.94 TFLOPs/s, fwd + bwd: 79.08 TFLOPs/s
Flash2UnPacked fwd: 79.49 TFLOPs/s, bwd: 78.98 TFLOPs/s, fwd + bwd: 79.13 TFLOPs/s
Pytorch fwd: 10.10 TFLOPs/s, bwd: 42.63 TFLOPs/s, fwd + bwd: 22.20 TFLOPs/s
Flash2Sink fwd: 70.83 TFLOPs/s, bwd: 30.70 TFLOPs/s, fwd + bwd: 36.63 TFLOPs/s
Flash2SinkFused fwd: 79.43 TFLOPs/s, bwd: 78.96 TFLOPs/s, fwd + bwd: 79.09 TFLOPs/s

causal=True, headdim=64, batch_size=2, seqlen=8192

Flash2 fwd: 80.98 TFLOPs/s, bwd: 82.14 TFLOPs/s, fwd + bwd: 81.81 TFLOPs/s
Flash2UnPacked fwd: 81.07 TFLOPs/s, bwd: 82.16 TFLOPs/s, fwd + bwd: 81.85 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Flash2Sink fwd: 76.20 TFLOPs/s, bwd: 32.90 TFLOPs/s, fwd + bwd: 39.28 TFLOPs/s
Flash2SinkFused fwd: 81.29 TFLOPs/s, bwd: 82.14 TFLOPs/s, fwd + bwd: 81.90 TFLOPs/s

causal=True, headdim=64, batch_size=1, seqlen=16384

Flash2 fwd: 81.96 TFLOPs/s, bwd: 84.50 TFLOPs/s, fwd + bwd: 83.76 TFLOPs/s
Flash2UnPacked fwd: 81.87 TFLOPs/s, bwd: 84.51 TFLOPs/s, fwd + bwd: 83.74 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Flash2Sink fwd: 79.47 TFLOPs/s, bwd: 34.38 TFLOPs/s, fwd + bwd: 41.04 TFLOPs/s
Flash2SinkFused fwd: 81.90 TFLOPs/s, bwd: 84.51 TFLOPs/s, fwd + bwd: 83.75 TFLOPs/s

causal=True, headdim=128, batch_size=32, seqlen=512

Flash2 fwd: 63.20 TFLOPs/s, bwd: 55.57 TFLOPs/s, fwd + bwd: 57.56 TFLOPs/s
Flash2UnPacked fwd: 63.43 TFLOPs/s, bwd: 55.72 TFLOPs/s, fwd + bwd: 57.72 TFLOPs/s
Pytorch fwd: 16.50 TFLOPs/s, bwd: 37.46 TFLOPs/s, fwd + bwd: 27.49 TFLOPs/s
Flash2Sink fwd: 36.02 TFLOPs/s, bwd: 17.04 TFLOPs/s, fwd + bwd: 20.06 TFLOPs/s
Flash2SinkFused fwd: 61.34 TFLOPs/s, bwd: 55.66 TFLOPs/s, fwd + bwd: 57.17 TFLOPs/s

causal=True, headdim=128, batch_size=16, seqlen=1024

Flash2 fwd: 72.71 TFLOPs/s, bwd: 67.55 TFLOPs/s, fwd + bwd: 68.95 TFLOPs/s
Flash2UnPacked fwd: 72.97 TFLOPs/s, bwd: 67.69 TFLOPs/s, fwd + bwd: 69.12 TFLOPs/s
Pytorch fwd: 19.12 TFLOPs/s, bwd: 43.00 TFLOPs/s, fwd + bwd: 31.69 TFLOPs/s
Flash2Sink fwd: 50.75 TFLOPs/s, bwd: 23.12 TFLOPs/s, fwd + bwd: 27.37 TFLOPs/s
Flash2SinkFused fwd: 70.66 TFLOPs/s, bwd: 67.70 TFLOPs/s, fwd + bwd: 68.52 TFLOPs/s

causal=True, headdim=128, batch_size=8, seqlen=2048

Flash2 fwd: 78.43 TFLOPs/s, bwd: 75.85 TFLOPs/s, fwd + bwd: 76.57 TFLOPs/s
Flash2UnPacked fwd: 78.87 TFLOPs/s, bwd: 75.91 TFLOPs/s, fwd + bwd: 76.74 TFLOPs/s
Pytorch fwd: 19.09 TFLOPs/s, bwd: 51.46 TFLOPs/s, fwd + bwd: 34.66 TFLOPs/s
Flash2Sink fwd: 63.74 TFLOPs/s, bwd: 28.14 TFLOPs/s, fwd + bwd: 33.48 TFLOPs/s
Flash2SinkFused fwd: 76.92 TFLOPs/s, bwd: 75.89 TFLOPs/s, fwd + bwd: 76.18 TFLOPs/s

causal=True, headdim=128, batch_size=4, seqlen=4096

Flash2 fwd: 81.66 TFLOPs/s, bwd: 81.18 TFLOPs/s, fwd + bwd: 81.32 TFLOPs/s
Flash2UnPacked fwd: 81.87 TFLOPs/s, bwd: 81.19 TFLOPs/s, fwd + bwd: 81.38 TFLOPs/s
Pytorch fwd: 17.20 TFLOPs/s, bwd: 54.63 TFLOPs/s, fwd + bwd: 33.69 TFLOPs/s
Flash2Sink fwd: 73.03 TFLOPs/s, bwd: 31.67 TFLOPs/s, fwd + bwd: 37.78 TFLOPs/s
Flash2SinkFused fwd: 80.45 TFLOPs/s, bwd: 81.21 TFLOPs/s, fwd + bwd: 80.99 TFLOPs/s

causal=True, headdim=128, batch_size=2, seqlen=8192

Flash2 fwd: 83.46 TFLOPs/s, bwd: 84.58 TFLOPs/s, fwd + bwd: 84.25 TFLOPs/s
Flash2UnPacked fwd: 83.51 TFLOPs/s, bwd: 84.61 TFLOPs/s, fwd + bwd: 84.29 TFLOPs/s
Pytorch fwd: 16.27 TFLOPs/s, bwd: 56.76 TFLOPs/s, fwd + bwd: 33.18 TFLOPs/s
Flash2Sink fwd: 78.52 TFLOPs/s, bwd: 33.94 TFLOPs/s, fwd + bwd: 40.51 TFLOPs/s
Flash2SinkFused fwd: 81.96 TFLOPs/s, bwd: 84.57 TFLOPs/s, fwd + bwd: 83.81 TFLOPs/s

causal=True, headdim=128, batch_size=1, seqlen=16384

Flash2 fwd: 84.49 TFLOPs/s, bwd: 87.69 TFLOPs/s, fwd + bwd: 86.75 TFLOPs/s
Flash2UnPacked fwd: 84.52 TFLOPs/s, bwd: 87.70 TFLOPs/s, fwd + bwd: 86.77 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s, bwd: 0.00 TFLOPs/s, fwd + bwd: 0.00 TFLOPs/s
Flash2Sink fwd: 82.31 TFLOPs/s, bwd: 35.69 TFLOPs/s, fwd + bwd: 42.58 TFLOPs/s
Flash2SinkFused fwd: 83.06 TFLOPs/s, bwd: 87.71 TFLOPs/s, fwd + bwd: 86.33 TFLOPs/s

@aoxy aoxy force-pushed the feature/attention_with_sink branch from e34d3ad to c00f806 Compare August 22, 2025 03:50
@aoxy aoxy force-pushed the feature/attention_with_sink branch from 47e6f28 to 9eb63cb Compare January 15, 2026 03:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Any plans to backport additive attention sinks to flash-attn-2?

2 participants