feat: Implement Sink Attention#1
Conversation
|
My current implementation approach is to add a new top-level interface |
|
H800 flash-attention/benchmarks$ python benchmark_flash_attention.py causal=False, headdim=64, batch_size=32, seqlen=512Flash2 fwd: 254.08 TFLOPs/s, bwd: 185.81 TFLOPs/s, fwd + bwd: 201.26 TFLOPs/s causal=False, headdim=64, batch_size=16, seqlen=1024Flash2 fwd: 290.81 TFLOPs/s, bwd: 229.42 TFLOPs/s, fwd + bwd: 244.14 TFLOPs/s causal=False, headdim=64, batch_size=8, seqlen=2048Flash2 fwd: 306.05 TFLOPs/s, bwd: 252.78 TFLOPs/s, fwd + bwd: 266.01 TFLOPs/s causal=False, headdim=64, batch_size=4, seqlen=4096Flash2 fwd: 313.88 TFLOPs/s, bwd: 266.13 TFLOPs/s, fwd + bwd: 278.22 TFLOPs/s causal=False, headdim=64, batch_size=2, seqlen=8192Flash2 fwd: 313.67 TFLOPs/s, bwd: 275.84 TFLOPs/s, fwd + bwd: 285.68 TFLOPs/s causal=False, headdim=64, batch_size=1, seqlen=16384Flash2 fwd: 315.12 TFLOPs/s, bwd: 278.46 TFLOPs/s, fwd + bwd: 288.03 TFLOPs/s causal=False, headdim=128, batch_size=32, seqlen=512Flash2 fwd: 297.26 TFLOPs/s, bwd: 198.26 TFLOPs/s, fwd + bwd: 219.11 TFLOPs/s causal=False, headdim=128, batch_size=16, seqlen=1024Flash2 fwd: 337.05 TFLOPs/s, bwd: 238.64 TFLOPs/s, fwd + bwd: 260.36 TFLOPs/s causal=False, headdim=128, batch_size=8, seqlen=2048Flash2 fwd: 357.42 TFLOPs/s, bwd: 264.75 TFLOPs/s, fwd + bwd: 285.93 TFLOPs/s causal=False, headdim=128, batch_size=4, seqlen=4096Flash2 fwd: 366.47 TFLOPs/s, bwd: 279.28 TFLOPs/s, fwd + bwd: 299.65 TFLOPs/s causal=False, headdim=128, batch_size=2, seqlen=8192Flash2 fwd: 369.66 TFLOPs/s, bwd: 289.60 TFLOPs/s, fwd + bwd: 308.70 TFLOPs/s causal=False, headdim=128, batch_size=1, seqlen=16384Flash2 fwd: 370.63 TFLOPs/s, bwd: 295.78 TFLOPs/s, fwd + bwd: 313.90 TFLOPs/s causal=True, headdim=64, batch_size=32, seqlen=512Flash2 fwd: 175.09 TFLOPs/s, bwd: 124.75 TFLOPs/s, fwd + bwd: 135.91 TFLOPs/s causal=True, headdim=64, batch_size=16, seqlen=1024Flash2 fwd: 223.30 TFLOPs/s, bwd: 174.80 TFLOPs/s, fwd + bwd: 186.37 TFLOPs/s causal=True, headdim=64, batch_size=8, seqlen=2048Flash2 fwd: 255.18 TFLOPs/s, bwd: 217.07 TFLOPs/s, fwd + bwd: 226.75 TFLOPs/s causal=True, headdim=64, batch_size=4, seqlen=4096Flash2 fwd: 275.49 TFLOPs/s, bwd: 242.80 TFLOPs/s, fwd + bwd: 251.32 TFLOPs/s causal=True, headdim=64, batch_size=2, seqlen=8192Flash2 fwd: 286.01 TFLOPs/s, bwd: 263.67 TFLOPs/s, fwd + bwd: 269.69 TFLOPs/s causal=True, headdim=64, batch_size=1, seqlen=16384Flash2 fwd: 290.80 TFLOPs/s, bwd: 278.21 TFLOPs/s, fwd + bwd: 281.69 TFLOPs/s causal=True, headdim=128, batch_size=32, seqlen=512Flash2 fwd: 198.90 TFLOPs/s, bwd: 133.71 TFLOPs/s, fwd + bwd: 147.52 TFLOPs/s causal=True, headdim=128, batch_size=16, seqlen=1024Flash2 fwd: 254.38 TFLOPs/s, bwd: 184.45 TFLOPs/s, fwd + bwd: 200.18 TFLOPs/s causal=True, headdim=128, batch_size=8, seqlen=2048Flash2 fwd: 294.34 TFLOPs/s, bwd: 226.72 TFLOPs/s, fwd + bwd: 242.65 TFLOPs/s causal=True, headdim=128, batch_size=4, seqlen=4096Flash2 fwd: 317.22 TFLOPs/s, bwd: 255.42 TFLOPs/s, fwd + bwd: 270.48 TFLOPs/s causal=True, headdim=128, batch_size=2, seqlen=8192Flash2 fwd: 329.36 TFLOPs/s, bwd: 275.89 TFLOPs/s, fwd + bwd: 289.31 TFLOPs/s causal=True, headdim=128, batch_size=1, seqlen=16384Flash2 fwd: 332.55 TFLOPs/s, bwd: 296.42 TFLOPs/s, fwd + bwd: 305.92 TFLOPs/s |
|
H20 flash-attention/benchmarks$ CUDA_VISIBLE_DEVICES=2 python benchmark_flash_attention.py causal=False, headdim=64, batch_size=32, seqlen=512Flash2 fwd: 80.92 TFLOPs/s, bwd: 73.09 TFLOPs/s, fwd + bwd: 75.17 TFLOPs/s causal=False, headdim=64, batch_size=16, seqlen=1024Flash2 fwd: 84.14 TFLOPs/s, bwd: 79.36 TFLOPs/s, fwd + bwd: 80.67 TFLOPs/s causal=False, headdim=64, batch_size=8, seqlen=2048Flash2 fwd: 85.02 TFLOPs/s, bwd: 82.84 TFLOPs/s, fwd + bwd: 83.45 TFLOPs/s causal=False, headdim=64, batch_size=4, seqlen=4096Flash2 fwd: 85.53 TFLOPs/s, bwd: 84.71 TFLOPs/s, fwd + bwd: 84.95 TFLOPs/s causal=False, headdim=64, batch_size=2, seqlen=8192Flash2 fwd: 85.76 TFLOPs/s, bwd: 85.69 TFLOPs/s, fwd + bwd: 85.71 TFLOPs/s causal=False, headdim=64, batch_size=1, seqlen=16384Flash2 fwd: 85.91 TFLOPs/s, bwd: 86.20 TFLOPs/s, fwd + bwd: 86.11 TFLOPs/s causal=False, headdim=128, batch_size=32, seqlen=512Flash2 fwd: 84.86 TFLOPs/s, bwd: 75.61 TFLOPs/s, fwd + bwd: 78.04 TFLOPs/s causal=False, headdim=128, batch_size=16, seqlen=1024Flash2 fwd: 86.95 TFLOPs/s, bwd: 81.30 TFLOPs/s, fwd + bwd: 82.84 TFLOPs/s causal=False, headdim=128, batch_size=8, seqlen=2048Flash2 fwd: 87.85 TFLOPs/s, bwd: 84.52 TFLOPs/s, fwd + bwd: 85.45 TFLOPs/s causal=False, headdim=128, batch_size=4, seqlen=4096Flash2 fwd: 88.39 TFLOPs/s, bwd: 86.22 TFLOPs/s, fwd + bwd: 86.83 TFLOPs/s causal=False, headdim=128, batch_size=2, seqlen=8192Flash2 fwd: 88.63 TFLOPs/s, bwd: 87.09 TFLOPs/s, fwd + bwd: 87.52 TFLOPs/s causal=False, headdim=128, batch_size=1, seqlen=16384Flash2 fwd: 88.78 TFLOPs/s, bwd: 87.53 TFLOPs/s, fwd + bwd: 87.88 TFLOPs/s causal=True, headdim=64, batch_size=32, seqlen=512Flash2 fwd: 61.08 TFLOPs/s, bwd: 52.77 TFLOPs/s, fwd + bwd: 54.91 TFLOPs/s causal=True, headdim=64, batch_size=16, seqlen=1024Flash2 fwd: 70.45 TFLOPs/s, bwd: 64.96 TFLOPs/s, fwd + bwd: 66.44 TFLOPs/s causal=True, headdim=64, batch_size=8, seqlen=2048Flash2 fwd: 76.25 TFLOPs/s, bwd: 73.57 TFLOPs/s, fwd + bwd: 74.32 TFLOPs/s causal=True, headdim=64, batch_size=4, seqlen=4096Flash2 fwd: 79.42 TFLOPs/s, bwd: 78.94 TFLOPs/s, fwd + bwd: 79.08 TFLOPs/s causal=True, headdim=64, batch_size=2, seqlen=8192Flash2 fwd: 80.98 TFLOPs/s, bwd: 82.14 TFLOPs/s, fwd + bwd: 81.81 TFLOPs/s causal=True, headdim=64, batch_size=1, seqlen=16384Flash2 fwd: 81.96 TFLOPs/s, bwd: 84.50 TFLOPs/s, fwd + bwd: 83.76 TFLOPs/s causal=True, headdim=128, batch_size=32, seqlen=512Flash2 fwd: 63.20 TFLOPs/s, bwd: 55.57 TFLOPs/s, fwd + bwd: 57.56 TFLOPs/s causal=True, headdim=128, batch_size=16, seqlen=1024Flash2 fwd: 72.71 TFLOPs/s, bwd: 67.55 TFLOPs/s, fwd + bwd: 68.95 TFLOPs/s causal=True, headdim=128, batch_size=8, seqlen=2048Flash2 fwd: 78.43 TFLOPs/s, bwd: 75.85 TFLOPs/s, fwd + bwd: 76.57 TFLOPs/s causal=True, headdim=128, batch_size=4, seqlen=4096Flash2 fwd: 81.66 TFLOPs/s, bwd: 81.18 TFLOPs/s, fwd + bwd: 81.32 TFLOPs/s causal=True, headdim=128, batch_size=2, seqlen=8192Flash2 fwd: 83.46 TFLOPs/s, bwd: 84.58 TFLOPs/s, fwd + bwd: 84.25 TFLOPs/s causal=True, headdim=128, batch_size=1, seqlen=16384Flash2 fwd: 84.49 TFLOPs/s, bwd: 87.69 TFLOPs/s, fwd + bwd: 86.75 TFLOPs/s |
e34d3ad to
c00f806
Compare
Fix attention with sink combine_attn_seqk_parallel.
47e6f28 to
9eb63cb
Compare
This implementation provides a fused and highly optimized kernel for performing Sink Attention within the FlashAttention framework.
Key Changes
flash_fwd_kernel.h,flash_bwd_kernel.h) have been extended to optionally accept asinktensor. The softmax calculation is modified to incorporate the sink logits, ensuring the initial tokens receive consistent attention.mha_sink_fwd,mha_sink_bwd) have been added toflash_api.cppand exposed to Python, providing a direct interface to the new kernel logic.flash_attn_sink_func, is introduced. It is implemented as atorch.autograd.Functionthat seamlessly integrates with PyTorch's automatic differentiation.naive_attn_with_sink.py) was added to serve as a ground truth for correctness.test.py,test_fused.py) were added to validate the fused kernel's forward and backward passes against the naive implementation, ensuring numerical correctness.benchmark_flash_attention.py) has been updated to includeFlash2SinkandFlash2SinkFusedmethods to measure the performance of this new feature.Motivation
The growing need for efficient handling of long contexts and streaming applications in LLMs makes Sink Attention a vital feature. This implementation allows models that rely on this mechanism to leverage the speed and memory efficiency of FlashAttention.
Fixes Dao-AILab#1802