Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 131 additions & 33 deletions benchmarks/bench_kda_fused_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,14 @@
- Accuracy: RMSE, relative max diff between cuLA fully-fused and FLA Triton
- Performance: kernel execution time (ms) with CUDA events

Modes:
- Fixed-length: various (B, T) configs
- Varlen: sequences with 2-3x length variation
Modes (--mode, default: both):
- fixed: Fixed-length sequences, various (B, T, H, HV) configs.
When HV > H the row is a GVA (Grouped Value Attention) workload;
GVA and MHA rows share the same sequence-length settings so they
can be compared side by side in a single table.
- varlen: Variable-length sequences with 2-3x length variation, same mixing
of GVA (HV > H) and MHA (HV == H) rows.
- both: Run fixed + varlen (default).

Usage:
python bench_kda_fused_fwd.py [--mode fixed|varlen|both] [--ncu]
Expand Down Expand Up @@ -67,6 +72,9 @@
# ============================================================
# Constants
# ============================================================
# Default number of Q/K heads. Each benchmark config may additionally specify
# HV (number of V heads); when HV > H the row runs in GVA mode (the kernel
# sees HV expanded q/k heads, prepared internally by prepare_safe_gate_inputs).
H, D = 64, 128
WARMUP = 25
N_ITERS = 100
Expand Down Expand Up @@ -151,23 +159,68 @@ def run_cula(q, k, v, g, beta, scale, A_log, dt_bias, init_state, cu_seqlens, lo


# ============================================================
# Fixed-length benchmark
# Config normalization helpers
# ============================================================
def _normalize_fixed_config(cfg):
"""Accept either (B, T) or (B, T, H_qk, HV) and return the 4-tuple form.

For the 2-tuple legacy form, defaults to H_qk=HV=H (no GVA).
"""
if len(cfg) == 2:
B, T = cfg
return B, T, H, H
if len(cfg) == 4:
return cfg
raise ValueError(f"Fixed config must be (B, T) or (B, T, H, HV), got {cfg!r}")
Comment on lines +162 to +174
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these normalization helpers are no longer needed as well



def _normalize_varlen_config(cfg):
"""Accept (seq_lens, total_len, dist) or (seq_lens, total_len, dist, H_qk, HV).

For the 3-tuple legacy form, defaults to H_qk=HV=H (no GVA).
"""
if len(cfg) == 3:
seq_lens, total_len, dist = cfg
return seq_lens, total_len, dist, H, H
if len(cfg) == 5:
return cfg
raise ValueError(f"Varlen config must be (seq_lens, total_len, dist) or (seq_lens, total_len, dist, H, HV), got {cfg!r}")


def _gva_hint(H_qk, HV):
"""Return a short tag marking GVA vs MHA rows for progress prints."""
return f"[GVA {HV // H_qk}x]" if HV > H_qk else "[MHA]"


# ============================================================
# Fixed-length benchmark (GVA-aware via per-config HV)
# ============================================================
def bench_fixed(configs):
print("\n" + "=" * 100)
print(f" Fixed-Length Benchmark: cuLA fully-fused ({_SM_TAG}) vs FLA Triton")
print(f" Fixed-Length Benchmark: cuLA fully-fused ({_SM_TAG}) vs FLA Triton (GVA when HV > H)")
print("=" * 100)
results = []

for B, T in configs:
for cfg in configs:
B, T, H_qk, HV = _normalize_fixed_config(cfg)
print(f" {_gva_hint(H_qk, HV):>9s} B={B} T={T} H={H_qk} HV={HV}")
set_seed(SEED)
device = torch.device("cuda")
torch.cuda.empty_cache()

seq_lens = [T] * B
cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device)

inputs = prepare_safe_gate_inputs(B, T, H, D, device, cu_seqlens=cu_seqlens, has_init_state=HAS_INIT_STATE)
inputs = prepare_safe_gate_inputs(
B,
T,
H_qk,
D,
device,
cu_seqlens=cu_seqlens,
has_init_state=HAS_INIT_STATE,
num_v_heads=HV,
)
q, k, v, g, beta = inputs["q"], inputs["k"], inputs["v"], inputs["g"], inputs["beta"]
A_log, dt_bias = inputs["A_log"], inputs["dt_bias"]
scale, init_state, lower_bound = inputs["scale"], inputs["init_state"], inputs["lower_bound"]
Expand Down Expand Up @@ -202,6 +255,8 @@ def bench_fixed(configs):
{
"B": B,
"T": T,
"H": H_qk,
"HV": HV,
"rmse": rmse,
"rel_max": rel_max,
"mean_diff": mean_diff,
Expand All @@ -218,23 +273,34 @@ def bench_fixed(configs):


# ============================================================
# Varlen benchmark
# Varlen benchmark (GVA-aware via per-config HV)
# ============================================================
def bench_varlen(configs):
print("\n" + "=" * 100)
print(f" Varlen Benchmark: cuLA fully-fused ({_SM_TAG}) vs FLA Triton")
print(f" Varlen Benchmark: cuLA fully-fused ({_SM_TAG}) vs FLA Triton (GVA when HV > H)")
print("=" * 100)
results = []

for seq_lens, total_len, dist in configs:
for cfg in configs:
seq_lens, total_len, dist, H_qk, HV = _normalize_varlen_config(cfg)
print(f" {_gva_hint(H_qk, HV):>9s} {dist} {len(seq_lens)}seqs T={total_len} H={H_qk} HV={HV}")
set_seed(SEED)
device = torch.device("cuda")
torch.cuda.empty_cache()

T = total_len
cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device)

inputs = prepare_safe_gate_inputs(1, T, H, D, device, cu_seqlens=cu_seqlens, has_init_state=HAS_INIT_STATE)
inputs = prepare_safe_gate_inputs(
1,
T,
H_qk,
D,
device,
cu_seqlens=cu_seqlens,
has_init_state=HAS_INIT_STATE,
num_v_heads=HV,
)
q, k, v, g, beta = inputs["q"], inputs["k"], inputs["v"], inputs["g"], inputs["beta"]
A_log, dt_bias = inputs["A_log"], inputs["dt_bias"]
scale, init_state, lower_bound = inputs["scale"], inputs["init_state"], inputs["lower_bound"]
Expand Down Expand Up @@ -276,6 +342,8 @@ def bench_varlen(configs):
"dist": dist,
"T_total": T,
"n_seqs": n_seqs,
"H": H_qk,
"HV": HV,
"rmse": rmse,
"rel_max": rel_max,
"mean_diff": mean_diff,
Expand All @@ -295,11 +363,12 @@ def bench_varlen(configs):
# Report
# ============================================================
def print_report(fixed_results, varlen_results):
sep = "=" * 110
sep = "=" * 120
print(f"\n\n{sep}")
print(" BENCHMARK REPORT: cula_kda_fused_fwd (fully-fused)")
print(f" cuLA {_SM_TAG} fully-fused vs FLA Triton")
print(f" H={H} D={D} dtype=bf16 safe_gate=True has_init_state={HAS_INIT_STATE}")
print(f" D={D} dtype=bf16 safe_gate=True has_init_state={HAS_INIT_STATE}")
print(" GVA rows are those with HV > H (H, HV shown per row).")
wu = 1 if (NCU_MODE or SANITIZER_MODE) else WARMUP
ni = 1 if (NCU_MODE or SANITIZER_MODE) else N_ITERS
mode_tag = " [NCU mode]" if NCU_MODE else (" [Sanitizer mode]" if SANITIZER_MODE else "")
Expand All @@ -308,35 +377,39 @@ def print_report(fixed_results, varlen_results):

if fixed_results:
print("\n [Fixed-Length]")
print(f" {'─' * 90}")
print(f" {'─' * 110}")
print(
f" {'B':>3s} {'T':>6s} │ {'RMSE':>10s} {'rel_max':>10s} {'mean_diff':>10s}"
f" │ {'FLA(ms)':>9s} {'cuLA(ms)':>10s} {'Speedup':>8s}"
f" {'B':>3s} {'T':>6s} {'H':>3s} {'HV':>3s} {'GVA':>4s} │ "
f"{'RMSE':>10s} {'rel_max':>10s} {'mean_diff':>10s} │ "
f"{'FLA(ms)':>9s} {'cuLA(ms)':>10s} {'Speedup':>8s}"
)
print(f" {'─' * 90}")
print(f" {'─' * 110}")
for r in fixed_results:
gva_tag = f"{r['HV'] // r['H']}x" if r["HV"] > r["H"] else "no"
print(
f" {r['B']:3d} {r['T']:6d} │ "
f" {r['B']:3d} {r['T']:6d} {r['H']:3d} {r['HV']:3d} {gva_tag:>4s} │ "
f"{r['rmse']:10.6f} {r['rel_max']:10.6f} {r['mean_diff']:10.6f} │ "
f"{r['ms_fla']:9.4f} {r['ms_cula']:10.4f} {r['speedup']:7.2f}x"
)
print(f" {'─' * 90}")
print(f" {'─' * 110}")

if varlen_results:
print("\n [Varlen]")
print(f" {'─' * 105}")
print(f" {'─' * 120}")
print(
f" {'Config':>45s} │ {'RMSE':>10s} {'rel_max':>10s} {'mean_diff':>10s}"
f" │ {'FLA(ms)':>9s} {'cuLA(ms)':>10s} {'Speedup':>8s}"
f" {'Config':>45s} {'H':>3s} {'HV':>3s} {'GVA':>4s} │ "
f"{'RMSE':>10s} {'rel_max':>10s} {'mean_diff':>10s} │ "
f"{'FLA(ms)':>9s} {'cuLA(ms)':>10s} {'Speedup':>8s}"
)
print(f" {'─' * 105}")
print(f" {'─' * 120}")
for r in varlen_results:
gva_tag = f"{r['HV'] // r['H']}x" if r["HV"] > r["H"] else "no"
print(
f" {r['tag']:>45s} │ "
f" {r['tag']:>45s} {r['H']:3d} {r['HV']:3d} {gva_tag:>4s} │ "
f"{r['rmse']:10.6f} {r['rel_max']:10.6f} {r['mean_diff']:10.6f} │ "
f"{r['ms_fla']:9.4f} {r['ms_cula']:10.4f} {r['speedup']:7.2f}x"
)
print(f" {'─' * 105}")
print(f" {'─' * 120}")

print(f"\n{sep}\n")

Expand All @@ -351,7 +424,12 @@ def main():
type=str,
default="both",
choices=["fixed", "varlen", "both"],
help="Which benchmark mode to run (default: both)",
help=(
"Which benchmark mode to run (default: both). "
"GVA rows (HV > H) are mixed in alongside MHA rows (HV == H) "
"under the same sequence-length settings, so GVA and MHA can be "
"compared side by side."
Comment on lines +429 to +431
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment can be deleted

),
)
parser.add_argument(
"--ncu",
Expand Down Expand Up @@ -385,25 +463,45 @@ def main():
f"[Device] {torch.cuda.get_device_name(0)} compute capability {_SM_TAG} → using {cula_kda_fused_fwd.__module__}.{cula_kda_fused_fwd.__name__}"
)

# ------------------------------------------------------------------
# Fixed-length configs — MHA and GVA rows share the same (B, T) grid.
# (B, T) → MHA (H_qk = HV = default H)
# (B, T, H_qk, HV) → GVA when HV > H_qk (HV must be a multiple of H_qk)
# ------------------------------------------------------------------
fixed_configs = [
# (B, T)
(1, 512),
# MHA (HV == H):
(1, 1024),
(1, 4096),
(1, 8192),
(1, 16384),
(2, 512),
(2, 1024),
(2, 4096),
(2, 8192),
(2, 16384),
# GVA (HV > H) at the same (B, T) shapes for side-by-side comparison:
(1, 1024, 16, 64), # 4x
(1, 4096, 16, 64), # 4x
(1, 8192, 16, 64), # 4x
(1, 16384, 16, 64), # 4x
(1, 4096, 32, 64), # 2x
(1, 8192, 32, 64), # 2x
(1, 4096, 8, 64), # 8x
(1, 8192, 8, 64), # 8x
(2, 4096, 16, 64), # 4x
(2, 8192, 16, 64), # 4x
Comment on lines +479 to +489
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HV parameter can be specified by user with --hv.
And we have a HV parameter, so these test settings are no longer needed, just restoring to not modify them is OK.

]

varlen_configs = build_varlen_configs(
# Varlen configs — identical sequence-length layouts replayed with and
# without GVA so MHA and GVA can be compared row-by-row.
varlen_configs_base = build_varlen_configs(
num_seqs_list=(10, 20),
total_lens=(4096, 8192, 16384),
total_lens=(4096, 8192),
dists=("uniform", "random", "skewed"),
)
gva_varlen_mixed = [
(seq_lens, T, dist, H_qk, HV)
for (H_qk, HV) in ((16, 64), (32, 64))
for (seq_lens, T, dist) in varlen_configs_base
]
varlen_configs = list(varlen_configs_base) + gva_varlen_mixed
Comment on lines +492 to +504
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above


fixed_res, varlen_res = [], []

Expand Down
68 changes: 61 additions & 7 deletions benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,25 +256,77 @@ def build_varlen_configs(


def prepare_safe_gate_inputs(
batch_size, T, H, D, device, cu_seqlens=None, chunk_size=CHUNK_SIZE, seed=SEED, has_init_state=False
batch_size,
T,
H,
D,
device,
cu_seqlens=None,
chunk_size=CHUNK_SIZE,
seed=SEED,
has_init_state=False,
num_v_heads=None,
):
"""Prepare inputs for safe_gate benchmarks (use_gate_in_kernel=True, safe_gate=True).

Args:
batch_size: Batch size `B`.
T: Per-sequence length.
H: Number of Q/K heads (the "group" count, must be > 0).
D: Head dimension.
device: torch device.
cu_seqlens: Optional cumulative sequence lengths for varlen.
chunk_size: Chunk size for `prepare_chunk_indices`.
seed: RNG seed.
has_init_state: Whether to allocate a non-zero initial state.
num_v_heads: Optional number of V heads (HV). Defaults to `H` (no GVA).
When `HV > H`, GVA (Grouped Value Attention) is enabled:
- v/g/beta are allocated with HV heads.
- q/k/g/beta are expanded from H to HV heads via
`repeat_interleave(..., dim=head)`, equivalent to the einops
pattern `repeat(x, "... h d -> ... (h g) d")` used by
`fla.layers.kda`.
- `A_log` / `dt_bias` are sized to HV because FLA's
`kda_gate_chunk_cumsum` indexes them per head of `g`.
This expanded layout works on both cuLA backends (SM100 requires
q.shape[-2] == v.shape[-2]; SM90 accepts both native and expanded
GVA) and on FLA's `chunk_kda`.

All tensors are flattened to (1, B*T, ...) for cu_seqlens compatibility.
The returned dict always contains `H` and `HV` keys so callers can report
the effective head counts uniformly.
"""
HV = H if num_v_heads is None else num_v_heads
assert H > 0, f"H must be positive, got {H}."
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This two assert H>0, HV>0 can be deleted

assert HV > 0, f"HV must be positive, got {HV}."
assert HV % H == 0, f"HV ({HV}) must be a positive multiple of H ({H})."

dtype = torch.bfloat16
scale = D ** (-0.5)

set_seed(seed)

# Allocate native GVA shapes:
# q, k: (B, T, H, D)
# v, g: (B, T, HV, D)
# beta: (B, T, HV)
# When HV == H this collapses to the standard non-GVA layout.
q = torch.randn(batch_size, T, H, D, dtype=dtype, device=device).requires_grad_(False)
k = torch.randn(batch_size, T, H, D, dtype=dtype, device=device).requires_grad_(False)
v = torch.randn(batch_size, T, H, D, dtype=dtype, device=device).requires_grad_(False)
g = torch.randn(batch_size, T, H, D, dtype=dtype, device=device).requires_grad_(False)
beta = torch.randn(batch_size, T, H, dtype=torch.float, device=device).sigmoid().requires_grad_(False)
v = torch.randn(batch_size, T, HV, D, dtype=dtype, device=device).requires_grad_(False)
g = torch.randn(batch_size, T, HV, D, dtype=dtype, device=device).requires_grad_(False)
beta = torch.randn(batch_size, T, HV, dtype=torch.float, device=device).sigmoid().requires_grad_(False)

# GVA expansion: bring q/k up to HV heads so all tensors share head dim.
group = HV // H
if group > 1:
q = q.repeat_interleave(group, dim=2).contiguous()
k = k.repeat_interleave(group, dim=2).contiguous()

A_log = torch.randn(H, dtype=torch.float, device=device).requires_grad_(False)
dt_bias = torch.randn(H * D, dtype=torch.float, device=device).requires_grad_(False)
# A_log / dt_bias must match the head count of `g` (HV), otherwise
# kda_gate_chunk_cumsum would index out of bounds for i_h >= H.
A_log = torch.randn(HV, dtype=torch.float, device=device).requires_grad_(False)
dt_bias = torch.randn(HV * D, dtype=torch.float, device=device).requires_grad_(False)

# flatten to batch_size=1 for cu_seqlens compatibility
if batch_size != 1:
Expand All @@ -285,7 +337,7 @@ def prepare_safe_gate_inputs(
init_state = None
if has_init_state:
num_seqs = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch_size
init_state = torch.randn(num_seqs, H, D, D, dtype=torch.float, device=device).requires_grad_(False)
init_state = torch.randn(num_seqs, HV, D, D, dtype=torch.float, device=device).requires_grad_(False)

return dict(
q=q,
Expand All @@ -300,6 +352,8 @@ def prepare_safe_gate_inputs(
chunk_indices=chunk_indices,
init_state=init_state,
lower_bound=-5.0,
H=H,
HV=HV,
Comment on lines +355 to +356
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to return

)


Expand Down
Loading