-
Notifications
You must be signed in to change notification settings - Fork 53
[KDA] sm90 GVA enhance #64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
6862e4c
c01cf10
f7b3960
e001f72
aa9ba8a
c6886fb
5a6964a
c516f28
c9e4e97
987df50
ca8f431
1b27a7f
95086a4
2b0c922
b538129
5415db7
e5d24c3
955e536
09dbba9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,9 +25,14 @@ | |
| - Accuracy: RMSE, relative max diff between cuLA fully-fused and FLA Triton | ||
| - Performance: kernel execution time (ms) with CUDA events | ||
|
|
||
| Modes: | ||
| - Fixed-length: various (B, T) configs | ||
| - Varlen: sequences with 2-3x length variation | ||
| Modes (--mode, default: both): | ||
| - fixed: Fixed-length sequences, various (B, T, H, HV) configs. | ||
| When HV > H the row is a GVA (Grouped Value Attention) workload; | ||
| GVA and MHA rows share the same sequence-length settings so they | ||
| can be compared side by side in a single table. | ||
| - varlen: Variable-length sequences with 2-3x length variation, same mixing | ||
| of GVA (HV > H) and MHA (HV == H) rows. | ||
| - both: Run fixed + varlen (default). | ||
|
|
||
| Usage: | ||
| python bench_kda_fused_fwd.py [--mode fixed|varlen|both] [--ncu] | ||
|
|
@@ -67,6 +72,9 @@ | |
| # ============================================================ | ||
| # Constants | ||
| # ============================================================ | ||
| # Default number of Q/K heads. Each benchmark config may additionally specify | ||
| # HV (number of V heads); when HV > H the row runs in GVA mode (the kernel | ||
| # sees HV expanded q/k heads, prepared internally by prepare_safe_gate_inputs). | ||
| H, D = 64, 128 | ||
| WARMUP = 25 | ||
| N_ITERS = 100 | ||
|
|
@@ -151,23 +159,68 @@ def run_cula(q, k, v, g, beta, scale, A_log, dt_bias, init_state, cu_seqlens, lo | |
|
|
||
|
|
||
| # ============================================================ | ||
| # Fixed-length benchmark | ||
| # Config normalization helpers | ||
| # ============================================================ | ||
| def _normalize_fixed_config(cfg): | ||
| """Accept either (B, T) or (B, T, H_qk, HV) and return the 4-tuple form. | ||
|
|
||
| For the 2-tuple legacy form, defaults to H_qk=HV=H (no GVA). | ||
| """ | ||
| if len(cfg) == 2: | ||
| B, T = cfg | ||
| return B, T, H, H | ||
| if len(cfg) == 4: | ||
| return cfg | ||
| raise ValueError(f"Fixed config must be (B, T) or (B, T, H, HV), got {cfg!r}") | ||
|
|
||
|
|
||
| def _normalize_varlen_config(cfg): | ||
| """Accept (seq_lens, total_len, dist) or (seq_lens, total_len, dist, H_qk, HV). | ||
|
|
||
| For the 3-tuple legacy form, defaults to H_qk=HV=H (no GVA). | ||
| """ | ||
| if len(cfg) == 3: | ||
| seq_lens, total_len, dist = cfg | ||
| return seq_lens, total_len, dist, H, H | ||
| if len(cfg) == 5: | ||
| return cfg | ||
| raise ValueError(f"Varlen config must be (seq_lens, total_len, dist) or (seq_lens, total_len, dist, H, HV), got {cfg!r}") | ||
|
|
||
|
|
||
| def _gva_hint(H_qk, HV): | ||
| """Return a short tag marking GVA vs MHA rows for progress prints.""" | ||
| return f"[GVA {HV // H_qk}x]" if HV > H_qk else "[MHA]" | ||
|
|
||
|
|
||
| # ============================================================ | ||
| # Fixed-length benchmark (GVA-aware via per-config HV) | ||
| # ============================================================ | ||
| def bench_fixed(configs): | ||
| print("\n" + "=" * 100) | ||
| print(f" Fixed-Length Benchmark: cuLA fully-fused ({_SM_TAG}) vs FLA Triton") | ||
| print(f" Fixed-Length Benchmark: cuLA fully-fused ({_SM_TAG}) vs FLA Triton (GVA when HV > H)") | ||
| print("=" * 100) | ||
| results = [] | ||
|
|
||
| for B, T in configs: | ||
| for cfg in configs: | ||
| B, T, H_qk, HV = _normalize_fixed_config(cfg) | ||
| print(f" {_gva_hint(H_qk, HV):>9s} B={B} T={T} H={H_qk} HV={HV}") | ||
| set_seed(SEED) | ||
| device = torch.device("cuda") | ||
| torch.cuda.empty_cache() | ||
|
|
||
| seq_lens = [T] * B | ||
| cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) | ||
|
|
||
| inputs = prepare_safe_gate_inputs(B, T, H, D, device, cu_seqlens=cu_seqlens, has_init_state=HAS_INIT_STATE) | ||
| inputs = prepare_safe_gate_inputs( | ||
| B, | ||
| T, | ||
| H_qk, | ||
| D, | ||
| device, | ||
| cu_seqlens=cu_seqlens, | ||
| has_init_state=HAS_INIT_STATE, | ||
| num_v_heads=HV, | ||
| ) | ||
| q, k, v, g, beta = inputs["q"], inputs["k"], inputs["v"], inputs["g"], inputs["beta"] | ||
| A_log, dt_bias = inputs["A_log"], inputs["dt_bias"] | ||
| scale, init_state, lower_bound = inputs["scale"], inputs["init_state"], inputs["lower_bound"] | ||
|
|
@@ -202,6 +255,8 @@ def bench_fixed(configs): | |
| { | ||
| "B": B, | ||
| "T": T, | ||
| "H": H_qk, | ||
| "HV": HV, | ||
| "rmse": rmse, | ||
| "rel_max": rel_max, | ||
| "mean_diff": mean_diff, | ||
|
|
@@ -218,23 +273,34 @@ def bench_fixed(configs): | |
|
|
||
|
|
||
| # ============================================================ | ||
| # Varlen benchmark | ||
| # Varlen benchmark (GVA-aware via per-config HV) | ||
| # ============================================================ | ||
| def bench_varlen(configs): | ||
| print("\n" + "=" * 100) | ||
| print(f" Varlen Benchmark: cuLA fully-fused ({_SM_TAG}) vs FLA Triton") | ||
| print(f" Varlen Benchmark: cuLA fully-fused ({_SM_TAG}) vs FLA Triton (GVA when HV > H)") | ||
| print("=" * 100) | ||
| results = [] | ||
|
|
||
| for seq_lens, total_len, dist in configs: | ||
| for cfg in configs: | ||
| seq_lens, total_len, dist, H_qk, HV = _normalize_varlen_config(cfg) | ||
| print(f" {_gva_hint(H_qk, HV):>9s} {dist} {len(seq_lens)}seqs T={total_len} H={H_qk} HV={HV}") | ||
| set_seed(SEED) | ||
| device = torch.device("cuda") | ||
| torch.cuda.empty_cache() | ||
|
|
||
| T = total_len | ||
| cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) | ||
|
|
||
| inputs = prepare_safe_gate_inputs(1, T, H, D, device, cu_seqlens=cu_seqlens, has_init_state=HAS_INIT_STATE) | ||
| inputs = prepare_safe_gate_inputs( | ||
| 1, | ||
| T, | ||
| H_qk, | ||
| D, | ||
| device, | ||
| cu_seqlens=cu_seqlens, | ||
| has_init_state=HAS_INIT_STATE, | ||
| num_v_heads=HV, | ||
| ) | ||
| q, k, v, g, beta = inputs["q"], inputs["k"], inputs["v"], inputs["g"], inputs["beta"] | ||
| A_log, dt_bias = inputs["A_log"], inputs["dt_bias"] | ||
| scale, init_state, lower_bound = inputs["scale"], inputs["init_state"], inputs["lower_bound"] | ||
|
|
@@ -276,6 +342,8 @@ def bench_varlen(configs): | |
| "dist": dist, | ||
| "T_total": T, | ||
| "n_seqs": n_seqs, | ||
| "H": H_qk, | ||
| "HV": HV, | ||
| "rmse": rmse, | ||
| "rel_max": rel_max, | ||
| "mean_diff": mean_diff, | ||
|
|
@@ -295,11 +363,12 @@ def bench_varlen(configs): | |
| # Report | ||
| # ============================================================ | ||
| def print_report(fixed_results, varlen_results): | ||
| sep = "=" * 110 | ||
| sep = "=" * 120 | ||
| print(f"\n\n{sep}") | ||
| print(" BENCHMARK REPORT: cula_kda_fused_fwd (fully-fused)") | ||
| print(f" cuLA {_SM_TAG} fully-fused vs FLA Triton") | ||
| print(f" H={H} D={D} dtype=bf16 safe_gate=True has_init_state={HAS_INIT_STATE}") | ||
| print(f" D={D} dtype=bf16 safe_gate=True has_init_state={HAS_INIT_STATE}") | ||
| print(" GVA rows are those with HV > H (H, HV shown per row).") | ||
| wu = 1 if (NCU_MODE or SANITIZER_MODE) else WARMUP | ||
| ni = 1 if (NCU_MODE or SANITIZER_MODE) else N_ITERS | ||
| mode_tag = " [NCU mode]" if NCU_MODE else (" [Sanitizer mode]" if SANITIZER_MODE else "") | ||
|
|
@@ -308,35 +377,39 @@ def print_report(fixed_results, varlen_results): | |
|
|
||
| if fixed_results: | ||
| print("\n [Fixed-Length]") | ||
| print(f" {'─' * 90}") | ||
| print(f" {'─' * 110}") | ||
| print( | ||
| f" {'B':>3s} {'T':>6s} │ {'RMSE':>10s} {'rel_max':>10s} {'mean_diff':>10s}" | ||
| f" │ {'FLA(ms)':>9s} {'cuLA(ms)':>10s} {'Speedup':>8s}" | ||
| f" {'B':>3s} {'T':>6s} {'H':>3s} {'HV':>3s} {'GVA':>4s} │ " | ||
| f"{'RMSE':>10s} {'rel_max':>10s} {'mean_diff':>10s} │ " | ||
| f"{'FLA(ms)':>9s} {'cuLA(ms)':>10s} {'Speedup':>8s}" | ||
| ) | ||
| print(f" {'─' * 90}") | ||
| print(f" {'─' * 110}") | ||
| for r in fixed_results: | ||
| gva_tag = f"{r['HV'] // r['H']}x" if r["HV"] > r["H"] else "no" | ||
| print( | ||
| f" {r['B']:3d} {r['T']:6d} │ " | ||
| f" {r['B']:3d} {r['T']:6d} {r['H']:3d} {r['HV']:3d} {gva_tag:>4s} │ " | ||
| f"{r['rmse']:10.6f} {r['rel_max']:10.6f} {r['mean_diff']:10.6f} │ " | ||
| f"{r['ms_fla']:9.4f} {r['ms_cula']:10.4f} {r['speedup']:7.2f}x" | ||
| ) | ||
| print(f" {'─' * 90}") | ||
| print(f" {'─' * 110}") | ||
|
|
||
| if varlen_results: | ||
| print("\n [Varlen]") | ||
| print(f" {'─' * 105}") | ||
| print(f" {'─' * 120}") | ||
| print( | ||
| f" {'Config':>45s} │ {'RMSE':>10s} {'rel_max':>10s} {'mean_diff':>10s}" | ||
| f" │ {'FLA(ms)':>9s} {'cuLA(ms)':>10s} {'Speedup':>8s}" | ||
| f" {'Config':>45s} {'H':>3s} {'HV':>3s} {'GVA':>4s} │ " | ||
| f"{'RMSE':>10s} {'rel_max':>10s} {'mean_diff':>10s} │ " | ||
| f"{'FLA(ms)':>9s} {'cuLA(ms)':>10s} {'Speedup':>8s}" | ||
| ) | ||
| print(f" {'─' * 105}") | ||
| print(f" {'─' * 120}") | ||
| for r in varlen_results: | ||
| gva_tag = f"{r['HV'] // r['H']}x" if r["HV"] > r["H"] else "no" | ||
| print( | ||
| f" {r['tag']:>45s} │ " | ||
| f" {r['tag']:>45s} {r['H']:3d} {r['HV']:3d} {gva_tag:>4s} │ " | ||
| f"{r['rmse']:10.6f} {r['rel_max']:10.6f} {r['mean_diff']:10.6f} │ " | ||
| f"{r['ms_fla']:9.4f} {r['ms_cula']:10.4f} {r['speedup']:7.2f}x" | ||
| ) | ||
| print(f" {'─' * 105}") | ||
| print(f" {'─' * 120}") | ||
|
|
||
| print(f"\n{sep}\n") | ||
|
|
||
|
|
@@ -351,7 +424,12 @@ def main(): | |
| type=str, | ||
| default="both", | ||
| choices=["fixed", "varlen", "both"], | ||
| help="Which benchmark mode to run (default: both)", | ||
| help=( | ||
| "Which benchmark mode to run (default: both). " | ||
| "GVA rows (HV > H) are mixed in alongside MHA rows (HV == H) " | ||
| "under the same sequence-length settings, so GVA and MHA can be " | ||
| "compared side by side." | ||
|
Comment on lines
+429
to
+431
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this comment can be deleted |
||
| ), | ||
| ) | ||
| parser.add_argument( | ||
| "--ncu", | ||
|
|
@@ -385,25 +463,45 @@ def main(): | |
| f"[Device] {torch.cuda.get_device_name(0)} compute capability {_SM_TAG} → using {cula_kda_fused_fwd.__module__}.{cula_kda_fused_fwd.__name__}" | ||
| ) | ||
|
|
||
| # ------------------------------------------------------------------ | ||
| # Fixed-length configs — MHA and GVA rows share the same (B, T) grid. | ||
| # (B, T) → MHA (H_qk = HV = default H) | ||
| # (B, T, H_qk, HV) → GVA when HV > H_qk (HV must be a multiple of H_qk) | ||
| # ------------------------------------------------------------------ | ||
| fixed_configs = [ | ||
| # (B, T) | ||
| (1, 512), | ||
| # MHA (HV == H): | ||
| (1, 1024), | ||
| (1, 4096), | ||
| (1, 8192), | ||
| (1, 16384), | ||
| (2, 512), | ||
| (2, 1024), | ||
| (2, 4096), | ||
| (2, 8192), | ||
| (2, 16384), | ||
| # GVA (HV > H) at the same (B, T) shapes for side-by-side comparison: | ||
| (1, 1024, 16, 64), # 4x | ||
| (1, 4096, 16, 64), # 4x | ||
| (1, 8192, 16, 64), # 4x | ||
| (1, 16384, 16, 64), # 4x | ||
| (1, 4096, 32, 64), # 2x | ||
| (1, 8192, 32, 64), # 2x | ||
| (1, 4096, 8, 64), # 8x | ||
| (1, 8192, 8, 64), # 8x | ||
| (2, 4096, 16, 64), # 4x | ||
| (2, 8192, 16, 64), # 4x | ||
|
Comment on lines
+479
to
+489
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| ] | ||
|
|
||
| varlen_configs = build_varlen_configs( | ||
| # Varlen configs — identical sequence-length layouts replayed with and | ||
| # without GVA so MHA and GVA can be compared row-by-row. | ||
| varlen_configs_base = build_varlen_configs( | ||
| num_seqs_list=(10, 20), | ||
| total_lens=(4096, 8192, 16384), | ||
| total_lens=(4096, 8192), | ||
| dists=("uniform", "random", "skewed"), | ||
| ) | ||
| gva_varlen_mixed = [ | ||
| (seq_lens, T, dist, H_qk, HV) | ||
| for (H_qk, HV) in ((16, 64), (32, 64)) | ||
| for (seq_lens, T, dist) in varlen_configs_base | ||
| ] | ||
| varlen_configs = list(varlen_configs_base) + gva_varlen_mixed | ||
|
Comment on lines
+492
to
+504
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above |
||
|
|
||
| fixed_res, varlen_res = [], [] | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -256,25 +256,77 @@ def build_varlen_configs( | |
|
|
||
|
|
||
| def prepare_safe_gate_inputs( | ||
| batch_size, T, H, D, device, cu_seqlens=None, chunk_size=CHUNK_SIZE, seed=SEED, has_init_state=False | ||
| batch_size, | ||
| T, | ||
| H, | ||
| D, | ||
| device, | ||
| cu_seqlens=None, | ||
| chunk_size=CHUNK_SIZE, | ||
| seed=SEED, | ||
| has_init_state=False, | ||
| num_v_heads=None, | ||
| ): | ||
| """Prepare inputs for safe_gate benchmarks (use_gate_in_kernel=True, safe_gate=True). | ||
|
|
||
| Args: | ||
| batch_size: Batch size `B`. | ||
| T: Per-sequence length. | ||
| H: Number of Q/K heads (the "group" count, must be > 0). | ||
| D: Head dimension. | ||
| device: torch device. | ||
| cu_seqlens: Optional cumulative sequence lengths for varlen. | ||
| chunk_size: Chunk size for `prepare_chunk_indices`. | ||
| seed: RNG seed. | ||
| has_init_state: Whether to allocate a non-zero initial state. | ||
| num_v_heads: Optional number of V heads (HV). Defaults to `H` (no GVA). | ||
| When `HV > H`, GVA (Grouped Value Attention) is enabled: | ||
| - v/g/beta are allocated with HV heads. | ||
| - q/k/g/beta are expanded from H to HV heads via | ||
| `repeat_interleave(..., dim=head)`, equivalent to the einops | ||
| pattern `repeat(x, "... h d -> ... (h g) d")` used by | ||
| `fla.layers.kda`. | ||
| - `A_log` / `dt_bias` are sized to HV because FLA's | ||
| `kda_gate_chunk_cumsum` indexes them per head of `g`. | ||
| This expanded layout works on both cuLA backends (SM100 requires | ||
| q.shape[-2] == v.shape[-2]; SM90 accepts both native and expanded | ||
| GVA) and on FLA's `chunk_kda`. | ||
|
|
||
| All tensors are flattened to (1, B*T, ...) for cu_seqlens compatibility. | ||
| The returned dict always contains `H` and `HV` keys so callers can report | ||
| the effective head counts uniformly. | ||
| """ | ||
| HV = H if num_v_heads is None else num_v_heads | ||
| assert H > 0, f"H must be positive, got {H}." | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This two assert |
||
| assert HV > 0, f"HV must be positive, got {HV}." | ||
| assert HV % H == 0, f"HV ({HV}) must be a positive multiple of H ({H})." | ||
|
|
||
| dtype = torch.bfloat16 | ||
| scale = D ** (-0.5) | ||
|
|
||
| set_seed(seed) | ||
|
|
||
| # Allocate native GVA shapes: | ||
| # q, k: (B, T, H, D) | ||
| # v, g: (B, T, HV, D) | ||
| # beta: (B, T, HV) | ||
| # When HV == H this collapses to the standard non-GVA layout. | ||
| q = torch.randn(batch_size, T, H, D, dtype=dtype, device=device).requires_grad_(False) | ||
| k = torch.randn(batch_size, T, H, D, dtype=dtype, device=device).requires_grad_(False) | ||
| v = torch.randn(batch_size, T, H, D, dtype=dtype, device=device).requires_grad_(False) | ||
| g = torch.randn(batch_size, T, H, D, dtype=dtype, device=device).requires_grad_(False) | ||
| beta = torch.randn(batch_size, T, H, dtype=torch.float, device=device).sigmoid().requires_grad_(False) | ||
| v = torch.randn(batch_size, T, HV, D, dtype=dtype, device=device).requires_grad_(False) | ||
| g = torch.randn(batch_size, T, HV, D, dtype=dtype, device=device).requires_grad_(False) | ||
| beta = torch.randn(batch_size, T, HV, dtype=torch.float, device=device).sigmoid().requires_grad_(False) | ||
|
|
||
| # GVA expansion: bring q/k up to HV heads so all tensors share head dim. | ||
| group = HV // H | ||
| if group > 1: | ||
| q = q.repeat_interleave(group, dim=2).contiguous() | ||
| k = k.repeat_interleave(group, dim=2).contiguous() | ||
|
|
||
| A_log = torch.randn(H, dtype=torch.float, device=device).requires_grad_(False) | ||
| dt_bias = torch.randn(H * D, dtype=torch.float, device=device).requires_grad_(False) | ||
| # A_log / dt_bias must match the head count of `g` (HV), otherwise | ||
| # kda_gate_chunk_cumsum would index out of bounds for i_h >= H. | ||
| A_log = torch.randn(HV, dtype=torch.float, device=device).requires_grad_(False) | ||
| dt_bias = torch.randn(HV * D, dtype=torch.float, device=device).requires_grad_(False) | ||
|
|
||
| # flatten to batch_size=1 for cu_seqlens compatibility | ||
| if batch_size != 1: | ||
|
|
@@ -285,7 +337,7 @@ def prepare_safe_gate_inputs( | |
| init_state = None | ||
| if has_init_state: | ||
| num_seqs = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch_size | ||
| init_state = torch.randn(num_seqs, H, D, D, dtype=torch.float, device=device).requires_grad_(False) | ||
| init_state = torch.randn(num_seqs, HV, D, D, dtype=torch.float, device=device).requires_grad_(False) | ||
|
|
||
| return dict( | ||
| q=q, | ||
|
|
@@ -300,6 +352,8 @@ def prepare_safe_gate_inputs( | |
| chunk_indices=chunk_indices, | ||
| init_state=init_state, | ||
| lower_bound=-5.0, | ||
| H=H, | ||
| HV=HV, | ||
|
Comment on lines
+355
to
+356
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need to return |
||
| ) | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these normalization helpers are no longer needed as well