diff --git a/benchmarks/bench_kda_fused_fwd.py b/benchmarks/bench_kda_fused_fwd.py index a87eed3..db5e9e6 100644 --- a/benchmarks/bench_kda_fused_fwd.py +++ b/benchmarks/bench_kda_fused_fwd.py @@ -25,9 +25,14 @@ - Accuracy: RMSE, relative max diff between cuLA fully-fused and FLA Triton - Performance: kernel execution time (ms) with CUDA events -Modes: - - Fixed-length: various (B, T) configs - - Varlen: sequences with 2-3x length variation +Modes (--mode, default: both): + - fixed: Fixed-length sequences, various (B, T, H, HV) configs. + When HV > H the row is a GVA (Grouped Value Attention) workload; + GVA and MHA rows share the same sequence-length settings so they + can be compared side by side in a single table. + - varlen: Variable-length sequences with 2-3x length variation, same mixing + of GVA (HV > H) and MHA (HV == H) rows. + - both: Run fixed + varlen (default). Usage: python bench_kda_fused_fwd.py [--mode fixed|varlen|both] [--ncu] @@ -67,6 +72,9 @@ # ============================================================ # Constants # ============================================================ +# Default number of Q/K heads. Each benchmark config may additionally specify +# HV (number of V heads); when HV > H the row runs in GVA mode (the kernel +# sees HV expanded q/k heads, prepared internally by prepare_safe_gate_inputs). H, D = 64, 128 WARMUP = 25 N_ITERS = 100 @@ -151,15 +159,51 @@ def run_cula(q, k, v, g, beta, scale, A_log, dt_bias, init_state, cu_seqlens, lo # ============================================================ -# Fixed-length benchmark +# Config normalization helpers +# ============================================================ +def _normalize_fixed_config(cfg): + """Accept either (B, T) or (B, T, H_qk, HV) and return the 4-tuple form. + + For the 2-tuple legacy form, defaults to H_qk=HV=H (no GVA). + """ + if len(cfg) == 2: + B, T = cfg + return B, T, H, H + if len(cfg) == 4: + return cfg + raise ValueError(f"Fixed config must be (B, T) or (B, T, H, HV), got {cfg!r}") + + +def _normalize_varlen_config(cfg): + """Accept (seq_lens, total_len, dist) or (seq_lens, total_len, dist, H_qk, HV). + + For the 3-tuple legacy form, defaults to H_qk=HV=H (no GVA). + """ + if len(cfg) == 3: + seq_lens, total_len, dist = cfg + return seq_lens, total_len, dist, H, H + if len(cfg) == 5: + return cfg + raise ValueError(f"Varlen config must be (seq_lens, total_len, dist) or (seq_lens, total_len, dist, H, HV), got {cfg!r}") + + +def _gva_hint(H_qk, HV): + """Return a short tag marking GVA vs MHA rows for progress prints.""" + return f"[GVA {HV // H_qk}x]" if HV > H_qk else "[MHA]" + + +# ============================================================ +# Fixed-length benchmark (GVA-aware via per-config HV) # ============================================================ def bench_fixed(configs): print("\n" + "=" * 100) - print(f" Fixed-Length Benchmark: cuLA fully-fused ({_SM_TAG}) vs FLA Triton") + print(f" Fixed-Length Benchmark: cuLA fully-fused ({_SM_TAG}) vs FLA Triton (GVA when HV > H)") print("=" * 100) results = [] - for B, T in configs: + for cfg in configs: + B, T, H_qk, HV = _normalize_fixed_config(cfg) + print(f" {_gva_hint(H_qk, HV):>9s} B={B} T={T} H={H_qk} HV={HV}") set_seed(SEED) device = torch.device("cuda") torch.cuda.empty_cache() @@ -167,7 +211,16 @@ def bench_fixed(configs): seq_lens = [T] * B cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) - inputs = prepare_safe_gate_inputs(B, T, H, D, device, cu_seqlens=cu_seqlens, has_init_state=HAS_INIT_STATE) + inputs = prepare_safe_gate_inputs( + B, + T, + H_qk, + D, + device, + cu_seqlens=cu_seqlens, + has_init_state=HAS_INIT_STATE, + num_v_heads=HV, + ) q, k, v, g, beta = inputs["q"], inputs["k"], inputs["v"], inputs["g"], inputs["beta"] A_log, dt_bias = inputs["A_log"], inputs["dt_bias"] scale, init_state, lower_bound = inputs["scale"], inputs["init_state"], inputs["lower_bound"] @@ -202,6 +255,8 @@ def bench_fixed(configs): { "B": B, "T": T, + "H": H_qk, + "HV": HV, "rmse": rmse, "rel_max": rel_max, "mean_diff": mean_diff, @@ -218,15 +273,17 @@ def bench_fixed(configs): # ============================================================ -# Varlen benchmark +# Varlen benchmark (GVA-aware via per-config HV) # ============================================================ def bench_varlen(configs): print("\n" + "=" * 100) - print(f" Varlen Benchmark: cuLA fully-fused ({_SM_TAG}) vs FLA Triton") + print(f" Varlen Benchmark: cuLA fully-fused ({_SM_TAG}) vs FLA Triton (GVA when HV > H)") print("=" * 100) results = [] - for seq_lens, total_len, dist in configs: + for cfg in configs: + seq_lens, total_len, dist, H_qk, HV = _normalize_varlen_config(cfg) + print(f" {_gva_hint(H_qk, HV):>9s} {dist} {len(seq_lens)}seqs T={total_len} H={H_qk} HV={HV}") set_seed(SEED) device = torch.device("cuda") torch.cuda.empty_cache() @@ -234,7 +291,16 @@ def bench_varlen(configs): T = total_len cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) - inputs = prepare_safe_gate_inputs(1, T, H, D, device, cu_seqlens=cu_seqlens, has_init_state=HAS_INIT_STATE) + inputs = prepare_safe_gate_inputs( + 1, + T, + H_qk, + D, + device, + cu_seqlens=cu_seqlens, + has_init_state=HAS_INIT_STATE, + num_v_heads=HV, + ) q, k, v, g, beta = inputs["q"], inputs["k"], inputs["v"], inputs["g"], inputs["beta"] A_log, dt_bias = inputs["A_log"], inputs["dt_bias"] scale, init_state, lower_bound = inputs["scale"], inputs["init_state"], inputs["lower_bound"] @@ -276,6 +342,8 @@ def bench_varlen(configs): "dist": dist, "T_total": T, "n_seqs": n_seqs, + "H": H_qk, + "HV": HV, "rmse": rmse, "rel_max": rel_max, "mean_diff": mean_diff, @@ -295,11 +363,12 @@ def bench_varlen(configs): # Report # ============================================================ def print_report(fixed_results, varlen_results): - sep = "=" * 110 + sep = "=" * 120 print(f"\n\n{sep}") print(" BENCHMARK REPORT: cula_kda_fused_fwd (fully-fused)") print(f" cuLA {_SM_TAG} fully-fused vs FLA Triton") - print(f" H={H} D={D} dtype=bf16 safe_gate=True has_init_state={HAS_INIT_STATE}") + print(f" D={D} dtype=bf16 safe_gate=True has_init_state={HAS_INIT_STATE}") + print(" GVA rows are those with HV > H (H, HV shown per row).") wu = 1 if (NCU_MODE or SANITIZER_MODE) else WARMUP ni = 1 if (NCU_MODE or SANITIZER_MODE) else N_ITERS mode_tag = " [NCU mode]" if NCU_MODE else (" [Sanitizer mode]" if SANITIZER_MODE else "") @@ -308,35 +377,39 @@ def print_report(fixed_results, varlen_results): if fixed_results: print("\n [Fixed-Length]") - print(f" {'─' * 90}") + print(f" {'─' * 110}") print( - f" {'B':>3s} {'T':>6s} │ {'RMSE':>10s} {'rel_max':>10s} {'mean_diff':>10s}" - f" │ {'FLA(ms)':>9s} {'cuLA(ms)':>10s} {'Speedup':>8s}" + f" {'B':>3s} {'T':>6s} {'H':>3s} {'HV':>3s} {'GVA':>4s} │ " + f"{'RMSE':>10s} {'rel_max':>10s} {'mean_diff':>10s} │ " + f"{'FLA(ms)':>9s} {'cuLA(ms)':>10s} {'Speedup':>8s}" ) - print(f" {'─' * 90}") + print(f" {'─' * 110}") for r in fixed_results: + gva_tag = f"{r['HV'] // r['H']}x" if r["HV"] > r["H"] else "no" print( - f" {r['B']:3d} {r['T']:6d} │ " + f" {r['B']:3d} {r['T']:6d} {r['H']:3d} {r['HV']:3d} {gva_tag:>4s} │ " f"{r['rmse']:10.6f} {r['rel_max']:10.6f} {r['mean_diff']:10.6f} │ " f"{r['ms_fla']:9.4f} {r['ms_cula']:10.4f} {r['speedup']:7.2f}x" ) - print(f" {'─' * 90}") + print(f" {'─' * 110}") if varlen_results: print("\n [Varlen]") - print(f" {'─' * 105}") + print(f" {'─' * 120}") print( - f" {'Config':>45s} │ {'RMSE':>10s} {'rel_max':>10s} {'mean_diff':>10s}" - f" │ {'FLA(ms)':>9s} {'cuLA(ms)':>10s} {'Speedup':>8s}" + f" {'Config':>45s} {'H':>3s} {'HV':>3s} {'GVA':>4s} │ " + f"{'RMSE':>10s} {'rel_max':>10s} {'mean_diff':>10s} │ " + f"{'FLA(ms)':>9s} {'cuLA(ms)':>10s} {'Speedup':>8s}" ) - print(f" {'─' * 105}") + print(f" {'─' * 120}") for r in varlen_results: + gva_tag = f"{r['HV'] // r['H']}x" if r["HV"] > r["H"] else "no" print( - f" {r['tag']:>45s} │ " + f" {r['tag']:>45s} {r['H']:3d} {r['HV']:3d} {gva_tag:>4s} │ " f"{r['rmse']:10.6f} {r['rel_max']:10.6f} {r['mean_diff']:10.6f} │ " f"{r['ms_fla']:9.4f} {r['ms_cula']:10.4f} {r['speedup']:7.2f}x" ) - print(f" {'─' * 105}") + print(f" {'─' * 120}") print(f"\n{sep}\n") @@ -351,7 +424,12 @@ def main(): type=str, default="both", choices=["fixed", "varlen", "both"], - help="Which benchmark mode to run (default: both)", + help=( + "Which benchmark mode to run (default: both). " + "GVA rows (HV > H) are mixed in alongside MHA rows (HV == H) " + "under the same sequence-length settings, so GVA and MHA can be " + "compared side by side." + ), ) parser.add_argument( "--ncu", @@ -385,25 +463,45 @@ def main(): f"[Device] {torch.cuda.get_device_name(0)} compute capability {_SM_TAG} → using {cula_kda_fused_fwd.__module__}.{cula_kda_fused_fwd.__name__}" ) + # ------------------------------------------------------------------ + # Fixed-length configs — MHA and GVA rows share the same (B, T) grid. + # (B, T) → MHA (H_qk = HV = default H) + # (B, T, H_qk, HV) → GVA when HV > H_qk (HV must be a multiple of H_qk) + # ------------------------------------------------------------------ fixed_configs = [ - # (B, T) - (1, 512), + # MHA (HV == H): (1, 1024), (1, 4096), (1, 8192), (1, 16384), - (2, 512), - (2, 1024), (2, 4096), (2, 8192), - (2, 16384), + # GVA (HV > H) at the same (B, T) shapes for side-by-side comparison: + (1, 1024, 16, 64), # 4x + (1, 4096, 16, 64), # 4x + (1, 8192, 16, 64), # 4x + (1, 16384, 16, 64), # 4x + (1, 4096, 32, 64), # 2x + (1, 8192, 32, 64), # 2x + (1, 4096, 8, 64), # 8x + (1, 8192, 8, 64), # 8x + (2, 4096, 16, 64), # 4x + (2, 8192, 16, 64), # 4x ] - varlen_configs = build_varlen_configs( + # Varlen configs — identical sequence-length layouts replayed with and + # without GVA so MHA and GVA can be compared row-by-row. + varlen_configs_base = build_varlen_configs( num_seqs_list=(10, 20), - total_lens=(4096, 8192, 16384), + total_lens=(4096, 8192), dists=("uniform", "random", "skewed"), ) + gva_varlen_mixed = [ + (seq_lens, T, dist, H_qk, HV) + for (H_qk, HV) in ((16, 64), (32, 64)) + for (seq_lens, T, dist) in varlen_configs_base + ] + varlen_configs = list(varlen_configs_base) + gva_varlen_mixed fixed_res, varlen_res = [], [] diff --git a/benchmarks/utils.py b/benchmarks/utils.py index 29bab04..7fa3e13 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -256,25 +256,77 @@ def build_varlen_configs( def prepare_safe_gate_inputs( - batch_size, T, H, D, device, cu_seqlens=None, chunk_size=CHUNK_SIZE, seed=SEED, has_init_state=False + batch_size, + T, + H, + D, + device, + cu_seqlens=None, + chunk_size=CHUNK_SIZE, + seed=SEED, + has_init_state=False, + num_v_heads=None, ): """Prepare inputs for safe_gate benchmarks (use_gate_in_kernel=True, safe_gate=True). + Args: + batch_size: Batch size `B`. + T: Per-sequence length. + H: Number of Q/K heads (the "group" count, must be > 0). + D: Head dimension. + device: torch device. + cu_seqlens: Optional cumulative sequence lengths for varlen. + chunk_size: Chunk size for `prepare_chunk_indices`. + seed: RNG seed. + has_init_state: Whether to allocate a non-zero initial state. + num_v_heads: Optional number of V heads (HV). Defaults to `H` (no GVA). + When `HV > H`, GVA (Grouped Value Attention) is enabled: + - v/g/beta are allocated with HV heads. + - q/k/g/beta are expanded from H to HV heads via + `repeat_interleave(..., dim=head)`, equivalent to the einops + pattern `repeat(x, "... h d -> ... (h g) d")` used by + `fla.layers.kda`. + - `A_log` / `dt_bias` are sized to HV because FLA's + `kda_gate_chunk_cumsum` indexes them per head of `g`. + This expanded layout works on both cuLA backends (SM100 requires + q.shape[-2] == v.shape[-2]; SM90 accepts both native and expanded + GVA) and on FLA's `chunk_kda`. + All tensors are flattened to (1, B*T, ...) for cu_seqlens compatibility. + The returned dict always contains `H` and `HV` keys so callers can report + the effective head counts uniformly. """ + HV = H if num_v_heads is None else num_v_heads + assert H > 0, f"H must be positive, got {H}." + assert HV > 0, f"HV must be positive, got {HV}." + assert HV % H == 0, f"HV ({HV}) must be a positive multiple of H ({H})." + dtype = torch.bfloat16 scale = D ** (-0.5) set_seed(seed) + # Allocate native GVA shapes: + # q, k: (B, T, H, D) + # v, g: (B, T, HV, D) + # beta: (B, T, HV) + # When HV == H this collapses to the standard non-GVA layout. q = torch.randn(batch_size, T, H, D, dtype=dtype, device=device).requires_grad_(False) k = torch.randn(batch_size, T, H, D, dtype=dtype, device=device).requires_grad_(False) - v = torch.randn(batch_size, T, H, D, dtype=dtype, device=device).requires_grad_(False) - g = torch.randn(batch_size, T, H, D, dtype=dtype, device=device).requires_grad_(False) - beta = torch.randn(batch_size, T, H, dtype=torch.float, device=device).sigmoid().requires_grad_(False) + v = torch.randn(batch_size, T, HV, D, dtype=dtype, device=device).requires_grad_(False) + g = torch.randn(batch_size, T, HV, D, dtype=dtype, device=device).requires_grad_(False) + beta = torch.randn(batch_size, T, HV, dtype=torch.float, device=device).sigmoid().requires_grad_(False) + + # GVA expansion: bring q/k up to HV heads so all tensors share head dim. + group = HV // H + if group > 1: + q = q.repeat_interleave(group, dim=2).contiguous() + k = k.repeat_interleave(group, dim=2).contiguous() - A_log = torch.randn(H, dtype=torch.float, device=device).requires_grad_(False) - dt_bias = torch.randn(H * D, dtype=torch.float, device=device).requires_grad_(False) + # A_log / dt_bias must match the head count of `g` (HV), otherwise + # kda_gate_chunk_cumsum would index out of bounds for i_h >= H. + A_log = torch.randn(HV, dtype=torch.float, device=device).requires_grad_(False) + dt_bias = torch.randn(HV * D, dtype=torch.float, device=device).requires_grad_(False) # flatten to batch_size=1 for cu_seqlens compatibility if batch_size != 1: @@ -285,7 +337,7 @@ def prepare_safe_gate_inputs( init_state = None if has_init_state: num_seqs = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch_size - init_state = torch.randn(num_seqs, H, D, D, dtype=torch.float, device=device).requires_grad_(False) + init_state = torch.randn(num_seqs, HV, D, D, dtype=torch.float, device=device).requires_grad_(False) return dict( q=q, @@ -300,6 +352,8 @@ def prepare_safe_gate_inputs( chunk_indices=chunk_indices, init_state=init_state, lower_bound=-5.0, + H=H, + HV=HV, ) diff --git a/csrc/api/kda_sm90.cu b/csrc/api/kda_sm90.cu index e8f3a54..9e016eb 100644 --- a/csrc/api/kda_sm90.cu +++ b/csrc/api/kda_sm90.cu @@ -36,21 +36,32 @@ kda_fwd_prefill( float scale, bool output_final_state, bool safe_gate) { - // Q, K, V: [packed_seq, H, D] (already packed by Python layer) + // Q, K: [packed_seq, num_qk_heads, D] + // V/O/g: [packed_seq, num_v_heads, D] (GVA: num_v_heads is a positive integer multiple of num_qk_heads) auto packed_seq = q.size(0); - auto num_heads = q.size(1); + auto num_qk_heads = q.size(1); + auto num_v_heads = v.size(1); auto head_size = q.size(2); auto num_seqs = cu_seqlens.size(0) - 1; - // KDA constraint: all head counts must be the same - TORCH_CHECK(num_heads == k.size(1), "KDA requires num_q_heads == num_k_heads, got ", num_heads, " vs ", k.size(1)); - TORCH_CHECK(num_heads == v.size(1), "KDA requires num_q_heads == num_v_heads, got ", num_heads, " vs ", v.size(1)); + // GVA contract on the C++ side. Order matters: check positivity *before* the modulo to + // avoid % 0 / division-by-zero UB in case the Python layer passed a degenerate shape. + TORCH_CHECK(num_qk_heads > 0, "KDA requires num_qk_heads > 0, got ", num_qk_heads); + TORCH_CHECK(num_v_heads > 0, "KDA requires num_v_heads > 0, got ", num_v_heads); + TORCH_CHECK( + num_qk_heads == k.size(1), "KDA requires num_q_heads == num_k_heads, got ", num_qk_heads, " vs ", k.size(1)); + TORCH_CHECK( + num_v_heads % num_qk_heads == 0, + "KDA GVA requires num_v_heads to be a positive multiple of num_qk_heads, got num_v_heads=", + num_v_heads, + ", num_qk_heads=", + num_qk_heads); TORCH_CHECK(head_size == v.size(2), "KDA requires Q and V head dim to match, got ", head_size, " vs ", v.size(2)); // Allocate output if not provided torch::Tensor output = output_.has_value() ? output_.value() : torch::empty( - {packed_seq, num_heads, head_size}, + {packed_seq, num_v_heads, head_size}, torch::TensorOptions().dtype(q.dtype()).device(q.device())); // output_final_state controls the API side effect. If it is false, ignore @@ -61,7 +72,7 @@ kda_fwd_prefill( output_state = output_state_.has_value() ? output_state_.value() : torch::zeros( - {num_seqs, num_heads, head_size, head_size}, + {num_seqs, num_v_heads, head_size, head_size}, torch::TensorOptions().dtype(torch::kFloat32).device(q.device())); } @@ -92,8 +103,8 @@ kda_fwd_prefill( TORCH_CHECK(alpha.dtype() == torch::kFloat32, "alpha must be float32"); TORCH_CHECK(alpha.is_contiguous(), "alpha must be contiguous"); TORCH_CHECK( - alpha.size(0) == packed_seq && alpha.size(1) == num_heads && alpha.size(2) == head_size, - "alpha shape must be [packed_seq, num_heads, head_size]"); + alpha.size(0) == packed_seq && alpha.size(1) == num_v_heads && alpha.size(2) == head_size, + "alpha shape must be [packed_seq, num_v_heads, head_size]"); alpha_ptr = alpha.data_ptr(); } @@ -106,12 +117,17 @@ kda_fwd_prefill( beta.dtype()); TORCH_CHECK(beta.is_contiguous(), "beta must be contiguous"); TORCH_CHECK( - beta.size(0) == packed_seq && beta.size(1) == num_heads, "beta shape must be [packed_seq, num_heads]"); + beta.size(0) == packed_seq && beta.size(1) == num_v_heads, "beta shape must be [packed_seq, num_v_heads]"); } if (input_state_.has_value()) { auto& input_state = input_state_.value(); TORCH_CHECK(input_state.dtype() == torch::kFloat32, "input_state must be float32"); TORCH_CHECK(input_state.is_contiguous(), "input_state must be contiguous"); + // Defense in depth: also enforce shape on the C++ side (Python layer should already check). + TORCH_CHECK( + input_state.dim() == 4 && input_state.size(0) == num_seqs && input_state.size(1) == num_v_heads && + input_state.size(2) == head_size && input_state.size(3) == head_size, + "input_state shape must be [num_seqs, num_v_heads, head_size, head_size]"); input_state_ptr = input_state.data_ptr(); } @@ -142,7 +158,8 @@ kda_fwd_prefill( cu_seqlens.data_ptr(), workspace_buffer.data_ptr(), static_cast(num_seqs), - static_cast(num_heads), + static_cast(num_qk_heads), + static_cast(num_v_heads), static_cast(head_size), static_cast(packed_seq), scale, @@ -163,7 +180,8 @@ kda_fwd_prefill( cu_seqlens.data_ptr(), workspace_buffer.data_ptr(), static_cast(num_seqs), - static_cast(num_heads), + static_cast(num_qk_heads), + static_cast(num_v_heads), static_cast(head_size), static_cast(packed_seq), scale, diff --git a/csrc/kda/sm90/collective/load_tma.hpp b/csrc/kda/sm90/collective/load_tma.hpp index d7e4c8f..1d427b0 100644 --- a/csrc/kda/sm90/collective/load_tma.hpp +++ b/csrc/kda/sm90/collective/load_tma.hpp @@ -92,10 +92,11 @@ struct CollectiveLoadTma { work_desc.seq_idx, work_desc.q_head_idx(), work_desc.tok_offset); + // Q lives in the QK head space. Tensor m_varlen_head = tma_load.get_tma_tensor(make_shape( problem_size.total_seqlen, problem_size.head_size, - problem_size.num_heads)); // global view to the packed varlen sequence + problem_size.num_qk_heads)); // global view to the packed varlen sequence Tensor m_varlen = m_varlen_head(_, _, work_desc.q_head_idx()); // slice into current head_idx Tensor m_offset = domain_offset( make_coord(work_desc.tok_offset, _0{}), @@ -103,18 +104,19 @@ struct CollectiveLoadTma { Tensor g_full = local_tile(m_offset, make_tile(BlkSeqQ, HeadSize), make_coord(_, _0{})); // (blk, d, iter_blk) return g_full; - } else if constexpr (kind == LoadKind::kAlpha) { // same as Q currently + } else if constexpr (kind == LoadKind::kAlpha) { + // Alpha (gate) is per V/O head under GVA. DPRINTF0_W( "slice view GMEM %s: seq_idx:%d head_idx:%d tok_offset:%lld\n", to_string(kind), work_desc.seq_idx, - work_desc.q_head_idx(), + work_desc.o_head_idx(), work_desc.tok_offset); Tensor m_varlen_head = tma_load.get_tma_tensor(make_shape( problem_size.total_seqlen, problem_size.head_size, - problem_size.num_heads)); // global view to the packed varlen sequence - Tensor m_varlen = m_varlen_head(_, _, work_desc.q_head_idx()); // slice into current head_idx + problem_size.num_v_heads)); // global view to the packed varlen sequence + Tensor m_varlen = m_varlen_head(_, _, work_desc.o_head_idx()); // slice into current head_idx Tensor m_offset = domain_offset( make_coord(work_desc.tok_offset, _0{}), m_varlen); // offset to start of the current sequence @@ -122,7 +124,11 @@ struct CollectiveLoadTma { local_tile(m_offset, make_tile(BlkSeqQ, HeadSize), make_coord(_, _0{})); // (blk, d, iter_blk) return g_full; } else { - auto head_idx = (kind == LoadKind::kK ? work_desc.k_head_idx() : work_desc.v_head_idx()); + // K lives in the QK head space; V lives in the V head space. + // `kind` is a static constexpr LoadKind, so the head-count selection collapses at compile time. + constexpr bool kIsK = (kind == LoadKind::kK); + auto head_idx = kIsK ? work_desc.k_head_idx() : work_desc.v_head_idx(); + auto num_kv_heads = kIsK ? problem_size.num_qk_heads : problem_size.num_v_heads; DPRINTF0_W( "slice view GMEM %s: seq_idx:%d head_idx:%d tok_offset:%lld\n", to_string(kind), @@ -132,7 +138,7 @@ struct CollectiveLoadTma { Tensor m_varlen_head = tma_load.get_tma_tensor(make_shape( problem_size.head_size, problem_size.total_seqlen, - problem_size.num_heads)); // global view to the packed varlen sequence + num_kv_heads)); // global view to the packed varlen sequence Tensor m_varlen = m_varlen_head(_, _, head_idx); // slice into current head_idx Tensor m_offset = domain_offset( make_coord(_0{}, work_desc.tok_offset), diff --git a/csrc/kda/sm90/collective/mainloop_kda_fwd.hpp b/csrc/kda/sm90/collective/mainloop_kda_fwd.hpp index 1f78169..301dbfd 100644 --- a/csrc/kda/sm90/collective/mainloop_kda_fwd.hpp +++ b/csrc/kda/sm90/collective/mainloop_kda_fwd.hpp @@ -471,7 +471,7 @@ struct FlatMainloopTmaWarpSpecializedKdaFwd { Element const* ptr_V; LayoutV dV; Element* ptr_O; LayoutO dO; float const* ptr_Alpha; LayoutAlpha dAlpha; - float* ptr_output_state; // layout fixed (kdim, vdim, num_heads, num_seqs):LayoutLeft{} + float* ptr_output_state; // layout fixed (kdim, vdim, num_v_heads, num_seqs):LayoutLeft{} float const* ptr_input_state; float scale; ElementBetaGmem const* beta_ptr; GmemStrideBeta beta_stride; @@ -506,15 +506,17 @@ struct FlatMainloopTmaWarpSpecializedKdaFwd { int64_t t = problem_size.total_seqlen; int32_t d = problem_size.head_size; + // GVA: Q/K are sized by num_qk_heads; V, alpha (gate), O, beta and the recurrent state + // are sized by num_v_heads. auto params_qk = CollectiveMmaQK::to_underlying_arguments( - make_shape(s, t, d, problem_size.num_heads), + make_shape(s, t, d, problem_size.num_qk_heads), typename CollectiveMmaQK::Arguments{ args.ptr_Q, args.dQ, args.ptr_K, args.dK, // never used, dummy }, /*workspace=*/nullptr); auto params_kv_k = CollectiveMmaKV_G2S::to_underlying_arguments( - make_shape(d, d, s, problem_size.num_heads), + make_shape(d, d, s, problem_size.num_qk_heads), typename CollectiveMmaKV_G2S::Arguments{ args.ptr_V, select<1, 0, 2>(args.dV), // not used @@ -523,7 +525,7 @@ struct FlatMainloopTmaWarpSpecializedKdaFwd { }, /*workspace=*/nullptr); - auto alpha_shape = make_shape(s, d, problem_size.num_heads); + auto alpha_shape = make_shape(s, d, problem_size.num_v_heads); auto alpha_stride = make_stride( get<0>(args.dAlpha), // seqlen stride get<1>(args.dAlpha), // head_dim stride @@ -538,7 +540,7 @@ struct FlatMainloopTmaWarpSpecializedKdaFwd { size<0>(ClusterShape{})); auto params_kv_v = CollectiveMmaKV_G2S::to_underlying_arguments( - make_shape(d, d, s, problem_size.num_heads), + make_shape(d, d, s, problem_size.num_v_heads), typename CollectiveMmaKV_G2S::Arguments{ args.ptr_V, select<1, 0, 2>(args.dV), // used as G2S for V @@ -548,8 +550,8 @@ struct FlatMainloopTmaWarpSpecializedKdaFwd { /*workspace=*/nullptr); auto params_o = CollectiveStoreO::to_underlying_arguments( - make_shape(d, s, d, problem_size.num_heads), // in O1 - // make_shape(d, s, s, problem_size.num_heads), // in O2 + make_shape(d, s, d, problem_size.num_v_heads), // in O1 + // make_shape(d, s, s, problem_size.num_v_heads), // in O2 typename CollectiveStoreO::Arguments{args.ptr_O, select<1, 0, 2>(args.dO), workspace}, workspace); @@ -567,7 +569,7 @@ struct FlatMainloopTmaWarpSpecializedKdaFwd { // TODO: refactor all name to varname_vartype .beta_ptr = args.beta_ptr, - .beta_layout = make_layout(make_shape(s, problem_size.num_heads), args.beta_stride), + .beta_layout = make_layout(make_shape(s, problem_size.num_v_heads), args.beta_stride), }; } @@ -899,7 +901,8 @@ struct FlatMainloopTmaWarpSpecializedKdaFwd { auto kv_load = [&](auto& tKVrKV) INLINE_LAMBDA { DPRINTF0_WG("[%d,%d,%d,%d]>> load tKVgKV -> tKVrKV\n", seq_idx, q_head_idx, k_head_idx, v_head_idx); - int num_state_heads = problem_size.num_heads; + // GVA: state is stored per V/O head. + int num_state_heads = problem_size.num_v_heads; int state_head_idx = work_desc.o_head_idx(); auto gKV = make_tensor( make_gmem_ptr(params.ptr_input_state), @@ -914,8 +917,22 @@ struct FlatMainloopTmaWarpSpecializedKdaFwd { }; auto kv_store = [&]() INLINE_LAMBDA { // tKVrKV is carried over whole mainloop + // Skip the final-state write-back entirely when the caller does not need it + // (i.e. output_final_state=False on the public API). The check is uniform + // across the launch, so the branch cost is negligible while it saves a full + // [N, num_v_heads, D, D] float32 store to GMEM. + if (params.ptr_output_state == nullptr) { + DPRINTF0_WG( + "[%d,%d,%d,%d]>> skip tKVrKV -> tKVgKV (output_final_state=false)\n", + seq_idx, + q_head_idx, + k_head_idx, + v_head_idx); + return; + } DPRINTF0_WG("[%d,%d,%d,%d]>> save tKVrKV -> tKVgKV\n", seq_idx, q_head_idx, k_head_idx, v_head_idx); - int num_state_heads = problem_size.num_heads; + // GVA: state is stored per V/O head. + int num_state_heads = problem_size.num_v_heads; int state_head_idx = work_desc.o_head_idx(); auto gKV = make_tensor( make_gmem_ptr(params.ptr_output_state), diff --git a/csrc/kda/sm90/collective/store_tma.hpp b/csrc/kda/sm90/collective/store_tma.hpp index f7b986a..0f7f7c1 100644 --- a/csrc/kda/sm90/collective/store_tma.hpp +++ b/csrc/kda/sm90/collective/store_tma.hpp @@ -195,7 +195,7 @@ struct CollectiveStoreTma { Tensor m_varlen_head = tma_store_.get_tma_tensor(make_shape( problem_size.head_size, problem_size.total_seqlen, - problem_size.num_heads)); // global view to the packed varlen sequence + problem_size.num_v_heads)); // O lives in the V/O head space under GVA Tensor m_varlen = m_varlen_head(_, _, work_desc.o_head_idx()); // slice into current head_idx Tensor m_offset = domain_offset( make_coord(_0{}, work_desc.tok_offset), diff --git a/csrc/kda/sm90/kda_fwd_sm90.cu b/csrc/kda/sm90/kda_fwd_sm90.cu index c13b1bd..d668db9 100644 --- a/csrc/kda/sm90/kda_fwd_sm90.cu +++ b/csrc/kda/sm90/kda_fwd_sm90.cu @@ -48,7 +48,8 @@ launch_kda_fwd_prefill_kernel_gbai( int32_t const* cu_seqlens, uint8_t* workspace_buffer, int32_t num_seqs, - int32_t num_heads, + int32_t num_qk_heads, + int32_t num_v_heads, int32_t head_size, int64_t total_seqlen, float scale, @@ -74,7 +75,8 @@ launch_kda_fwd_prefill_kernel( int32_t const* cu_seqlens, uint8_t* workspace_buffer, int32_t num_seqs, - int32_t num_heads, + int32_t num_qk_heads, + int32_t num_v_heads, int32_t head_size, int64_t total_seqlen, float scale, @@ -98,7 +100,8 @@ launch_kda_fwd_prefill_kernel( cu_seqlens, \ workspace_buffer, \ num_seqs, \ - num_heads, \ + num_qk_heads, \ + num_v_heads, \ head_size, \ total_seqlen, \ scale, \ @@ -137,7 +140,8 @@ launch_kda_fwd_prefill_kernel( int32_t const* cu_seqlens, uint8_t* workspace_buffer, int32_t num_seqs, - int32_t num_heads, + int32_t num_qk_heads, + int32_t num_v_heads, int32_t head_size, int64_t total_seqlen, float scale, @@ -159,7 +163,8 @@ launch_kda_fwd_prefill_kernel( int32_t const* cu_seqlens, uint8_t* workspace_buffer, int32_t num_seqs, - int32_t num_heads, + int32_t num_qk_heads, + int32_t num_v_heads, int32_t head_size, int64_t total_seqlen, float scale, diff --git a/csrc/kda/sm90/kda_fwd_sm90_safe_gate.cu b/csrc/kda/sm90/kda_fwd_sm90_safe_gate.cu index 93fe869..309cefa 100644 --- a/csrc/kda/sm90/kda_fwd_sm90_safe_gate.cu +++ b/csrc/kda/sm90/kda_fwd_sm90_safe_gate.cu @@ -40,6 +40,7 @@ launch_kda_fwd_prefill_kernel_gbai CUTE_DEVICE bool @@ -42,11 +43,11 @@ struct WorkDesc { CUTE_DEVICE int32_t q_head_idx() const { - return head_idx; + return qk_head_idx; } CUTE_DEVICE int32_t k_head_idx() const { - return head_idx; + return qk_head_idx; } CUTE_DEVICE int32_t v_head_idx() const { @@ -64,11 +65,15 @@ struct WorkDesc { } }; +// Each block handles a single (seq, v_head) work item; CTAs do not cooperate. +// GVA optimization: heads_per_group is precomputed on the host and stored in +// Params, so the device side does not redo the integer division per CTA. struct IndividualTileScheduler { struct Params { dim3 grid; int32_t num_seqs; - int32_t num_heads; + int32_t num_v_heads; + int32_t heads_per_group; // = num_v_heads / num_qk_heads, precomputed on host }; bool scheduled = false; // a once flag @@ -84,19 +89,26 @@ struct IndividualTileScheduler { cutlass::KernelHardwareInfo const& hw_info, ClusterShape const& cluster_shape, TileShape const& tile_shape) { + // Compute heads_per_group once on the host so every CTA does not have to redo + // the integer division. + int32_t const heads_per_group = problem_size.num_v_heads / problem_size.num_qk_heads; dim3 grid(0, 1, 1); - grid.x = problem_size.num_seqs * problem_size.num_heads; + grid.x = problem_size.num_seqs * problem_size.num_v_heads; DPRINTF( - "to_underlying_arguments: grid:{.x:%d, .y:%d, .z:%d}, num_seqs:%d, num_heads:%d\n", + "to_underlying_arguments: grid:{.x:%d, .y:%d, .z:%d}, num_seqs:%d, num_qk_heads:%d, num_v_heads:%d, " + "heads_per_group:%d\n", grid.x, grid.y, grid.z, problem_size.num_seqs, - problem_size.num_heads); + problem_size.num_qk_heads, + problem_size.num_v_heads, + heads_per_group); return { .grid = grid, .num_seqs = problem_size.num_seqs, - .num_heads = problem_size.num_heads, + .num_v_heads = problem_size.num_v_heads, + .heads_per_group = heads_per_group, }; } @@ -108,8 +120,10 @@ struct IndividualTileScheduler { template CUTE_DEVICE WorkDesc get_next_work(Params params, ProblemSize const& problem_size) { - int32_t seq_idx = blockIdx.x / params.num_heads; - int32_t head_idx = blockIdx.x % params.num_heads; + int32_t seq_idx = blockIdx.x / params.num_v_heads; + int32_t head_idx = blockIdx.x % params.num_v_heads; + // GVA: use the host-precomputed heads_per_group to avoid device-side division. + int32_t qk_head_idx = head_idx / params.heads_per_group; int32_t s = problem_size.cu_seqlens[seq_idx]; int32_t e = problem_size.cu_seqlens[seq_idx + 1]; @@ -120,8 +134,9 @@ struct IndividualTileScheduler { } else { scheduled = true; DPRINTF0_W( - "get_next_work: this_work={seq_idx:%d head_idx:%d tok_offset:%lld seq_len:%lld}\n", + "get_next_work: this_work={seq_idx:%d qk_head_idx:%d head_idx:%d tok_offset:%lld seq_len:%lld}\n", seq_idx, + qk_head_idx, head_idx, s, seq_len); @@ -129,6 +144,7 @@ struct IndividualTileScheduler { return { .seq_idx = seq_idx, + .qk_head_idx = qk_head_idx, .head_idx = head_idx, .tok_offset = s, .seq_len = seq_len, diff --git a/csrc/kda/sm90/prefill_kernel.hpp b/csrc/kda/sm90/prefill_kernel.hpp index 00cef2d..d56fafa 100644 --- a/csrc/kda/sm90/prefill_kernel.hpp +++ b/csrc/kda/sm90/prefill_kernel.hpp @@ -40,7 +40,8 @@ launch_kda_fwd_prefill_kernel( int32_t const* cu_seqlens, uint8_t* workspace_buffer, int32_t num_seqs, - int32_t num_heads, + int32_t num_qk_heads, + int32_t num_v_heads, int32_t head_size, int64_t total_seqlen, float scale, diff --git a/csrc/kda/sm90/prefill_kernel_kda_fwd_sm90.cuh b/csrc/kda/sm90/prefill_kernel_kda_fwd_sm90.cuh index 2fb9bda..72f13a6 100644 --- a/csrc/kda/sm90/prefill_kernel_kda_fwd_sm90.cuh +++ b/csrc/kda/sm90/prefill_kernel_kda_fwd_sm90.cuh @@ -53,7 +53,8 @@ launch_kda_fwd_prefill_kernel_gbai( int32_t const* cu_seqlens, uint8_t* workspace_buffer, int32_t num_seqs, - int32_t num_heads, + int32_t num_qk_heads, + int32_t num_v_heads, int32_t head_size, int64_t total_seqlen, float scale, @@ -105,8 +106,11 @@ launch_kda_fwd_prefill_kernel_gbai( using Arguments = typename Operation::Arguments; // NOTE: LayoutQ/K/V in (seq, head_size, (b,h)) coordinate semantics + // GVA: Q/K rows are packed as [packed_seq, num_qk_heads, head_size]; + // V/O/g/beta rows are packed as [packed_seq, num_v_heads, head_size]. - int32_t tok_stride = num_heads * head_size; + int32_t qk_tok_stride = num_qk_heads * head_size; + int32_t v_tok_stride = num_v_heads * head_size; int32_t head_stride = head_size; Operation op; @@ -116,21 +120,22 @@ launch_kda_fwd_prefill_kernel_gbai( .cu_seqlens = cu_seqlens, .total_seqlen = total_seqlen, .num_seqs = num_seqs, - .num_heads = num_heads, + .num_qk_heads = num_qk_heads, + .num_v_heads = num_v_heads, .head_size = head_size, }, .mainloop = { // clang-format off - .ptr_Q = (T*)q, .dQ = {tok_stride, _1{}, head_stride}, - .ptr_K = (T*)k, .dK = {tok_stride, _1{}, head_stride}, - .ptr_V = (T*)v, .dV = {tok_stride, _1{}, head_stride}, - .ptr_O = (T*)output, .dO = {tok_stride, _1{}, head_stride}, - .ptr_Alpha = alpha, .dAlpha = {tok_stride, _1{}, head_stride}, + .ptr_Q = (T*)q, .dQ = {qk_tok_stride, _1{}, head_stride}, + .ptr_K = (T*)k, .dK = {qk_tok_stride, _1{}, head_stride}, + .ptr_V = (T*)v, .dV = {v_tok_stride, _1{}, head_stride}, + .ptr_O = (T*)output, .dO = {v_tok_stride, _1{}, head_stride}, + .ptr_Alpha = alpha, .dAlpha = {v_tok_stride, _1{}, head_stride}, .ptr_output_state = (float*)output_state, .ptr_input_state = (float*)input_state, .scale = scale, - .beta_ptr = beta, .beta_stride = {num_heads, 1}, + .beta_ptr = beta, .beta_stride = {num_v_heads, 1}, }, // clang-format on .hw_info = hw_info}; diff --git a/cula/kda/hopper_fused_fwd.py b/cula/kda/hopper_fused_fwd.py index cc42827..c0399bb 100644 --- a/cula/kda/hopper_fused_fwd.py +++ b/cula/kda/hopper_fused_fwd.py @@ -49,9 +49,18 @@ def forward( chunk_indices: torch.IntTensor | None = None, ): chunk_size = 64 - assert q.shape[-2] == v.shape[-2] == k.shape[-2], "Number of heads must be the same for q, k, v." + # GVA: q/k share num_qk_heads; v/g/beta share num_v_heads. + # num_v_heads must be a positive multiple of num_qk_heads (heads_per_group = HV / H). + assert q.shape == k.shape, "q and k must have the same shape." + assert q.shape[:2] == v.shape[:2] == g.shape[:2], "q, k, v, g must share batch and sequence dimensions." - batch_size, seq_len, num_heads, head_dim = q.shape + batch_size, seq_len, num_qk_heads, head_dim = q.shape + num_v_heads = v.shape[-2] + assert num_qk_heads > 0, f"num_qk_heads must be positive, got {num_qk_heads}." + assert num_v_heads > 0, f"num_v_heads must be positive, got {num_v_heads}." + assert num_v_heads % num_qk_heads == 0, ( + f"num_v_heads ({num_v_heads}) must be a positive multiple of num_qk_heads ({num_qk_heads})." + ) if cu_seqlens is None: cu_seqlens = prepare_uniform_cu_seqlens(batch_size, seq_len, q.device, torch.int32) @@ -88,13 +97,13 @@ def forward( q, q_rstd = l2norm_fwd(q) k, k_rstd = l2norm_fwd(k) - # reshape to packed [T, H, K] for the C++ kernel + # reshape q/k to packed [T, H, K] and v/g to [T, HV, K], beta to [T, HV] for the C++ kernel packed_seq = batch_size * seq_len - q = q.reshape(packed_seq, num_heads, head_dim).contiguous() - k = k.reshape(packed_seq, num_heads, head_dim).contiguous() - v = v.reshape(packed_seq, num_heads, head_dim).contiguous() - g = g.reshape(packed_seq, num_heads, head_dim).contiguous() - beta = beta.reshape(packed_seq, num_heads).contiguous() + q = q.reshape(packed_seq, num_qk_heads, head_dim).contiguous() + k = k.reshape(packed_seq, num_qk_heads, head_dim).contiguous() + v = v.reshape(packed_seq, num_v_heads, head_dim).contiguous() + g = g.reshape(packed_seq, num_v_heads, head_dim).contiguous() + beta = beta.reshape(packed_seq, num_v_heads).contiguous() # workspace buffer for TMA Store O tensormap sm_count = get_device_sm_count(q.device) @@ -159,19 +168,19 @@ def cula_kda_prefill( k (torch.Tensor): keys of shape `[B, T, H, K]`. v (torch.Tensor): - values of shape `[B, T, H, V]`. + values of shape `[B, T, HV, K]`. g (torch.Tensor): - (forget) gating tensor (in log space!) of shape `[B, T, H, K]`. + (forget) gating tensor (in log space!) of shape `[B, T, HV, K]`. beta (torch.Tensor): - betas of shape `[B, T, H]`. + betas of shape `[B, T, HV]`. scale (Optional[float]): Scale factor for the KDA attention scores. - If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + If not provided, it will default to `1 / sqrt(D)`. Default: `None`. initial_state (Optional[torch.Tensor]): - Initial state of shape `[N, H, K, V]` for `N` input sequences. + Initial state of shape `[N, HV, K, K]` for `N` input sequences. Default: `None`. output_final_state (Optional[bool]): - Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + Whether to output the final state of shape `[N, HV, K, K]`. Default: `False`. use_qk_l2norm_in_kernel (bool): Whether to apply L2norm to the q,k tensor internally. Default: `False`. use_gate_in_kernel (bool): @@ -189,9 +198,9 @@ def cula_kda_prefill( Returns: o (torch.Tensor): - Outputs of shape `[B, T, H, V]`. + Outputs of shape `[B, T, HV, K]`. final_state (torch.Tensor): - Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + Final state of shape `[N, HV, K, K]` if `output_final_state=True` else `None`. """ assert_hopper() assert safe_gate, "Only support safe_gate=True." @@ -219,9 +228,32 @@ def cula_kda_prefill( if not (-5 <= lower_bound < 0): raise ValueError(f"`lower_bound` must be in the safe range [-5, 0), got {lower_bound}.") - assert q.shape == k.shape == g.shape, "q, k, g must have the same shape." - assert beta.shape == q.shape[:3], "beta must be of shape (batch size, seq len, num of head)." - assert v.shape == (*q.shape[:3], v.shape[-1]), "v must be of shape (batch size, seq len, num of head, head dim)." + assert q.shape == k.shape, "q and k must have the same shape." + assert q.shape[:2] == v.shape[:2] == g.shape[:2], "q, k, v, g must share batch and sequence dimensions." + + batch_size, seq_len, num_qk_heads, head_dim = q.shape + num_v_heads = v.shape[-2] + # Order matters here: positivity *first*, modulo second, to avoid ZeroDivisionError on bad inputs. + assert num_qk_heads > 0, f"num_qk_heads must be positive, got {num_qk_heads}." + assert num_v_heads > 0, f"num_v_heads must be positive, got {num_v_heads}." + assert num_v_heads % num_qk_heads == 0, ( + f"num_v_heads ({num_v_heads}) must be a positive multiple of num_qk_heads ({num_qk_heads})." + ) + assert g.shape == (batch_size, seq_len, num_v_heads, head_dim), ( + f"g must have shape (B, T, HV, D)=({batch_size}, {seq_len}, {num_v_heads}, {head_dim}), got {tuple(g.shape)}." + ) + assert v.shape == (batch_size, seq_len, num_v_heads, head_dim), ( + f"v must have shape (B, T, HV, D)=({batch_size}, {seq_len}, {num_v_heads}, {head_dim}), got {tuple(v.shape)}." + ) + assert beta.shape == (batch_size, seq_len, num_v_heads), ( + f"beta must have shape (B, T, HV)=({batch_size}, {seq_len}, {num_v_heads}), got {tuple(beta.shape)}." + ) + if initial_state is not None: + expected_num_states = (len(cu_seqlens) - 1) if cu_seqlens is not None else batch_size + assert initial_state.shape == (expected_num_states, num_v_heads, head_dim, head_dim), ( + f"initial_state must have shape (N, HV, D, D)=" + f"({expected_num_states}, {num_v_heads}, {head_dim}, {head_dim}), got {tuple(initial_state.shape)}." + ) assert q.dtype == k.dtype == v.dtype == torch.bfloat16, "q, k, v must be in bfloat16." assert beta.dtype == torch.bfloat16 or beta.dtype == torch.float32, "beta must be in bfloat16 or float32." assert q.shape[-1] == k.shape[-1] == v.shape[-1] == 128, "Currently we only support head dim of 128 for KDA" diff --git a/tests/test_kda_fused_fwd.py b/tests/test_kda_fused_fwd.py index 9c32552..354f0c9 100644 --- a/tests/test_kda_fused_fwd.py +++ b/tests/test_kda_fused_fwd.py @@ -36,28 +36,34 @@ "B", "T", "H", + "HV", "D", "gate_logit_normalizer", "mask_p", "use_qk_l2norm_in_kernel", "use_gate_in_kernel", "safe_gate", + "use_initial_state", "dtype", ), [ pytest.param( *test, - id="B{}-T{}-H{}-D{}-gln{}-mask_p{}-l2norm{}-gate{}-safe_gate{}-{}".format(*test), + id=("B{}-T{}-H{}-HV{}-D{}-gln{}-mask_p{}-l2norm{}-gate{}-safe_gate{}-init{}-{}").format(*test), ) for test in [ - (1, 63, 1, 128, 1, 0, False, False, True, torch.bfloat16), - (2, 500, 3, 128, 1, 0, False, False, True, torch.bfloat16), - (2, 1000, 3, 128, 1, 0.5, False, False, True, torch.bfloat16), - (3, 1024, 4, 128, 0.1, 0, False, False, True, torch.bfloat16), - (4, 1024, 4, 128, 1, 0, False, False, True, torch.bfloat16), - (4, 1024, 4, 128, 1, 0, True, False, True, torch.bfloat16), - (2, 1500, 4, 128, 10, 0, False, True, True, torch.bfloat16), - (4, 2048, 8, 128, 1, 0, False, True, True, torch.bfloat16), + (1, 63, 1, 1, 128, 1, 0, False, False, True, True, torch.bfloat16), + (2, 500, 3, 3, 128, 1, 0, False, False, True, True, torch.bfloat16), + (2, 1000, 3, 3, 128, 1, 0.5, False, False, True, True, torch.bfloat16), + (3, 1024, 4, 4, 128, 0.1, 0, False, False, True, True, torch.bfloat16), + (4, 1024, 4, 4, 128, 1, 0, False, False, True, True, torch.bfloat16), + (4, 1024, 4, 4, 128, 1, 0, True, False, True, True, torch.bfloat16), + (2, 1500, 4, 4, 128, 10, 0, False, True, True, True, torch.bfloat16), + (4, 2048, 8, 8, 128, 1, 0, False, True, True, True, torch.bfloat16), + (2, 512, 2, 4, 128, 1, 0, False, False, True, True, torch.bfloat16), + (2, 1024, 2, 8, 128, 1, 0, False, True, True, True, torch.bfloat16), + (1, 64, 1, 2, 128, 1, 0, False, False, True, True, torch.bfloat16), + (1, 65, 1, 4, 128, 1, 0, False, False, True, False, torch.bfloat16), ] ], ) @@ -65,12 +71,14 @@ def test_safe_gate_chunk( B: int, T: int, H: int, + HV: int, D: int, gate_logit_normalizer: float, mask_p: float, use_qk_l2norm_in_kernel: bool, use_gate_in_kernel: bool, safe_gate: bool, + use_initial_state: bool, dtype: torch.dtype, beta_dtype: torch.dtype, ): @@ -81,11 +89,11 @@ def test_safe_gate_chunk( torch.manual_seed(42) q = torch.rand(B, T, H, D, dtype=dtype) k = torch.rand(B, T, H, D, dtype=dtype) - v = torch.rand(B, T, H, D, dtype=dtype) - g = torch.randn(B, T, H, D, dtype=torch.float if not use_gate_in_kernel else dtype) + v = torch.rand(B, T, HV, D, dtype=dtype) + g = torch.randn(B, T, HV, D, dtype=torch.float if not use_gate_in_kernel else dtype) if use_gate_in_kernel: - A_log = torch.randn(H, dtype=torch.float) - dt_bias = torch.randn(H * D, dtype=torch.float) + A_log = torch.randn(HV, dtype=torch.float) + dt_bias = torch.randn(HV * D, dtype=torch.float) else: g = F.logsigmoid(g) / gate_logit_normalizer g = g * (torch.rand_like(g) > mask_p) @@ -98,33 +106,39 @@ def test_safe_gate_chunk( lower_bound = None naive_kda_gate_fn = naive_kda_gate - beta = torch.randn(B, T, H, dtype=torch.float32).sigmoid().to(beta_dtype) - h0 = torch.randn(B, H, D, D, dtype=torch.float32) + beta = torch.randn(B, T, HV, dtype=torch.float32).sigmoid().to(beta_dtype) + h0 = torch.randn(B, HV, D, D, dtype=torch.float32) # NOTE: for inference scenarios, we only use transposed state layout for better decoding performance h0_vk = h0.transpose(-1, -2).contiguous() if use_gate_in_kernel: A_log, dt_bias = map(lambda x: x.to(device).requires_grad_(False), (A_log, dt_bias)) q, k, v, g, beta, h0, h0_vk = map(lambda x: x.to(device).requires_grad_(False), (q, k, v, g, beta, h0, h0_vk)) + initial_state = h0.clone() if use_initial_state else None + initial_state_vk = h0_vk.clone() if use_initial_state else None + + heads_per_group = HV // H + q_ref = q.repeat_interleave(heads_per_group, dim=2) + k_ref = k.repeat_interleave(heads_per_group, dim=2) ref, ref_ht = naive_recurrent_kda( - q=F.normalize(q.clone(), p=2, dim=-1), - k=F.normalize(k.clone(), p=2, dim=-1), + q=F.normalize(q_ref.clone(), p=2, dim=-1), + k=F.normalize(k_ref.clone(), p=2, dim=-1), v=v.clone(), g=(naive_kda_gate_fn(g, A_log, dt_bias) if use_gate_in_kernel else g.clone()), beta=beta.clone(), - initial_state=h0.clone(), + initial_state=initial_state, output_final_state=True, ) ref_fla, ref_ht_fla = fla_chunk_kda( - q=F.normalize(q.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q.clone(), - k=F.normalize(k.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k.clone(), + q=F.normalize(q_ref.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q_ref.clone(), + k=F.normalize(k_ref.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k_ref.clone(), v=v.clone(), g=g.clone(), beta=beta.clone(), A_log=(A_log.clone() if use_gate_in_kernel else None), dt_bias=(dt_bias.clone() if use_gate_in_kernel else None), - initial_state=h0.clone(), + initial_state=initial_state, output_final_state=True, use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, use_gate_in_kernel=use_gate_in_kernel, @@ -133,14 +147,14 @@ def test_safe_gate_chunk( ) ref_fla_trans, ref_ht_fla_trans = fla_chunk_kda( - q=F.normalize(q.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q.clone(), - k=F.normalize(k.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k.clone(), + q=F.normalize(q_ref.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q_ref.clone(), + k=F.normalize(k_ref.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k_ref.clone(), v=v.clone(), g=g.clone(), beta=beta.clone(), A_log=(A_log.clone() if use_gate_in_kernel else None), dt_bias=(dt_bias.clone() if use_gate_in_kernel else None), - initial_state=h0_vk.clone(), + initial_state=initial_state_vk, output_final_state=True, use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, use_gate_in_kernel=use_gate_in_kernel, @@ -157,7 +171,7 @@ def test_safe_gate_chunk( beta=beta.clone(), A_log=(A_log.clone() if use_gate_in_kernel else None), dt_bias=(dt_bias.clone() if use_gate_in_kernel else None), - initial_state=h0_vk.clone(), + initial_state=initial_state_vk, output_final_state=True, use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, use_gate_in_kernel=use_gate_in_kernel, @@ -166,10 +180,10 @@ def test_safe_gate_chunk( ) assert_close("o", ref, tri, 0.005) - assert_close("ht", ref_ht, tri_ht.transpose(-1, -2), 0.005) assert_close("o", ref_fla, tri, 0.005) - assert_close("ht", ref_ht_fla, tri_ht.transpose(-1, -2), 0.005) assert_close("o", ref_fla_trans, tri, 0.005) + assert_close("ht", ref_ht, tri_ht.transpose(-1, -2), 0.005) + assert_close("ht", ref_ht_fla, tri_ht.transpose(-1, -2), 0.005) assert_close("ht", ref_ht_fla_trans, tri_ht, 0.005) @@ -222,58 +236,75 @@ def test_safe_gate_chunk_no_final_state(): @pytest.mark.parametrize("beta_dtype", [torch.float32, torch.bfloat16], ids=["beta_fp32", "beta_bf16"]) @pytest.mark.parametrize( - ("H", "D", "mask_p", "cu_seqlens", "dtype", "safe_gate"), + ("H", "HV", "D", "mask_p", "cu_seqlens", "dtype", "safe_gate", "use_initial_state"), [ - pytest.param(*test, id="H{}-D{}-mask_p{}-cu_seqlens{}-{}-safe_gate{}".format(*test)) + pytest.param( + *test, + id="H{}-HV{}-D{}-mask_p{}-cu_seqlens{}-{}-safe_gate{}-init{}".format(*test), + ) for test in [ - (4, 128, 0.1, [0, 15], torch.bfloat16, True), - (4, 128, 0.9, [0, 256, 500, 1000], torch.bfloat16, True), - (4, 128, 0.5, [0, 256, 500, 1000], torch.bfloat16, True), - (4, 128, 0, [0, 15, 100, 300, 1200, 2000], torch.bfloat16, True), - (4, 128, 0, [0, 100, 300, 1200, 3000, 4096], torch.bfloat16, True), + (4, 4, 128, 0.1, [0, 15], torch.bfloat16, True, True), + (4, 4, 128, 0.9, [0, 256, 500, 1000], torch.bfloat16, True, True), + (4, 4, 128, 0.5, [0, 256, 500, 1000], torch.bfloat16, True, True), + (4, 4, 128, 0, [0, 15, 100, 300, 1200, 2000], torch.bfloat16, True, True), + (4, 4, 128, 0, [0, 100, 300, 1200, 3000, 4096], torch.bfloat16, True, True), + (2, 4, 128, 0, [0, 63, 130], torch.bfloat16, True, True), + (1, 2, 128, 0, [0, 1], torch.bfloat16, True, True), + (1, 2, 128, 0, [0, 63, 64, 65], torch.bfloat16, True, True), + (2, 4, 128, 0, [0, 17, 64, 65, 130], torch.bfloat16, True, False), # ======Varlen test with simulated trace======= ( + 32, 32, 128, 0, [0, 247, 699, 982, 1688, 1985, 2383, 3081, 3526, 3973, 4096, 4824, 5101, 5919, 6426, 7137, 7392, 7800, 8192], torch.bfloat16, True, + True, ), ( + 32, 32, 128, 0, [0, 652, 1255, 1600, 2083, 2345, 2756, 3172, 3767, 4096, 4891, 5236, 5543, 6255, 6480, 6947, 7616, 8192], torch.bfloat16, True, + True, ), ( + 32, 32, 128, 0, [0, 315, 973, 1283, 2162, 2459, 2678, 2998, 3781, 4096, 4503, 5459, 6318, 6669, 6979, 7583, 8192], torch.bfloat16, True, + True, ), ( + 32, 32, 128, 0, [0, 494, 1004, 1561, 1908, 2240, 2849, 3116, 4096, 4986, 5626, 6090, 6718, 7244, 7870, 8192], torch.bfloat16, True, + True, ), ] ], ) def test_safe_gate_chunk_varlen( H: int, + HV: int, D: int, mask_p: float, cu_seqlens: list[int], dtype: torch.dtype, safe_gate: bool, + use_initial_state: bool, beta_dtype: torch.dtype, ): cula_kda_fused_fwd = get_kda_fused_fwd(device) @@ -286,19 +317,24 @@ def test_safe_gate_chunk_varlen( q = torch.randn((1, T, H, D), dtype=dtype) k = F.normalize(torch.randn(1, T, H, D, dtype=torch.float32), p=2, dim=-1).to(dtype) - v = torch.randn((1, T, H, D), dtype=dtype) - g = F.logsigmoid(torch.randn(1, T, H, D, dtype=torch.float)) + v = torch.randn((1, T, HV, D), dtype=dtype) + g = F.logsigmoid(torch.randn(1, T, HV, D, dtype=torch.float)) mask = torch.rand_like(g) > mask_p g = g * mask + (~mask) * (-1000) if safe_gate: g = g.clamp(-5, 0) - beta = torch.randn(1, T, H, dtype=torch.float32).sigmoid().to(beta_dtype) - h0 = torch.randn((N, H, D, D), dtype=torch.float32) + beta = torch.randn(1, T, HV, dtype=torch.float32).sigmoid().to(beta_dtype) + h0 = torch.randn((N, HV, D, D), dtype=torch.float32) # NOTE: for inference scenarios, we only use transposed state layout for better decoding performance h0_vk = h0.transpose(-1, -2).contiguous() q, k, v, g, beta, h0, h0_vk = map(lambda x: x.to(device).requires_grad_(False), (q, k, v, g, beta, h0, h0_vk)) + initial_state = h0.clone() if use_initial_state else None + initial_state_vk = h0_vk.clone() if use_initial_state else None + heads_per_group = HV // H + q_ref = q.repeat_interleave(heads_per_group, dim=2) + k_ref = k.repeat_interleave(heads_per_group, dim=2) tri, tri_ht = cula_kda_fused_fwd( q=F.normalize(q.clone(), p=2, dim=-1), @@ -306,7 +342,7 @@ def test_safe_gate_chunk_varlen( v=v.clone(), g=g.clone(), beta=beta.clone(), - initial_state=h0_vk.clone(), + initial_state=initial_state_vk, output_final_state=True, cu_seqlens=cu_seqlens, cu_seqlens_cpu=cu_seqlens_cpu, @@ -315,12 +351,12 @@ def test_safe_gate_chunk_varlen( ) ref_fla, ref_ht_fla = fla_chunk_kda( - q=F.normalize(q.clone(), p=2, dim=-1), - k=k.clone(), + q=F.normalize(q_ref.clone(), p=2, dim=-1), + k=k_ref.clone(), v=v.clone(), g=g.clone(), beta=beta.clone(), - initial_state=h0.clone(), + initial_state=initial_state, output_final_state=True, cu_seqlens=cu_seqlens, cu_seqlens_cpu=cu_seqlens_cpu, @@ -329,12 +365,12 @@ def test_safe_gate_chunk_varlen( ) ref_fla_trans, ref_ht_fla_trans = fla_chunk_kda( - q=F.normalize(q.clone(), p=2, dim=-1), - k=k.clone(), + q=F.normalize(q_ref.clone(), p=2, dim=-1), + k=k_ref.clone(), v=v.clone(), g=g.clone(), beta=beta.clone(), - initial_state=h0_vk.clone(), + initial_state=initial_state_vk, output_final_state=True, cu_seqlens=cu_seqlens, cu_seqlens_cpu=cu_seqlens_cpu, @@ -347,12 +383,12 @@ def test_safe_gate_chunk_varlen( ref_ht = [] for i in range(N): ref_i, ref_ht_i = naive_recurrent_kda( - q=F.normalize(q[:, cu_seqlens[i] : cu_seqlens[i + 1]], p=2, dim=-1), - k=k[:, cu_seqlens[i] : cu_seqlens[i + 1]], + q=F.normalize(q_ref[:, cu_seqlens[i] : cu_seqlens[i + 1]], p=2, dim=-1), + k=k_ref[:, cu_seqlens[i] : cu_seqlens[i + 1]], v=v[:, cu_seqlens[i] : cu_seqlens[i + 1]], beta=beta[:, cu_seqlens[i] : cu_seqlens[i + 1]], g=g[:, cu_seqlens[i] : cu_seqlens[i + 1]], - initial_state=h0[i], + initial_state=h0[i] if use_initial_state else None, output_final_state=True, ) ref.append(ref_i) @@ -361,8 +397,8 @@ def test_safe_gate_chunk_varlen( ref_ht = torch.cat(ref_ht, 0) assert_close("o", ref, tri, 0.005) - assert_close("ht", ref_ht, tri_ht.transpose(-1, -2), 0.005) assert_close("o", ref_fla, tri, 0.005) - assert_close("ht", ref_ht_fla, tri_ht.transpose(-1, -2), 0.005) assert_close("o", ref_fla_trans, tri, 0.005) + assert_close("ht", ref_ht, tri_ht.transpose(-1, -2), 0.005) + assert_close("ht", ref_ht_fla, tri_ht.transpose(-1, -2), 0.005) assert_close("ht", ref_ht_fla_trans, tri_ht, 0.005)