Skip to content

from-branch-43-make-cumsum-compute-ndependently-for-every-element-across-dimensions #42

Description

@david-thrower

make cumsum compute a cumulative sum independently for every element across the other four dimensions

The problem, a punitive slowdown as n_loops grows linearly.

n_loops=4 → 24 cumsums per forward pass
n_loops=5 → 30 cumsums per forward pass

...

The problem

kv tensor before the cumsum is:
Python

[B, H, T, D, D] = [16, 16, 1024, 64, 64]

torch.cumsum(kv, dim=2) must compute a cumulative sum independently for every element across the other four dimensions. That is:
plain

16 × 16 × 64 × 64 = 1,048,576

…separate scan operations, each only 1024 steps long.

PyTorch launches one (or more) CUDA blocks per scan. A million+ block launches per cumsum call is catastrophic, kernel launch overhead alone is microseconds per block, so you burn 5–6 seconds just scheduling work that the GPU finishes in ~100 ms. You are launch-bound, not compute-bound or memory-bound.

And because dim=2 has a stride of D×D = 4096 floats (16 KB), the memory access pattern is strided and non-contiguous, which makes each tiny scan even slower.

The proposed solution

Replace

kv_cum = torch.cumsum(kv, dim=2)

With

# kv: [B, H, T, D, D]
# Move T to the front and collapse the rest so cumsum runs as ONE large scan
kv = kv.permute(2, 0, 1, 3, 4).contiguous()          # [T, B, H, D, D]
kv_cum = torch.cumsum(kv, dim=0)                      # one scan, length T
kv_cum = kv_cum.permute(1, 2, 0, 3, 4)               # [B, H, T, D, D]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions