make cumsum compute a cumulative sum independently for every element across the other four dimensions
The problem, a punitive slowdown as n_loops grows linearly.
n_loops=4 → 24 cumsums per forward pass
n_loops=5 → 30 cumsums per forward pass
...
The problem
kv tensor before the cumsum is:
Python
[B, H, T, D, D] = [16, 16, 1024, 64, 64]
torch.cumsum(kv, dim=2) must compute a cumulative sum independently for every element across the other four dimensions. That is:
plain
16 × 16 × 64 × 64 = 1,048,576
…separate scan operations, each only 1024 steps long.
PyTorch launches one (or more) CUDA blocks per scan. A million+ block launches per cumsum call is catastrophic, kernel launch overhead alone is microseconds per block, so you burn 5–6 seconds just scheduling work that the GPU finishes in ~100 ms. You are launch-bound, not compute-bound or memory-bound.
And because dim=2 has a stride of D×D = 4096 floats (16 KB), the memory access pattern is strided and non-contiguous, which makes each tiny scan even slower.
The proposed solution
Replace
kv_cum = torch.cumsum(kv, dim=2)
With
# kv: [B, H, T, D, D]
# Move T to the front and collapse the rest so cumsum runs as ONE large scan
kv = kv.permute(2, 0, 1, 3, 4).contiguous() # [T, B, H, D, D]
kv_cum = torch.cumsum(kv, dim=0) # one scan, length T
kv_cum = kv_cum.permute(1, 2, 0, 3, 4) # [B, H, T, D, D]
make cumsum compute a cumulative sum independently for every element across the other four dimensions
The problem, a punitive slowdown as n_loops grows linearly.
n_loops=4 → 24 cumsums per forward pass
n_loops=5 → 30 cumsums per forward pass
...
The problem
kv tensor before the cumsum is:
Python
torch.cumsum(kv, dim=2) must compute a cumulative sum independently for every element across the other four dimensions. That is:
plain
…separate scan operations, each only 1024 steps long.
PyTorch launches one (or more) CUDA blocks per scan. A million+ block launches per cumsum call is catastrophic, kernel launch overhead alone is microseconds per block, so you burn 5–6 seconds just scheduling work that the GPU finishes in ~100 ms. You are launch-bound, not compute-bound or memory-bound.
And because dim=2 has a stride of D×D = 4096 floats (16 KB), the memory access pattern is strided and non-contiguous, which makes each tiny scan even slower.
The proposed solution
Replace
With