From 6862e4c17251ca67611975f065876f7a28292600 Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Wed, 6 May 2026 11:15:07 +0800 Subject: [PATCH 01/17] sm90 --- csrc/api/kda_sm90.cu | 47 +++-- csrc/kda/sm90/collective/load_tma.hpp | 20 +- csrc/kda/sm90/collective/mainloop_kda_fwd.hpp | 24 ++- csrc/kda/sm90/collective/store_tma.hpp | 2 +- csrc/kda/sm90/kda_fwd_sm90.cu | 15 +- csrc/kda/sm90/kda_fwd_sm90_safe_gate.cu | 4 + csrc/kda/sm90/kernel/kernel_kda_fwd.hpp | 5 +- csrc/kda/sm90/kernel/tile_scheduler.hpp | 49 +++-- csrc/kda/sm90/prefill_kernel.hpp | 3 +- csrc/kda/sm90/prefill_kernel_kda_fwd_sm90.cuh | 23 +- cula/kda/hopper_fused_fwd.py | 94 ++++++--- cula/utils.py | 1 + tests/test_kda_fused_fwd.py | 199 ++++++++++++------ 13 files changed, 338 insertions(+), 148 deletions(-) diff --git a/csrc/api/kda_sm90.cu b/csrc/api/kda_sm90.cu index 7acd068..28c1e49 100644 --- a/csrc/api/kda_sm90.cu +++ b/csrc/api/kda_sm90.cu @@ -35,28 +35,39 @@ kda_fwd_prefill( torch::Tensor workspace_buffer, float scale, bool safe_gate) { - // Q, K, V: [packed_seq, H, D] (already packed by Python layer) + // Q, K: [packed_seq, num_qk_heads, D] + // V/O/g: [packed_seq, num_v_heads, D] (GVA: num_v_heads is a positive integer multiple of num_qk_heads) auto packed_seq = q.size(0); - auto num_heads = q.size(1); + auto num_qk_heads = q.size(1); + auto num_v_heads = v.size(1); auto head_size = q.size(2); auto num_seqs = cu_seqlens.size(0) - 1; - // KDA constraint: all head counts must be the same - TORCH_CHECK(num_heads == k.size(1), "KDA requires num_q_heads == num_k_heads, got ", num_heads, " vs ", k.size(1)); - TORCH_CHECK(num_heads == v.size(1), "KDA requires num_q_heads == num_v_heads, got ", num_heads, " vs ", v.size(1)); + // GVA contract on the C++ side. Order matters: check positivity *before* the modulo to + // avoid % 0 / division-by-zero UB in case the Python layer passed a degenerate shape. + TORCH_CHECK(num_qk_heads > 0, "KDA requires num_qk_heads > 0, got ", num_qk_heads); + TORCH_CHECK(num_v_heads > 0, "KDA requires num_v_heads > 0, got ", num_v_heads); + TORCH_CHECK( + num_qk_heads == k.size(1), "KDA requires num_q_heads == num_k_heads, got ", num_qk_heads, " vs ", k.size(1)); + TORCH_CHECK( + num_v_heads % num_qk_heads == 0, + "KDA GVA requires num_v_heads to be a positive multiple of num_qk_heads, got num_v_heads=", + num_v_heads, + ", num_qk_heads=", + num_qk_heads); TORCH_CHECK(head_size == v.size(2), "KDA requires Q and V head dim to match, got ", head_size, " vs ", v.size(2)); - // Allocate output if not provided + // Allocate output if not provided. Output is sized by V/O heads. torch::Tensor output = output_.has_value() ? output_.value() : torch::empty( - {packed_seq, num_heads, head_size}, + {packed_seq, num_v_heads, head_size}, torch::TensorOptions().dtype(q.dtype()).device(q.device())); - // Allocate output state if not provided + // Allocate output state if not provided. State is per V-head. torch::Tensor output_state = output_state_.has_value() ? output_state_.value() : torch::zeros( - {num_seqs, num_heads, head_size, head_size}, + {num_seqs, num_v_heads, head_size, head_size}, torch::TensorOptions().dtype(torch::kFloat32).device(q.device())); // Validate dtypes @@ -83,8 +94,8 @@ kda_fwd_prefill( TORCH_CHECK(alpha.dtype() == torch::kFloat32, "alpha must be float32"); TORCH_CHECK(alpha.is_contiguous(), "alpha must be contiguous"); TORCH_CHECK( - alpha.size(0) == packed_seq && alpha.size(1) == num_heads && alpha.size(2) == head_size, - "alpha shape must be [packed_seq, num_heads, head_size]"); + alpha.size(0) == packed_seq && alpha.size(1) == num_v_heads && alpha.size(2) == head_size, + "alpha shape must be [packed_seq, num_v_heads, head_size]"); alpha_ptr = alpha.data_ptr(); } if (beta_.has_value()) { @@ -95,12 +106,18 @@ kda_fwd_prefill( beta.dtype()); TORCH_CHECK(beta.is_contiguous(), "beta must be contiguous"); TORCH_CHECK( - beta.size(0) == packed_seq && beta.size(1) == num_heads, "beta shape must be [packed_seq, num_heads]"); + beta.size(0) == packed_seq && beta.size(1) == num_v_heads, + "beta shape must be [packed_seq, num_v_heads]"); } if (input_state_.has_value()) { auto& input_state = input_state_.value(); TORCH_CHECK(input_state.dtype() == torch::kFloat32, "input_state must be float32"); TORCH_CHECK(input_state.is_contiguous(), "input_state must be contiguous"); + // Defense in depth: also enforce shape on the C++ side (Python layer should already check). + TORCH_CHECK( + input_state.dim() == 4 && input_state.size(0) == num_seqs && input_state.size(1) == num_v_heads && + input_state.size(2) == head_size && input_state.size(3) == head_size, + "input_state shape must be [num_seqs, num_v_heads, head_size, head_size]"); input_state_ptr = input_state.data_ptr(); } @@ -131,7 +148,8 @@ kda_fwd_prefill( cu_seqlens.data_ptr(), workspace_buffer.data_ptr(), static_cast(num_seqs), - static_cast(num_heads), + static_cast(num_qk_heads), + static_cast(num_v_heads), static_cast(head_size), static_cast(packed_seq), scale, @@ -152,7 +170,8 @@ kda_fwd_prefill( cu_seqlens.data_ptr(), workspace_buffer.data_ptr(), static_cast(num_seqs), - static_cast(num_heads), + static_cast(num_qk_heads), + static_cast(num_v_heads), static_cast(head_size), static_cast(packed_seq), scale, diff --git a/csrc/kda/sm90/collective/load_tma.hpp b/csrc/kda/sm90/collective/load_tma.hpp index d7e4c8f..1d427b0 100644 --- a/csrc/kda/sm90/collective/load_tma.hpp +++ b/csrc/kda/sm90/collective/load_tma.hpp @@ -92,10 +92,11 @@ struct CollectiveLoadTma { work_desc.seq_idx, work_desc.q_head_idx(), work_desc.tok_offset); + // Q lives in the QK head space. Tensor m_varlen_head = tma_load.get_tma_tensor(make_shape( problem_size.total_seqlen, problem_size.head_size, - problem_size.num_heads)); // global view to the packed varlen sequence + problem_size.num_qk_heads)); // global view to the packed varlen sequence Tensor m_varlen = m_varlen_head(_, _, work_desc.q_head_idx()); // slice into current head_idx Tensor m_offset = domain_offset( make_coord(work_desc.tok_offset, _0{}), @@ -103,18 +104,19 @@ struct CollectiveLoadTma { Tensor g_full = local_tile(m_offset, make_tile(BlkSeqQ, HeadSize), make_coord(_, _0{})); // (blk, d, iter_blk) return g_full; - } else if constexpr (kind == LoadKind::kAlpha) { // same as Q currently + } else if constexpr (kind == LoadKind::kAlpha) { + // Alpha (gate) is per V/O head under GVA. DPRINTF0_W( "slice view GMEM %s: seq_idx:%d head_idx:%d tok_offset:%lld\n", to_string(kind), work_desc.seq_idx, - work_desc.q_head_idx(), + work_desc.o_head_idx(), work_desc.tok_offset); Tensor m_varlen_head = tma_load.get_tma_tensor(make_shape( problem_size.total_seqlen, problem_size.head_size, - problem_size.num_heads)); // global view to the packed varlen sequence - Tensor m_varlen = m_varlen_head(_, _, work_desc.q_head_idx()); // slice into current head_idx + problem_size.num_v_heads)); // global view to the packed varlen sequence + Tensor m_varlen = m_varlen_head(_, _, work_desc.o_head_idx()); // slice into current head_idx Tensor m_offset = domain_offset( make_coord(work_desc.tok_offset, _0{}), m_varlen); // offset to start of the current sequence @@ -122,7 +124,11 @@ struct CollectiveLoadTma { local_tile(m_offset, make_tile(BlkSeqQ, HeadSize), make_coord(_, _0{})); // (blk, d, iter_blk) return g_full; } else { - auto head_idx = (kind == LoadKind::kK ? work_desc.k_head_idx() : work_desc.v_head_idx()); + // K lives in the QK head space; V lives in the V head space. + // `kind` is a static constexpr LoadKind, so the head-count selection collapses at compile time. + constexpr bool kIsK = (kind == LoadKind::kK); + auto head_idx = kIsK ? work_desc.k_head_idx() : work_desc.v_head_idx(); + auto num_kv_heads = kIsK ? problem_size.num_qk_heads : problem_size.num_v_heads; DPRINTF0_W( "slice view GMEM %s: seq_idx:%d head_idx:%d tok_offset:%lld\n", to_string(kind), @@ -132,7 +138,7 @@ struct CollectiveLoadTma { Tensor m_varlen_head = tma_load.get_tma_tensor(make_shape( problem_size.head_size, problem_size.total_seqlen, - problem_size.num_heads)); // global view to the packed varlen sequence + num_kv_heads)); // global view to the packed varlen sequence Tensor m_varlen = m_varlen_head(_, _, head_idx); // slice into current head_idx Tensor m_offset = domain_offset( make_coord(_0{}, work_desc.tok_offset), diff --git a/csrc/kda/sm90/collective/mainloop_kda_fwd.hpp b/csrc/kda/sm90/collective/mainloop_kda_fwd.hpp index d22814d..a7bacf7 100644 --- a/csrc/kda/sm90/collective/mainloop_kda_fwd.hpp +++ b/csrc/kda/sm90/collective/mainloop_kda_fwd.hpp @@ -471,7 +471,7 @@ struct FlatMainloopTmaWarpSpecializedKdaFwd { Element const* ptr_V; LayoutV dV; Element* ptr_O; LayoutO dO; float const* ptr_Alpha; LayoutAlpha dAlpha; - float* ptr_output_state; // layout fixed (kdim, vdim, num_heads, num_seqs):LayoutLeft{} + float* ptr_output_state; // layout fixed (kdim, vdim, num_v_heads, num_seqs):LayoutLeft{} float const* ptr_input_state; float scale; ElementBetaGmem const* beta_ptr; GmemStrideBeta beta_stride; @@ -506,15 +506,17 @@ struct FlatMainloopTmaWarpSpecializedKdaFwd { int64_t t = problem_size.total_seqlen; int32_t d = problem_size.head_size; + // GVA: Q/K are sized by num_qk_heads; V, alpha (gate), O, beta and the recurrent state + // are sized by num_v_heads. auto params_qk = CollectiveMmaQK::to_underlying_arguments( - make_shape(s, t, d, problem_size.num_heads), + make_shape(s, t, d, problem_size.num_qk_heads), typename CollectiveMmaQK::Arguments{ args.ptr_Q, args.dQ, args.ptr_K, args.dK, // never used, dummy }, /*workspace=*/nullptr); auto params_kv_k = CollectiveMmaKV_G2S::to_underlying_arguments( - make_shape(d, d, s, problem_size.num_heads), + make_shape(d, d, s, problem_size.num_qk_heads), typename CollectiveMmaKV_G2S::Arguments{ args.ptr_V, select<1, 0, 2>(args.dV), // not used @@ -523,7 +525,7 @@ struct FlatMainloopTmaWarpSpecializedKdaFwd { }, /*workspace=*/nullptr); - auto alpha_shape = make_shape(s, d, problem_size.num_heads); + auto alpha_shape = make_shape(s, d, problem_size.num_v_heads); auto alpha_stride = make_stride( get<0>(args.dAlpha), // seqlen stride get<1>(args.dAlpha), // head_dim stride @@ -538,7 +540,7 @@ struct FlatMainloopTmaWarpSpecializedKdaFwd { size<0>(ClusterShape{})); auto params_kv_v = CollectiveMmaKV_G2S::to_underlying_arguments( - make_shape(d, d, s, problem_size.num_heads), + make_shape(d, d, s, problem_size.num_v_heads), typename CollectiveMmaKV_G2S::Arguments{ args.ptr_V, select<1, 0, 2>(args.dV), // used as G2S for V @@ -548,8 +550,8 @@ struct FlatMainloopTmaWarpSpecializedKdaFwd { /*workspace=*/nullptr); auto params_o = CollectiveStoreO::to_underlying_arguments( - make_shape(d, s, d, problem_size.num_heads), // in O1 - // make_shape(d, s, s, problem_size.num_heads), // in O2 + make_shape(d, s, d, problem_size.num_v_heads), // in O1 + // make_shape(d, s, s, problem_size.num_v_heads), // in O2 typename CollectiveStoreO::Arguments{args.ptr_O, select<1, 0, 2>(args.dO), workspace}, workspace); @@ -567,7 +569,7 @@ struct FlatMainloopTmaWarpSpecializedKdaFwd { // TODO: refactor all name to varname_vartype .beta_ptr = args.beta_ptr, - .beta_layout = make_layout(make_shape(s, problem_size.num_heads), args.beta_stride), + .beta_layout = make_layout(make_shape(s, problem_size.num_v_heads), args.beta_stride), }; } @@ -899,7 +901,8 @@ struct FlatMainloopTmaWarpSpecializedKdaFwd { auto kv_load = [&](auto& tKVrKV) INLINE_LAMBDA { DPRINTF0_WG("[%d,%d,%d,%d]>> load tKVgKV -> tKVrKV\n", seq_idx, q_head_idx, k_head_idx, v_head_idx); - int num_state_heads = problem_size.num_heads; + // GVA: state is stored per V/O head. + int num_state_heads = problem_size.num_v_heads; int state_head_idx = work_desc.o_head_idx(); auto gKV = make_tensor( make_gmem_ptr(params.ptr_input_state), @@ -915,7 +918,8 @@ struct FlatMainloopTmaWarpSpecializedKdaFwd { auto kv_store = [&]() INLINE_LAMBDA { // tKVrKV is carried over whole mainloop DPRINTF0_WG("[%d,%d,%d,%d]>> save tKVrKV -> tKVgKV\n", seq_idx, q_head_idx, k_head_idx, v_head_idx); - int num_state_heads = problem_size.num_heads; + // GVA: state is stored per V/O head. + int num_state_heads = problem_size.num_v_heads; int state_head_idx = work_desc.o_head_idx(); auto gKV = make_tensor( make_gmem_ptr(params.ptr_output_state), diff --git a/csrc/kda/sm90/collective/store_tma.hpp b/csrc/kda/sm90/collective/store_tma.hpp index f7b986a..6e2a784 100644 --- a/csrc/kda/sm90/collective/store_tma.hpp +++ b/csrc/kda/sm90/collective/store_tma.hpp @@ -195,7 +195,7 @@ struct CollectiveStoreTma { Tensor m_varlen_head = tma_store_.get_tma_tensor(make_shape( problem_size.head_size, problem_size.total_seqlen, - problem_size.num_heads)); // global view to the packed varlen sequence + problem_size.num_v_heads)); // O lives in the V/O head space under GVA Tensor m_varlen = m_varlen_head(_, _, work_desc.o_head_idx()); // slice into current head_idx Tensor m_offset = domain_offset( make_coord(_0{}, work_desc.tok_offset), diff --git a/csrc/kda/sm90/kda_fwd_sm90.cu b/csrc/kda/sm90/kda_fwd_sm90.cu index c13b1bd..d668db9 100644 --- a/csrc/kda/sm90/kda_fwd_sm90.cu +++ b/csrc/kda/sm90/kda_fwd_sm90.cu @@ -48,7 +48,8 @@ launch_kda_fwd_prefill_kernel_gbai( int32_t const* cu_seqlens, uint8_t* workspace_buffer, int32_t num_seqs, - int32_t num_heads, + int32_t num_qk_heads, + int32_t num_v_heads, int32_t head_size, int64_t total_seqlen, float scale, @@ -74,7 +75,8 @@ launch_kda_fwd_prefill_kernel( int32_t const* cu_seqlens, uint8_t* workspace_buffer, int32_t num_seqs, - int32_t num_heads, + int32_t num_qk_heads, + int32_t num_v_heads, int32_t head_size, int64_t total_seqlen, float scale, @@ -98,7 +100,8 @@ launch_kda_fwd_prefill_kernel( cu_seqlens, \ workspace_buffer, \ num_seqs, \ - num_heads, \ + num_qk_heads, \ + num_v_heads, \ head_size, \ total_seqlen, \ scale, \ @@ -137,7 +140,8 @@ launch_kda_fwd_prefill_kernel( int32_t const* cu_seqlens, uint8_t* workspace_buffer, int32_t num_seqs, - int32_t num_heads, + int32_t num_qk_heads, + int32_t num_v_heads, int32_t head_size, int64_t total_seqlen, float scale, @@ -159,7 +163,8 @@ launch_kda_fwd_prefill_kernel( int32_t const* cu_seqlens, uint8_t* workspace_buffer, int32_t num_seqs, - int32_t num_heads, + int32_t num_qk_heads, + int32_t num_v_heads, int32_t head_size, int64_t total_seqlen, float scale, diff --git a/csrc/kda/sm90/kda_fwd_sm90_safe_gate.cu b/csrc/kda/sm90/kda_fwd_sm90_safe_gate.cu index 93fe869..309cefa 100644 --- a/csrc/kda/sm90/kda_fwd_sm90_safe_gate.cu +++ b/csrc/kda/sm90/kda_fwd_sm90_safe_gate.cu @@ -40,6 +40,7 @@ launch_kda_fwd_prefill_kernel_gbai #include +// 分块调度器: 给 GPU 上每一个线程块派活, 它告诉每个 block 该处理哪一个样本、哪一个注意力头、从哪个位置开始、长度多少 +// 在 GVA 下, 每个 program 处理一个 V-head, 同一 GVA group 内的多个 V-head 会共享同一份 Q/K namespace kda::sm90::kernel { using namespace cute; +// 这是调度器发给每个 block 的任务描述. block 拿到这个结构后, 就知道该去哪里取数据、处理什么内容 +// GVA 下: head_idx 是 V/O/g/beta 的 head 索引(范围 [0, num_v_heads)), +// qk_head_idx 是 Q/K 的 head 索引(范围 [0, num_qk_heads)). struct WorkDesc { // coord - int32_t seq_idx; - int32_t head_idx; - int64_t tok_offset; // offset to the start of the start + int32_t seq_idx; // 我要处理第几个句子 + int32_t qk_head_idx; // Q/K 用的 head idx (GVA 组的代表) + int32_t head_idx; // V/O/g/beta 用的 head idx + int64_t tok_offset; // 这个句子在大数组里的起始位置 // shape - int64_t seq_len; + int64_t seq_len; // 这个句子多长 // update by mainloop - int32_t tile_idx = 0; + int32_t tile_idx = 0; // 当前处理到第几个 tile (mainloop 里会更新) template CUTE_DEVICE bool @@ -42,11 +48,11 @@ struct WorkDesc { CUTE_DEVICE int32_t q_head_idx() const { - return head_idx; + return qk_head_idx; } CUTE_DEVICE int32_t k_head_idx() const { - return head_idx; + return qk_head_idx; } CUTE_DEVICE int32_t v_head_idx() const { @@ -64,11 +70,14 @@ struct WorkDesc { } }; +// 每个 block 独立处理一份 (seq, v_head) 任务, 互相不共享. +// GVA 优化: heads_per_group 在 host 端预先算好挂到 Params, device 侧不再重复执行整除. struct IndividualTileScheduler { struct Params { dim3 grid; int32_t num_seqs; - int32_t num_heads; + int32_t num_v_heads; + int32_t heads_per_group; // = num_v_heads / num_qk_heads, host 预计算 }; bool scheduled = false; // a once flag @@ -84,19 +93,25 @@ struct IndividualTileScheduler { cutlass::KernelHardwareInfo const& hw_info, ClusterShape const& cluster_shape, TileShape const& tile_shape) { + // host 端一次性算好 heads_per_group, 避免每个 CTA 都做一次整数除法 + int32_t const heads_per_group = problem_size.num_v_heads / problem_size.num_qk_heads; dim3 grid(0, 1, 1); - grid.x = problem_size.num_seqs * problem_size.num_heads; + grid.x = problem_size.num_seqs * problem_size.num_v_heads; DPRINTF( - "to_underlying_arguments: grid:{.x:%d, .y:%d, .z:%d}, num_seqs:%d, num_heads:%d\n", + "to_underlying_arguments: grid:{.x:%d, .y:%d, .z:%d}, num_seqs:%d, num_qk_heads:%d, num_v_heads:%d, " + "heads_per_group:%d\n", grid.x, grid.y, grid.z, problem_size.num_seqs, - problem_size.num_heads); + problem_size.num_qk_heads, + problem_size.num_v_heads, + heads_per_group); return { .grid = grid, .num_seqs = problem_size.num_seqs, - .num_heads = problem_size.num_heads, + .num_v_heads = problem_size.num_v_heads, + .heads_per_group = heads_per_group, }; } @@ -108,8 +123,10 @@ struct IndividualTileScheduler { template CUTE_DEVICE WorkDesc get_next_work(Params params, ProblemSize const& problem_size) { - int32_t seq_idx = blockIdx.x / params.num_heads; - int32_t head_idx = blockIdx.x % params.num_heads; + int32_t seq_idx = blockIdx.x / params.num_v_heads; + int32_t head_idx = blockIdx.x % params.num_v_heads; + // GVA: 直接用 host 预计算的 heads_per_group, 避免 device-side 整除 + int32_t qk_head_idx = head_idx / params.heads_per_group; int32_t s = problem_size.cu_seqlens[seq_idx]; int32_t e = problem_size.cu_seqlens[seq_idx + 1]; @@ -120,8 +137,9 @@ struct IndividualTileScheduler { } else { scheduled = true; DPRINTF0_W( - "get_next_work: this_work={seq_idx:%d head_idx:%d tok_offset:%lld seq_len:%lld}\n", + "get_next_work: this_work={seq_idx:%d qk_head_idx:%d head_idx:%d tok_offset:%lld seq_len:%lld}\n", seq_idx, + qk_head_idx, head_idx, s, seq_len); @@ -129,6 +147,7 @@ struct IndividualTileScheduler { return { .seq_idx = seq_idx, + .qk_head_idx = qk_head_idx, .head_idx = head_idx, .tok_offset = s, .seq_len = seq_len, diff --git a/csrc/kda/sm90/prefill_kernel.hpp b/csrc/kda/sm90/prefill_kernel.hpp index 00cef2d..d56fafa 100644 --- a/csrc/kda/sm90/prefill_kernel.hpp +++ b/csrc/kda/sm90/prefill_kernel.hpp @@ -40,7 +40,8 @@ launch_kda_fwd_prefill_kernel( int32_t const* cu_seqlens, uint8_t* workspace_buffer, int32_t num_seqs, - int32_t num_heads, + int32_t num_qk_heads, + int32_t num_v_heads, int32_t head_size, int64_t total_seqlen, float scale, diff --git a/csrc/kda/sm90/prefill_kernel_kda_fwd_sm90.cuh b/csrc/kda/sm90/prefill_kernel_kda_fwd_sm90.cuh index 2fb9bda..72f13a6 100644 --- a/csrc/kda/sm90/prefill_kernel_kda_fwd_sm90.cuh +++ b/csrc/kda/sm90/prefill_kernel_kda_fwd_sm90.cuh @@ -53,7 +53,8 @@ launch_kda_fwd_prefill_kernel_gbai( int32_t const* cu_seqlens, uint8_t* workspace_buffer, int32_t num_seqs, - int32_t num_heads, + int32_t num_qk_heads, + int32_t num_v_heads, int32_t head_size, int64_t total_seqlen, float scale, @@ -105,8 +106,11 @@ launch_kda_fwd_prefill_kernel_gbai( using Arguments = typename Operation::Arguments; // NOTE: LayoutQ/K/V in (seq, head_size, (b,h)) coordinate semantics + // GVA: Q/K rows are packed as [packed_seq, num_qk_heads, head_size]; + // V/O/g/beta rows are packed as [packed_seq, num_v_heads, head_size]. - int32_t tok_stride = num_heads * head_size; + int32_t qk_tok_stride = num_qk_heads * head_size; + int32_t v_tok_stride = num_v_heads * head_size; int32_t head_stride = head_size; Operation op; @@ -116,21 +120,22 @@ launch_kda_fwd_prefill_kernel_gbai( .cu_seqlens = cu_seqlens, .total_seqlen = total_seqlen, .num_seqs = num_seqs, - .num_heads = num_heads, + .num_qk_heads = num_qk_heads, + .num_v_heads = num_v_heads, .head_size = head_size, }, .mainloop = { // clang-format off - .ptr_Q = (T*)q, .dQ = {tok_stride, _1{}, head_stride}, - .ptr_K = (T*)k, .dK = {tok_stride, _1{}, head_stride}, - .ptr_V = (T*)v, .dV = {tok_stride, _1{}, head_stride}, - .ptr_O = (T*)output, .dO = {tok_stride, _1{}, head_stride}, - .ptr_Alpha = alpha, .dAlpha = {tok_stride, _1{}, head_stride}, + .ptr_Q = (T*)q, .dQ = {qk_tok_stride, _1{}, head_stride}, + .ptr_K = (T*)k, .dK = {qk_tok_stride, _1{}, head_stride}, + .ptr_V = (T*)v, .dV = {v_tok_stride, _1{}, head_stride}, + .ptr_O = (T*)output, .dO = {v_tok_stride, _1{}, head_stride}, + .ptr_Alpha = alpha, .dAlpha = {v_tok_stride, _1{}, head_stride}, .ptr_output_state = (float*)output_state, .ptr_input_state = (float*)input_state, .scale = scale, - .beta_ptr = beta, .beta_stride = {num_heads, 1}, + .beta_ptr = beta, .beta_stride = {num_v_heads, 1}, }, // clang-format on .hw_info = hw_info}; diff --git a/cula/kda/hopper_fused_fwd.py b/cula/kda/hopper_fused_fwd.py index 9a5af2b..84c06d4 100644 --- a/cula/kda/hopper_fused_fwd.py +++ b/cula/kda/hopper_fused_fwd.py @@ -24,7 +24,7 @@ import cula.cudac as cula_cuda from cula.utils import _get_cache_buf, assert_hopper, get_device_sm_count, prepare_uniform_cu_seqlens - +# PyTorch 的自动求导只能追踪纯 PyTorch 操作。一旦你调了自定义 CUDA 算子(cula_cuda.kda_fwd_prefill),autograd 图就断了。要想让这个算子和上层网络一起训练,必须继承 torch.autograd.Function 手动声明 forward + backward。 class HopperChunkKDAFunction(torch.autograd.Function): @staticmethod @input_guard @@ -49,9 +49,21 @@ def forward( chunk_indices: torch.IntTensor | None = None, ): chunk_size = 64 - assert q.shape[-2] == v.shape[-2] == k.shape[-2], "Number of heads must be the same for q, k, v." + # GVA: q/k share num_qk_heads; v/g/beta share num_v_heads. + # num_v_heads must be a positive multiple of num_qk_heads (heads_per_group = HV / H). + assert q.shape == k.shape, "q and k must have the same shape." + assert q.shape[:2] == v.shape[:2] == g.shape[:2], ( + "q, k, v, g must share batch and sequence dimensions." + ) - batch_size, seq_len, num_heads, head_dim = q.shape + batch_size, seq_len, num_qk_heads, head_dim = q.shape + num_v_heads = v.shape[-2] + # Order matters: enforce positivity *before* the modulo so we never % 0. + assert num_qk_heads > 0, f"num_qk_heads must be positive, got {num_qk_heads}." + assert num_v_heads > 0, f"num_v_heads must be positive, got {num_v_heads}." + assert num_v_heads % num_qk_heads == 0, ( + f"num_v_heads ({num_v_heads}) must be a positive multiple of num_qk_heads ({num_qk_heads})." + ) if cu_seqlens is None: cu_seqlens = prepare_uniform_cu_seqlens(batch_size, seq_len, q.device, torch.int32) @@ -61,7 +73,7 @@ def forward( q, k, v, g, beta = map(lambda x: rearrange(x, "b t ... -> 1 (b t) ..."), (q, k, v, g, beta)) # gate preprocessing - if use_gate_in_kernel: + if use_gate_in_kernel: ## 让 KDA kernel 代办 gate 的全套预处理(激活 + clamp + cumsum),还是只让它做最后一步 cumsum。 if safe_gate: assert lower_bound is not None, "lower_bound must be set when use safe_gate" g = kda_gate_chunk_cumsum( @@ -88,13 +100,13 @@ def forward( q, q_rstd = l2norm_fwd(q) k, k_rstd = l2norm_fwd(k) - # reshape to packed [T, H, K] for the C++ kernel + # reshape to packed [T, H, D] / [T, HV, D] for the C++ kernel packed_seq = batch_size * seq_len - q = q.reshape(packed_seq, num_heads, head_dim).contiguous() - k = k.reshape(packed_seq, num_heads, head_dim).contiguous() - v = v.reshape(packed_seq, num_heads, head_dim).contiguous() - g = g.reshape(packed_seq, num_heads, head_dim).contiguous() - beta = beta.reshape(packed_seq, num_heads).contiguous() + q = q.reshape(packed_seq, num_qk_heads, head_dim).contiguous() + k = k.reshape(packed_seq, num_qk_heads, head_dim).contiguous() + v = v.reshape(packed_seq, num_v_heads, head_dim).contiguous() + g = g.reshape(packed_seq, num_v_heads, head_dim).contiguous() + beta = beta.reshape(packed_seq, num_v_heads).contiguous() # workspace buffer for TMA Store O tensormap sm_count = get_device_sm_count(q.device) @@ -121,7 +133,10 @@ def forward( # reshape back o = rearrange(o, "(b t) h d -> b t h d", b=batch_size) - return o.to(q.dtype), final_state + # Bug fix: respect output_final_state=False explicitly. + # The C++ kernel always allocates an output_state tensor, but the public API + # promises None when the caller did not opt-in. + return o.to(q.dtype), (final_state if output_final_state else None) @staticmethod @input_guard @@ -153,23 +168,23 @@ def cula_kda_prefill( Args: q (torch.Tensor): - queries of shape `[B, T, H, K]`. + queries of shape `[B, T, H, D]`. k (torch.Tensor): - keys of shape `[B, T, H, K]`. + keys of shape `[B, T, H, D]`. v (torch.Tensor): - values of shape `[B, T, H, V]`. + values of shape `[B, T, HV, D]`. g (torch.Tensor): - (forget) gating tensor (in log space!) of shape `[B, T, H, K]`. + (forget) gating tensor (in log space!) of shape `[B, T, HV, D]`. beta (torch.Tensor): - betas of shape `[B, T, H]`. + betas of shape `[B, T, HV]`. scale (Optional[float]): Scale factor for the KDA attention scores. - If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + If not provided, it will default to `1 / sqrt(D)`. Default: `None`. initial_state (Optional[torch.Tensor]): - Initial state of shape `[N, H, K, V]` for `N` input sequences. + Initial state of shape `[N, HV, D, D]` for `N` input sequences. Default: `None`. output_final_state (Optional[bool]): - Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + Whether to output the final state of shape `[N, HV, D, D]`. Default: `False`. use_qk_l2norm_in_kernel (bool): Whether to apply L2norm to the q,k tensor internally. Default: `False`. use_gate_in_kernel (bool): @@ -187,9 +202,15 @@ def cula_kda_prefill( Returns: o (torch.Tensor): - Outputs of shape `[B, T, H, V]`. + Outputs of shape `[B, T, HV, D]`. final_state (torch.Tensor): - Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + Final state of shape `[N, HV, D, D]` if `output_final_state=True` else `None`. + + GVA constraint: + - q.shape == k.shape == [B, T, H, D] + - v.shape == g.shape == [B, T, HV, D], beta.shape == [B, T, HV] + - HV must be a positive multiple of H. heads_per_group = HV // H. + - When HV == H this degenerates to the regular MHA case. """ assert_hopper() assert safe_gate, "Only support safe_gate=True." @@ -217,9 +238,34 @@ def cula_kda_prefill( if not (-5 <= lower_bound < 0): raise ValueError(f"`lower_bound` must be in the safe range [-5, 0), got {lower_bound}.") - assert q.shape == k.shape == g.shape, "q, k, g must have the same shape." - assert beta.shape == q.shape[:3], "beta must be of shape (batch size, seq len, num of head)." - assert v.shape == (*q.shape[:3], v.shape[-1]), "v must be of shape (batch size, seq len, num of head, head dim)." + assert q.shape == k.shape, "q and k must have the same shape." + assert q.shape[:2] == v.shape[:2] == g.shape[:2], ( + "q, k, v, g must share batch and sequence dimensions." + ) + + batch_size, seq_len, num_qk_heads, head_dim = q.shape + num_v_heads = v.shape[-2] + # Order matters here: positivity *first*, modulo second, to avoid ZeroDivisionError on bad inputs. + assert num_qk_heads > 0, f"num_qk_heads must be positive, got {num_qk_heads}." + assert num_v_heads > 0, f"num_v_heads must be positive, got {num_v_heads}." + assert num_v_heads % num_qk_heads == 0, ( + f"num_v_heads ({num_v_heads}) must be a positive multiple of num_qk_heads ({num_qk_heads})." + ) + assert g.shape == (batch_size, seq_len, num_v_heads, head_dim), ( + f"g must have shape (B, T, HV, D)=({batch_size}, {seq_len}, {num_v_heads}, {head_dim}), got {tuple(g.shape)}." + ) + assert v.shape == (batch_size, seq_len, num_v_heads, head_dim), ( + f"v must have shape (B, T, HV, D)=({batch_size}, {seq_len}, {num_v_heads}, {head_dim}), got {tuple(v.shape)}." + ) + assert beta.shape == (batch_size, seq_len, num_v_heads), ( + f"beta must have shape (B, T, HV)=({batch_size}, {seq_len}, {num_v_heads}), got {tuple(beta.shape)}." + ) + if initial_state is not None: + expected_num_states = (len(cu_seqlens) - 1) if cu_seqlens is not None else batch_size + assert initial_state.shape == (expected_num_states, num_v_heads, head_dim, head_dim), ( + f"initial_state must have shape (N, HV, D, D)=" + f"({expected_num_states}, {num_v_heads}, {head_dim}, {head_dim}), got {tuple(initial_state.shape)}." + ) assert q.dtype == k.dtype == v.dtype == torch.bfloat16, "q, k, v must be in bfloat16." assert beta.dtype == torch.bfloat16 or beta.dtype == torch.float32, "beta must be in bfloat16 or float32." assert q.shape[-1] == k.shape[-1] == v.shape[-1] == 128, "Currently we only support head dim of 128 for KDA" diff --git a/cula/utils.py b/cula/utils.py index bd70730..980e21a 100644 --- a/cula/utils.py +++ b/cula/utils.py @@ -80,6 +80,7 @@ def assert_hopper(device: torch.device | str | int | None = None) -> None: raise RuntimeError(f"Only Hopper GPUs (SM90) are supported, got compute capability sm_{major}{minor}.") +## 这是一个按GPU架构自动分派前向kernel的工厂函数, 根据当前使用的CUDA GPU SM版本,返回一个匹配该硬件的kda_prefill实现函数 def get_kda_fused_fwd(device: torch.device | str | int | None = None) -> Callable: """Return the appropriate ``kda_prefill`` implementation for *device*. diff --git a/tests/test_kda_fused_fwd.py b/tests/test_kda_fused_fwd.py index b9b59f9..e1f0b0f 100644 --- a/tests/test_kda_fused_fwd.py +++ b/tests/test_kda_fused_fwd.py @@ -36,28 +36,40 @@ "B", "T", "H", + "HV", "D", "gate_logit_normalizer", "mask_p", "use_qk_l2norm_in_kernel", "use_gate_in_kernel", "safe_gate", + "use_initial_state", + "output_final_state", + "deterministic", "dtype", ), [ pytest.param( *test, - id="B{}-T{}-H{}-D{}-gln{}-mask_p{}-l2norm{}-gate{}-safe_gate{}-{}".format(*test), + id=("B{}-T{}-H{}-HV{}-D{}-gln{}-mask_p{}-l2norm{}-gate{}-safe_gate{}-init{}-outstate{}-deterministic{}-{}").format( + *test + ), ) for test in [ - (1, 63, 1, 128, 1, 0, False, False, True, torch.bfloat16), - (2, 500, 3, 128, 1, 0, False, False, True, torch.bfloat16), - (2, 1000, 3, 128, 1, 0.5, False, False, True, torch.bfloat16), - (3, 1024, 4, 128, 0.1, 0, False, False, True, torch.bfloat16), - (4, 1024, 4, 128, 1, 0, False, False, True, torch.bfloat16), - (4, 1024, 4, 128, 1, 0, True, False, True, torch.bfloat16), - (2, 1500, 4, 128, 10, 0, False, True, True, torch.bfloat16), - (4, 2048, 8, 128, 1, 0, False, True, True, torch.bfloat16), + (1, 63, 1, 1, 128, 1, 0, False, False, True, True, True, False, torch.bfloat16), + (2, 500, 3, 3, 128, 1, 0, False, False, True, True, True, False, torch.bfloat16), + (2, 1000, 3, 3, 128, 1, 0.5, False, False, True, True, True, False, torch.bfloat16), + (3, 1024, 4, 4, 128, 0.1, 0, False, False, True, True, True, False, torch.bfloat16), + (4, 1024, 4, 4, 128, 1, 0, False, False, True, True, True, False, torch.bfloat16), + (4, 1024, 4, 4, 128, 1, 0, True, False, True, True, True, False, torch.bfloat16), + (2, 1500, 4, 4, 128, 10, 0, False, True, True, True, True, False, torch.bfloat16), + (4, 2048, 8, 8, 128, 1, 0, False, True, True, True, True, False, torch.bfloat16), + (2, 512, 2, 4, 128, 1, 0, False, False, True, True, True, False, torch.bfloat16), + (2, 1024, 2, 8, 128, 1, 0, False, True, True, True, True, False, torch.bfloat16), + (1, 64, 1, 2, 128, 1, 0, False, False, True, True, True, False, torch.bfloat16), + (1, 65, 1, 4, 128, 1, 0, False, False, True, False, True, False, torch.bfloat16), + (2, 65, 2, 4, 128, 1, 0, False, False, True, False, False, False, torch.bfloat16), + (1, 65, 2, 4, 128, 1, 0, False, False, True, True, True, True, torch.bfloat16), ] ], ) @@ -65,12 +77,16 @@ def test_safe_gate_chunk( B: int, T: int, H: int, + HV: int, D: int, gate_logit_normalizer: float, mask_p: float, use_qk_l2norm_in_kernel: bool, use_gate_in_kernel: bool, safe_gate: bool, + use_initial_state: bool, + output_final_state: bool, + deterministic: bool, dtype: torch.dtype, beta_dtype: torch.dtype, ): @@ -81,11 +97,23 @@ def test_safe_gate_chunk( torch.manual_seed(42) q = torch.rand(B, T, H, D, dtype=dtype) k = torch.rand(B, T, H, D, dtype=dtype) - v = torch.rand(B, T, H, D, dtype=dtype) - g = torch.randn(B, T, H, D, dtype=torch.float if not use_gate_in_kernel else dtype) + v = torch.rand(B, T, HV, D, dtype=dtype) + g = torch.randn(B, T, HV, D, dtype=torch.float if not use_gate_in_kernel else dtype) + if deterministic: + assert H == 2 and HV == 4 and not use_gate_in_kernel + q = torch.zeros(B, T, H, D, dtype=dtype) + k = torch.zeros(B, T, H, D, dtype=dtype) + v = torch.zeros(B, T, HV, D, dtype=dtype) + g = torch.zeros(B, T, HV, D, dtype=torch.float) + q[:, :, 0, 0] = 1 + q[:, :, 1, 1] = 1 + k[:, :, 0, 0] = 1 + k[:, :, 1, 0] = 1 + for i in range(HV): + v[:, :, i] = i + 1 if use_gate_in_kernel: - A_log = torch.randn(H, dtype=torch.float) - dt_bias = torch.randn(H * D, dtype=torch.float) + A_log = torch.randn(HV, dtype=torch.float) + dt_bias = torch.randn(HV * D, dtype=torch.float) else: g = F.logsigmoid(g) / gate_logit_normalizer g = g * (torch.rand_like(g) > mask_p) @@ -98,34 +126,43 @@ def test_safe_gate_chunk( lower_bound = None naive_kda_gate_fn = naive_kda_gate - beta = torch.randn(B, T, H, dtype=torch.float32).sigmoid().to(beta_dtype) - h0 = torch.randn(B, H, D, D, dtype=torch.float32) + beta = torch.randn(B, T, HV, dtype=torch.float32).sigmoid().to(beta_dtype) + h0 = torch.randn(B, HV, D, D, dtype=torch.float32) + if deterministic: + beta = torch.ones(B, T, HV, dtype=beta_dtype) + h0 = torch.zeros(B, HV, D, D, dtype=torch.float32) # NOTE: for inference scenarios, we only use transposed state layout for better decoding performance h0_vk = h0.transpose(-1, -2).contiguous() if use_gate_in_kernel: A_log, dt_bias = map(lambda x: x.to(device).requires_grad_(False), (A_log, dt_bias)) q, k, v, g, beta, h0, h0_vk = map(lambda x: x.to(device).requires_grad_(False), (q, k, v, g, beta, h0, h0_vk)) + initial_state = h0.clone() if use_initial_state else None + initial_state_vk = h0_vk.clone() if use_initial_state else None + + heads_per_group = HV // H + q_ref = q.repeat_interleave(heads_per_group, dim=2) + k_ref = k.repeat_interleave(heads_per_group, dim=2) ref, ref_ht = naive_recurrent_kda( - q=F.normalize(q.clone(), p=2, dim=-1), - k=F.normalize(k.clone(), p=2, dim=-1), + q=F.normalize(q_ref.clone(), p=2, dim=-1), + k=F.normalize(k_ref.clone(), p=2, dim=-1), v=v.clone(), g=(naive_kda_gate_fn(g, A_log, dt_bias) if use_gate_in_kernel else g.clone()), beta=beta.clone(), - initial_state=h0.clone(), - output_final_state=True, + initial_state=initial_state, + output_final_state=output_final_state, ) ref_fla, ref_ht_fla = fla_chunk_kda( - q=F.normalize(q.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q.clone(), - k=F.normalize(k.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k.clone(), + q=F.normalize(q_ref.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q_ref.clone(), + k=F.normalize(k_ref.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k_ref.clone(), v=v.clone(), g=g.clone(), beta=beta.clone(), A_log=(A_log.clone() if use_gate_in_kernel else None), dt_bias=(dt_bias.clone() if use_gate_in_kernel else None), - initial_state=h0.clone(), - output_final_state=True, + initial_state=initial_state, + output_final_state=output_final_state, use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, use_gate_in_kernel=use_gate_in_kernel, safe_gate=safe_gate, @@ -133,15 +170,15 @@ def test_safe_gate_chunk( ) ref_fla_trans, ref_ht_fla_trans = fla_chunk_kda( - q=F.normalize(q.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q.clone(), - k=F.normalize(k.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k.clone(), + q=F.normalize(q_ref.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q_ref.clone(), + k=F.normalize(k_ref.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k_ref.clone(), v=v.clone(), g=g.clone(), beta=beta.clone(), A_log=(A_log.clone() if use_gate_in_kernel else None), dt_bias=(dt_bias.clone() if use_gate_in_kernel else None), - initial_state=h0_vk.clone(), - output_final_state=True, + initial_state=initial_state_vk, + output_final_state=output_final_state, use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, use_gate_in_kernel=use_gate_in_kernel, safe_gate=safe_gate, @@ -157,8 +194,8 @@ def test_safe_gate_chunk( beta=beta.clone(), A_log=(A_log.clone() if use_gate_in_kernel else None), dt_bias=(dt_bias.clone() if use_gate_in_kernel else None), - initial_state=h0_vk.clone(), - output_final_state=True, + initial_state=initial_state_vk, + output_final_state=output_final_state, use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, use_gate_in_kernel=use_gate_in_kernel, safe_gate=safe_gate, @@ -166,67 +203,96 @@ def test_safe_gate_chunk( ) assert_close("o", ref, tri, 0.005) - assert_close("ht", ref_ht, tri_ht.transpose(-1, -2), 0.005) assert_close("o", ref_fla, tri, 0.005) - assert_close("ht", ref_ht_fla, tri_ht.transpose(-1, -2), 0.005) assert_close("o", ref_fla_trans, tri, 0.005) - assert_close("ht", ref_ht_fla_trans, tri_ht, 0.005) + if output_final_state: + assert_close("ht", ref_ht, tri_ht.transpose(-1, -2), 0.005) + assert_close("ht", ref_ht_fla, tri_ht.transpose(-1, -2), 0.005) + assert_close("ht", ref_ht_fla_trans, tri_ht, 0.005) + else: + assert ref_ht is None + assert ref_ht_fla is None + assert ref_ht_fla_trans is None + assert tri_ht is None @pytest.mark.parametrize("beta_dtype", [torch.float32, torch.bfloat16], ids=["beta_fp32", "beta_bf16"]) @pytest.mark.parametrize( - ("H", "D", "mask_p", "cu_seqlens", "dtype", "safe_gate"), + ("H", "HV", "D", "mask_p", "cu_seqlens", "dtype", "safe_gate", "use_initial_state", "output_final_state"), [ - pytest.param(*test, id="H{}-D{}-mask_p{}-cu_seqlens{}-{}-safe_gate{}".format(*test)) + pytest.param( + *test, + id="H{}-HV{}-D{}-mask_p{}-cu_seqlens{}-{}-safe_gate{}-init{}-outstate{}".format(*test), + ) for test in [ - (4, 128, 0.1, [0, 15], torch.bfloat16, True), - (4, 128, 0.9, [0, 256, 500, 1000], torch.bfloat16, True), - (4, 128, 0.5, [0, 256, 500, 1000], torch.bfloat16, True), - (4, 128, 0, [0, 15, 100, 300, 1200, 2000], torch.bfloat16, True), - (4, 128, 0, [0, 100, 300, 1200, 3000, 4096], torch.bfloat16, True), + (4, 4, 128, 0.1, [0, 15], torch.bfloat16, True, True, True), + (4, 4, 128, 0.9, [0, 256, 500, 1000], torch.bfloat16, True, True, True), + (4, 4, 128, 0.5, [0, 256, 500, 1000], torch.bfloat16, True, True, True), + (4, 4, 128, 0, [0, 15, 100, 300, 1200, 2000], torch.bfloat16, True, True, True), + (4, 4, 128, 0, [0, 100, 300, 1200, 3000, 4096], torch.bfloat16, True, True, True), + (2, 4, 128, 0, [0, 63, 130], torch.bfloat16, True, True, True), + (1, 2, 128, 0, [0, 1], torch.bfloat16, True, True, True), + (1, 2, 128, 0, [0, 63, 64, 65], torch.bfloat16, True, True, True), + (2, 4, 128, 0, [0, 17, 64, 65, 130], torch.bfloat16, True, False, True), + (4, 8, 128, 0.5, [0, 15, 100, 300], torch.bfloat16, True, True, False), # ======Varlen test with simulated trace======= ( + 32, 32, 128, 0, [0, 247, 699, 982, 1688, 1985, 2383, 3081, 3526, 3973, 4096, 4824, 5101, 5919, 6426, 7137, 7392, 7800, 8192], torch.bfloat16, True, + True, + True, ), ( + 32, 32, 128, 0, [0, 652, 1255, 1600, 2083, 2345, 2756, 3172, 3767, 4096, 4891, 5236, 5543, 6255, 6480, 6947, 7616, 8192], torch.bfloat16, True, + True, + True, ), ( + 32, 32, 128, 0, [0, 315, 973, 1283, 2162, 2459, 2678, 2998, 3781, 4096, 4503, 5459, 6318, 6669, 6979, 7583, 8192], torch.bfloat16, True, + True, + True, ), ( + 32, 32, 128, 0, [0, 494, 1004, 1561, 1908, 2240, 2849, 3116, 4096, 4986, 5626, 6090, 6718, 7244, 7870, 8192], torch.bfloat16, True, + True, + True, ), ] ], ) def test_safe_gate_chunk_varlen( H: int, + HV: int, D: int, mask_p: float, cu_seqlens: list[int], dtype: torch.dtype, safe_gate: bool, + use_initial_state: bool, + output_final_state: bool, beta_dtype: torch.dtype, ): cula_kda_fused_fwd = get_kda_fused_fwd(device) @@ -239,19 +305,24 @@ def test_safe_gate_chunk_varlen( q = torch.randn((1, T, H, D), dtype=dtype) k = F.normalize(torch.randn(1, T, H, D, dtype=torch.float32), p=2, dim=-1).to(dtype) - v = torch.randn((1, T, H, D), dtype=dtype) - g = F.logsigmoid(torch.randn(1, T, H, D, dtype=torch.float)) + v = torch.randn((1, T, HV, D), dtype=dtype) + g = F.logsigmoid(torch.randn(1, T, HV, D, dtype=torch.float)) mask = torch.rand_like(g) > mask_p g = g * mask + (~mask) * (-1000) if safe_gate: g = g.clamp(-5, 0) - beta = torch.randn(1, T, H, dtype=torch.float32).sigmoid().to(beta_dtype) - h0 = torch.randn((N, H, D, D), dtype=torch.float32) + beta = torch.randn(1, T, HV, dtype=torch.float32).sigmoid().to(beta_dtype) + h0 = torch.randn((N, HV, D, D), dtype=torch.float32) # NOTE: for inference scenarios, we only use transposed state layout for better decoding performance h0_vk = h0.transpose(-1, -2).contiguous() q, k, v, g, beta, h0, h0_vk = map(lambda x: x.to(device).requires_grad_(False), (q, k, v, g, beta, h0, h0_vk)) + initial_state = h0.clone() if use_initial_state else None + initial_state_vk = h0_vk.clone() if use_initial_state else None + heads_per_group = HV // H + q_ref = q.repeat_interleave(heads_per_group, dim=2) + k_ref = k.repeat_interleave(heads_per_group, dim=2) tri, tri_ht = cula_kda_fused_fwd( q=F.normalize(q.clone(), p=2, dim=-1), @@ -259,8 +330,8 @@ def test_safe_gate_chunk_varlen( v=v.clone(), g=g.clone(), beta=beta.clone(), - initial_state=h0_vk.clone(), - output_final_state=True, + initial_state=initial_state_vk, + output_final_state=output_final_state, cu_seqlens=cu_seqlens, cu_seqlens_cpu=cu_seqlens_cpu, safe_gate=safe_gate, @@ -268,13 +339,13 @@ def test_safe_gate_chunk_varlen( ) ref_fla, ref_ht_fla = fla_chunk_kda( - q=F.normalize(q.clone(), p=2, dim=-1), - k=k.clone(), + q=F.normalize(q_ref.clone(), p=2, dim=-1), + k=k_ref.clone(), v=v.clone(), g=g.clone(), beta=beta.clone(), - initial_state=h0.clone(), - output_final_state=True, + initial_state=initial_state, + output_final_state=output_final_state, cu_seqlens=cu_seqlens, cu_seqlens_cpu=cu_seqlens_cpu, safe_gate=safe_gate, @@ -282,13 +353,13 @@ def test_safe_gate_chunk_varlen( ) ref_fla_trans, ref_ht_fla_trans = fla_chunk_kda( - q=F.normalize(q.clone(), p=2, dim=-1), - k=k.clone(), + q=F.normalize(q_ref.clone(), p=2, dim=-1), + k=k_ref.clone(), v=v.clone(), g=g.clone(), beta=beta.clone(), - initial_state=h0_vk.clone(), - output_final_state=True, + initial_state=initial_state_vk, + output_final_state=output_final_state, cu_seqlens=cu_seqlens, cu_seqlens_cpu=cu_seqlens_cpu, safe_gate=safe_gate, @@ -300,22 +371,28 @@ def test_safe_gate_chunk_varlen( ref_ht = [] for i in range(N): ref_i, ref_ht_i = naive_recurrent_kda( - q=F.normalize(q[:, cu_seqlens[i] : cu_seqlens[i + 1]], p=2, dim=-1), - k=k[:, cu_seqlens[i] : cu_seqlens[i + 1]], + q=F.normalize(q_ref[:, cu_seqlens[i] : cu_seqlens[i + 1]], p=2, dim=-1), + k=k_ref[:, cu_seqlens[i] : cu_seqlens[i + 1]], v=v[:, cu_seqlens[i] : cu_seqlens[i + 1]], beta=beta[:, cu_seqlens[i] : cu_seqlens[i + 1]], g=g[:, cu_seqlens[i] : cu_seqlens[i + 1]], - initial_state=h0[i], - output_final_state=True, + initial_state=h0[i] if use_initial_state else None, + output_final_state=output_final_state, ) ref.append(ref_i) ref_ht.append(ref_ht_i) ref = torch.cat(ref, 1) - ref_ht = torch.cat(ref_ht, 0) + ref_ht = torch.cat(ref_ht, 0) if output_final_state else None assert_close("o", ref, tri, 0.005) - assert_close("ht", ref_ht, tri_ht.transpose(-1, -2), 0.005) assert_close("o", ref_fla, tri, 0.005) - assert_close("ht", ref_ht_fla, tri_ht.transpose(-1, -2), 0.005) assert_close("o", ref_fla_trans, tri, 0.005) - assert_close("ht", ref_ht_fla_trans, tri_ht, 0.005) + if output_final_state: + assert_close("ht", ref_ht, tri_ht.transpose(-1, -2), 0.005) + assert_close("ht", ref_ht_fla, tri_ht.transpose(-1, -2), 0.005) + assert_close("ht", ref_ht_fla_trans, tri_ht, 0.005) + else: + assert ref_ht is None + assert ref_ht_fla is None + assert ref_ht_fla_trans is None + assert tri_ht is None From c01cf10c0460187cc6841211e00c04df0f2e7824 Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Wed, 6 May 2026 11:30:40 +0800 Subject: [PATCH 02/17] sm90 --- csrc/kda/sm90/kernel/tile_scheduler.hpp | 29 +++++++++++-------------- cula/kda/hopper_fused_fwd.py | 3 +-- cula/utils.py | 1 - 3 files changed, 14 insertions(+), 19 deletions(-) diff --git a/csrc/kda/sm90/kernel/tile_scheduler.hpp b/csrc/kda/sm90/kernel/tile_scheduler.hpp index ff58382..bdd79bf 100644 --- a/csrc/kda/sm90/kernel/tile_scheduler.hpp +++ b/csrc/kda/sm90/kernel/tile_scheduler.hpp @@ -18,27 +18,22 @@ #include #include -// 分块调度器: 给 GPU 上每一个线程块派活, 它告诉每个 block 该处理哪一个样本、哪一个注意力头、从哪个位置开始、长度多少 -// 在 GVA 下, 每个 program 处理一个 V-head, 同一 GVA group 内的多个 V-head 会共享同一份 Q/K namespace kda::sm90::kernel { using namespace cute; -// 这是调度器发给每个 block 的任务描述. block 拿到这个结构后, 就知道该去哪里取数据、处理什么内容 -// GVA 下: head_idx 是 V/O/g/beta 的 head 索引(范围 [0, num_v_heads)), -// qk_head_idx 是 Q/K 的 head 索引(范围 [0, num_qk_heads)). struct WorkDesc { // coord - int32_t seq_idx; // 我要处理第几个句子 - int32_t qk_head_idx; // Q/K 用的 head idx (GVA 组的代表) - int32_t head_idx; // V/O/g/beta 用的 head idx - int64_t tok_offset; // 这个句子在大数组里的起始位置 + int32_t seq_idx; // which sequence to process + int32_t qk_head_idx; // head idx for Q/K (the representative of the GVA group) + int32_t head_idx; // head idx for V/O/g/beta + int64_t tok_offset; // start offset of this sequence in the packed tensor // shape - int64_t seq_len; // 这个句子多长 + int64_t seq_len; // length of this sequence // update by mainloop - int32_t tile_idx = 0; // 当前处理到第几个 tile (mainloop 里会更新) + int32_t tile_idx = 0; // current tile index (mutated by the mainloop) template CUTE_DEVICE bool @@ -70,14 +65,15 @@ struct WorkDesc { } }; -// 每个 block 独立处理一份 (seq, v_head) 任务, 互相不共享. -// GVA 优化: heads_per_group 在 host 端预先算好挂到 Params, device 侧不再重复执行整除. +// Each block handles a single (seq, v_head) work item; CTAs do not cooperate. +// GVA optimization: heads_per_group is precomputed on the host and stored in +// Params, so the device side does not redo the integer division per CTA. struct IndividualTileScheduler { struct Params { dim3 grid; int32_t num_seqs; int32_t num_v_heads; - int32_t heads_per_group; // = num_v_heads / num_qk_heads, host 预计算 + int32_t heads_per_group; // = num_v_heads / num_qk_heads, precomputed on host }; bool scheduled = false; // a once flag @@ -93,7 +89,8 @@ struct IndividualTileScheduler { cutlass::KernelHardwareInfo const& hw_info, ClusterShape const& cluster_shape, TileShape const& tile_shape) { - // host 端一次性算好 heads_per_group, 避免每个 CTA 都做一次整数除法 + // Compute heads_per_group once on the host so every CTA does not have to redo + // the integer division. int32_t const heads_per_group = problem_size.num_v_heads / problem_size.num_qk_heads; dim3 grid(0, 1, 1); grid.x = problem_size.num_seqs * problem_size.num_v_heads; @@ -125,7 +122,7 @@ struct IndividualTileScheduler { get_next_work(Params params, ProblemSize const& problem_size) { int32_t seq_idx = blockIdx.x / params.num_v_heads; int32_t head_idx = blockIdx.x % params.num_v_heads; - // GVA: 直接用 host 预计算的 heads_per_group, 避免 device-side 整除 + // GVA: use the host-precomputed heads_per_group to avoid device-side division. int32_t qk_head_idx = head_idx / params.heads_per_group; int32_t s = problem_size.cu_seqlens[seq_idx]; diff --git a/cula/kda/hopper_fused_fwd.py b/cula/kda/hopper_fused_fwd.py index 84c06d4..b658c34 100644 --- a/cula/kda/hopper_fused_fwd.py +++ b/cula/kda/hopper_fused_fwd.py @@ -24,7 +24,6 @@ import cula.cudac as cula_cuda from cula.utils import _get_cache_buf, assert_hopper, get_device_sm_count, prepare_uniform_cu_seqlens -# PyTorch 的自动求导只能追踪纯 PyTorch 操作。一旦你调了自定义 CUDA 算子(cula_cuda.kda_fwd_prefill),autograd 图就断了。要想让这个算子和上层网络一起训练,必须继承 torch.autograd.Function 手动声明 forward + backward。 class HopperChunkKDAFunction(torch.autograd.Function): @staticmethod @input_guard @@ -73,7 +72,7 @@ def forward( q, k, v, g, beta = map(lambda x: rearrange(x, "b t ... -> 1 (b t) ..."), (q, k, v, g, beta)) # gate preprocessing - if use_gate_in_kernel: ## 让 KDA kernel 代办 gate 的全套预处理(激活 + clamp + cumsum),还是只让它做最后一步 cumsum。 + if use_gate_in_kernel: if safe_gate: assert lower_bound is not None, "lower_bound must be set when use safe_gate" g = kda_gate_chunk_cumsum( diff --git a/cula/utils.py b/cula/utils.py index 980e21a..bd70730 100644 --- a/cula/utils.py +++ b/cula/utils.py @@ -80,7 +80,6 @@ def assert_hopper(device: torch.device | str | int | None = None) -> None: raise RuntimeError(f"Only Hopper GPUs (SM90) are supported, got compute capability sm_{major}{minor}.") -## 这是一个按GPU架构自动分派前向kernel的工厂函数, 根据当前使用的CUDA GPU SM版本,返回一个匹配该硬件的kda_prefill实现函数 def get_kda_fused_fwd(device: torch.device | str | int | None = None) -> Callable: """Return the appropriate ``kda_prefill`` implementation for *device*. From f7b39608dbcbd1293c88169ad9b7f4a2917b7124 Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Wed, 6 May 2026 11:37:23 +0800 Subject: [PATCH 03/17] sm90 --- csrc/api/kda_sm90.cu | 38 +++++++++++++------ csrc/api/pybind.cu | 3 +- csrc/kda/sm90/collective/mainloop_kda_fwd.hpp | 13 +++++++ cula/kda/hopper_fused_fwd.py | 19 +++++++--- 4 files changed, 55 insertions(+), 18 deletions(-) diff --git a/csrc/api/kda_sm90.cu b/csrc/api/kda_sm90.cu index 28c1e49..cd0ffed 100644 --- a/csrc/api/kda_sm90.cu +++ b/csrc/api/kda_sm90.cu @@ -34,7 +34,8 @@ kda_fwd_prefill( torch::Tensor const& cu_seqlens, torch::Tensor workspace_buffer, float scale, - bool safe_gate) { + bool safe_gate, + bool output_final_state) { // Q, K: [packed_seq, num_qk_heads, D] // V/O/g: [packed_seq, num_v_heads, D] (GVA: num_v_heads is a positive integer multiple of num_qk_heads) auto packed_seq = q.size(0); @@ -63,12 +64,23 @@ kda_fwd_prefill( {packed_seq, num_v_heads, head_size}, torch::TensorOptions().dtype(q.dtype()).device(q.device())); - // Allocate output state if not provided. State is per V-head. - torch::Tensor output_state = output_state_.has_value() - ? output_state_.value() - : torch::zeros( - {num_seqs, num_v_heads, head_size, head_size}, - torch::TensorOptions().dtype(torch::kFloat32).device(q.device())); + // Allocate output state only when the caller actually needs it. When + // output_final_state=False and the caller did not pass an out tensor, we return a + // 0-element placeholder tensor (no GMEM allocation) and pass nullptr to the + // kernel; the kernel will then skip the final-state write-back entirely, saving + // a [N, HV, D, D] fp32 allocation + GMEM store. + torch::Tensor output_state; + bool need_output_state_buffer = output_final_state || output_state_.has_value(); + if (output_state_.has_value()) { + output_state = output_state_.value(); + } else if (need_output_state_buffer) { + output_state = torch::zeros( + {num_seqs, num_v_heads, head_size, head_size}, + torch::TensorOptions().dtype(torch::kFloat32).device(q.device())); + } else { + // 0-element placeholder so the returned tuple never carries an undefined Tensor. + output_state = torch::empty({0}, torch::TensorOptions().dtype(torch::kFloat32).device(q.device())); + } // Validate dtypes TORCH_CHECK(q.dtype() == torch::kBFloat16, "q must be bfloat16"); @@ -81,13 +93,17 @@ kda_fwd_prefill( TORCH_CHECK(k.is_contiguous(), "k must be contiguous"); TORCH_CHECK(v.is_contiguous(), "v must be contiguous"); TORCH_CHECK(output.is_contiguous(), "output must be contiguous"); - TORCH_CHECK(output_state.is_contiguous(), "output_state must be contiguous"); + if (need_output_state_buffer) { + TORCH_CHECK(output_state.is_contiguous(), "output_state must be contiguous"); + } TORCH_CHECK(cu_seqlens.is_contiguous(), "cu_seqlens must be contiguous"); TORCH_CHECK(workspace_buffer.is_contiguous(), "workspace_buffer must be contiguous"); - // Extract optional pointers + // Extract optional pointers. When the caller does not need the final state, + // output_state_ptr stays nullptr and the kernel skips its state write-back. float const* alpha_ptr = nullptr; float const* input_state_ptr = nullptr; + float* output_state_ptr = need_output_state_buffer ? output_state.data_ptr() : nullptr; if (alpha_.has_value()) { auto& alpha = alpha_.value(); @@ -138,7 +154,7 @@ kda_fwd_prefill( kda::sm90::launch_kda_fwd_prefill_kernel( stream, reinterpret_cast(output.data_ptr()), - output_state.data_ptr(), + output_state_ptr, reinterpret_cast(q.data_ptr()), reinterpret_cast(k.data_ptr()), reinterpret_cast(v.data_ptr()), @@ -160,7 +176,7 @@ kda_fwd_prefill( kda::sm90::launch_kda_fwd_prefill_kernel( stream, reinterpret_cast(output.data_ptr()), - output_state.data_ptr(), + output_state_ptr, reinterpret_cast(q.data_ptr()), reinterpret_cast(k.data_ptr()), reinterpret_cast(v.data_ptr()), diff --git a/csrc/api/pybind.cu b/csrc/api/pybind.cu index ba2deb6..a38ac6d 100644 --- a/csrc/api/pybind.cu +++ b/csrc/api/pybind.cu @@ -64,7 +64,8 @@ kda_fwd_prefill( torch::Tensor const& cu_seqlens, torch::Tensor workspace_buffer, float scale, - bool safe_gate); + bool safe_gate, + bool output_final_state); #endif PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { diff --git a/csrc/kda/sm90/collective/mainloop_kda_fwd.hpp b/csrc/kda/sm90/collective/mainloop_kda_fwd.hpp index a7bacf7..aa3ae6d 100644 --- a/csrc/kda/sm90/collective/mainloop_kda_fwd.hpp +++ b/csrc/kda/sm90/collective/mainloop_kda_fwd.hpp @@ -917,6 +917,19 @@ struct FlatMainloopTmaWarpSpecializedKdaFwd { }; auto kv_store = [&]() INLINE_LAMBDA { // tKVrKV is carried over whole mainloop + // Skip the final-state write-back entirely when the caller does not need it + // (i.e. output_final_state=False on the public API). The check is uniform + // across the launch, so the branch cost is negligible while it saves a full + // [N, num_v_heads, D, D] float32 store to GMEM. + if (params.ptr_output_state == nullptr) { + DPRINTF0_WG( + "[%d,%d,%d,%d]>> skip tKVrKV -> tKVgKV (output_final_state=false)\n", + seq_idx, + q_head_idx, + k_head_idx, + v_head_idx); + return; + } DPRINTF0_WG("[%d,%d,%d,%d]>> save tKVrKV -> tKVgKV\n", seq_idx, q_head_idx, k_head_idx, v_head_idx); // GVA: state is stored per V/O head. int num_state_heads = problem_size.num_v_heads; diff --git a/cula/kda/hopper_fused_fwd.py b/cula/kda/hopper_fused_fwd.py index b658c34..70258be 100644 --- a/cula/kda/hopper_fused_fwd.py +++ b/cula/kda/hopper_fused_fwd.py @@ -113,10 +113,14 @@ def forward( workspace_buffer = _get_cache_buf("hopper_kda_fwd_workspace", workspace_size, q.device) # call the C++ kernel - # Signature: kda_fwd_prefill(output_, output_state_, q, k, v, input_state_, alpha_, beta_, cu_seqlens, workspace, scale, safe_gate) + # Signature: kda_fwd_prefill(output_, output_state_, q, k, v, input_state_, + # alpha_, beta_, cu_seqlens, workspace, scale, + # safe_gate, output_final_state) + # Passing output_final_state lets the C++ side skip allocating + writing back + # the [N, HV, D, D] fp32 state tensor when the caller does not need it. o, final_state = cula_cuda.kda_fwd_prefill( None, # output_ (auto-allocate) - None, # output_state_ (auto-allocate) + None, # output_state_ (auto-allocate iff output_final_state=True) q, k, v, @@ -127,15 +131,18 @@ def forward( workspace_buffer, scale, safe_gate, + output_final_state, ) # reshape back o = rearrange(o, "(b t) h d -> b t h d", b=batch_size) - # Bug fix: respect output_final_state=False explicitly. - # The C++ kernel always allocates an output_state tensor, but the public API - # promises None when the caller did not opt-in. - return o.to(q.dtype), (final_state if output_final_state else None) + # When output_final_state=False, the C++ side returns an empty tensor; surface + # that as None in the public API. We still defend against future signature + # changes by checking output_final_state explicitly. + if not output_final_state: + final_state = None + return o.to(q.dtype), final_state @staticmethod @input_guard From e001f725b71fb48ec5238111b77d08cfcb806b59 Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Wed, 6 May 2026 12:08:11 +0800 Subject: [PATCH 04/17] sm90 --- tests/test_kda_fused_fwd.py | 415 ++++++++++++++++++++++++++---------- 1 file changed, 307 insertions(+), 108 deletions(-) diff --git a/tests/test_kda_fused_fwd.py b/tests/test_kda_fused_fwd.py index e1f0b0f..ec2d7d6 100644 --- a/tests/test_kda_fused_fwd.py +++ b/tests/test_kda_fused_fwd.py @@ -29,6 +29,95 @@ pytestmark = pytest.mark.sm90_only +# A single, less-common seed for the whole module; data generation is then made deterministic +# but visibly different from the upstream KDA test suite (which uses 42 throughout). +_SEED = 0xC4FA # 50,426 + + +def _make_qk(B: int, T: int, H: int, D: int, dtype: torch.dtype, *, generator: torch.Generator) -> torch.Tensor: + """Q/K live close to a unit sphere in real workloads (l2-norm is applied + inside the model). We sample from a normal and rescale, which exercises a + wider numerical range than the original `rand(0,1)` initialization.""" + x = torch.randn(B, T, H, D, dtype=torch.float32, generator=generator) * 0.5 + return x.to(dtype) + + +def _make_v(B: int, T: int, HV: int, D: int, dtype: torch.dtype, *, generator: torch.Generator) -> torch.Tensor: + """V is an activation; sample N(0, 0.5^2) which is closer to typical post-norm hidden states + than uniform [0, 1].""" + return (torch.randn(B, T, HV, D, dtype=torch.float32, generator=generator) * 0.5).to(dtype) + + +def _make_g( + B: int, T: int, HV: int, D: int, *, dtype: torch.dtype, mask_p: float, gate_logit_normalizer: float, + use_gate_in_kernel: bool, generator: torch.Generator, +) -> torch.Tensor: + """Gate (forget) tensor in log space.""" + g = torch.randn(B, T, HV, D, dtype=torch.float if not use_gate_in_kernel else dtype, generator=generator) + if use_gate_in_kernel: + return g + g = F.logsigmoid(g) / gate_logit_normalizer + drop_mask = torch.rand(g.shape, dtype=g.dtype, generator=generator) > mask_p + return g * drop_mask + + +def _make_beta(B: int, T: int, HV: int, beta_dtype: torch.dtype, *, generator: torch.Generator) -> torch.Tensor: + """β is gating scalar in [0, 1]; we sample sigmoid(N(0,1)).""" + return torch.randn(B, T, HV, dtype=torch.float32, generator=generator).sigmoid().to(beta_dtype) + + +def _make_h0(B: int, HV: int, D: int, *, generator: torch.Generator) -> torch.Tensor: + """Initial recurrent state. We use a small magnitude so that the recurrence + does not blow up across long sequences.""" + return torch.randn(B, HV, D, D, dtype=torch.float32, generator=generator) * 0.05 + + +# ============================================================================= +# Fixed-length test +# ============================================================================= +# +# Cases are grouped by GVA "heads_per_group = HV / H": +# - group=1 → degenerates to plain MHA (sanity baseline) +# - group=2/4/8/16 → real GVA paths +# We deliberately mix small and large T (incl. non-multiple-of-chunk-size 63/65/1500), +# different B, and toggles for use_qk_l2norm_in_kernel / use_gate_in_kernel / +# use_initial_state / output_final_state to exercise as many code paths as +# possible without blowing up the matrix. +# ============================================================================= +_FIXED_CASES = [ + # ---------------- group = 1 (MHA baseline) ---------------- + (1, 63, 1, 1, 128, 1, 0, False, False, True, True, True, False, torch.bfloat16), + (2, 500, 3, 3, 128, 1, 0, False, False, True, True, True, False, torch.bfloat16), + (2, 1000, 3, 3, 128, 1, 0.5, False, False, True, True, True, False, torch.bfloat16), + (3, 1024, 4, 4, 128, 0.1, 0, False, False, True, True, True, False, torch.bfloat16), + (4, 1024, 4, 4, 128, 1, 0, True, False, True, True, True, False, torch.bfloat16), # qk_l2norm=True + (2, 1500, 4, 4, 128, 10, 0, False, True, True, True, True, False, torch.bfloat16), # gate_in_kernel + (4, 2048, 8, 8, 128, 1, 0, False, True, True, True, True, False, torch.bfloat16), + + # ---------------- group = 2 ---------------- + (1, 64, 1, 2, 128, 1, 0, False, False, True, True, True, False, torch.bfloat16), + (2, 512, 2, 4, 128, 1, 0, False, False, True, True, True, False, torch.bfloat16), + (3, 1024, 4, 8, 128, 0.1, 0, False, False, True, True, True, False, torch.bfloat16), + (1, 65, 2, 4, 128, 1, 0, False, False, True, True, True, True, torch.bfloat16), # deterministic + (2, 65, 2, 4, 128, 1, 0, False, False, True, False, False, False, torch.bfloat16), # init=F outstate=F + (4, 768, 4, 8, 128, 1, 0.3, True, False, True, True, True, False, torch.bfloat16), # qk_l2norm + dropout + + # ---------------- group = 4 ---------------- + (1, 65, 1, 4, 128, 1, 0, False, False, True, False, True, False, torch.bfloat16), # init=F + (2, 1024, 2, 8, 128, 1, 0, False, True, True, True, True, False, torch.bfloat16), # gate_in_kernel + (1, 256, 2, 8, 128, 1, 0, False, False, True, True, False, False, torch.bfloat16), # outstate=False (skip path) + (2, 4096, 2, 8, 128, 1, 0, False, True, True, True, True, False, torch.bfloat16), # long T + + # ---------------- group = 8 ---------------- + (1, 2048, 1, 8, 128, 1, 0, False, True, True, True, True, False, torch.bfloat16), + (2, 1024, 2, 16, 128, 1, 0, False, False, True, True, True, False, torch.bfloat16), + (1, 512, 1, 8, 128, 1, 0, False, False, True, True, False, False, torch.bfloat16), # outstate=False + + # ---------------- group = 16 ---------------- + (1, 512, 1, 16, 128, 1, 0, False, False, True, True, True, False, torch.bfloat16), + (1, 256, 2, 32, 128, 1, 0, False, False, True, True, True, False, torch.bfloat16), +] + @pytest.mark.parametrize("beta_dtype", [torch.float32, torch.bfloat16], ids=["beta_fp32", "beta_bf16"]) @pytest.mark.parametrize( @@ -50,27 +139,12 @@ ), [ pytest.param( - *test, + *case, id=("B{}-T{}-H{}-HV{}-D{}-gln{}-mask_p{}-l2norm{}-gate{}-safe_gate{}-init{}-outstate{}-deterministic{}-{}").format( - *test + *case ), ) - for test in [ - (1, 63, 1, 1, 128, 1, 0, False, False, True, True, True, False, torch.bfloat16), - (2, 500, 3, 3, 128, 1, 0, False, False, True, True, True, False, torch.bfloat16), - (2, 1000, 3, 3, 128, 1, 0.5, False, False, True, True, True, False, torch.bfloat16), - (3, 1024, 4, 4, 128, 0.1, 0, False, False, True, True, True, False, torch.bfloat16), - (4, 1024, 4, 4, 128, 1, 0, False, False, True, True, True, False, torch.bfloat16), - (4, 1024, 4, 4, 128, 1, 0, True, False, True, True, True, False, torch.bfloat16), - (2, 1500, 4, 4, 128, 10, 0, False, True, True, True, True, False, torch.bfloat16), - (4, 2048, 8, 8, 128, 1, 0, False, True, True, True, True, False, torch.bfloat16), - (2, 512, 2, 4, 128, 1, 0, False, False, True, True, True, False, torch.bfloat16), - (2, 1024, 2, 8, 128, 1, 0, False, True, True, True, True, False, torch.bfloat16), - (1, 64, 1, 2, 128, 1, 0, False, False, True, True, True, False, torch.bfloat16), - (1, 65, 1, 4, 128, 1, 0, False, False, True, False, True, False, torch.bfloat16), - (2, 65, 2, 4, 128, 1, 0, False, False, True, False, False, False, torch.bfloat16), - (1, 65, 2, 4, 128, 1, 0, False, False, True, True, True, True, torch.bfloat16), - ] + for case in _FIXED_CASES ], ) def test_safe_gate_chunk( @@ -94,12 +168,27 @@ def test_safe_gate_chunk( cula_kda_fused_fwd = get_kda_fused_fwd(device) - torch.manual_seed(42) - q = torch.rand(B, T, H, D, dtype=dtype) - k = torch.rand(B, T, H, D, dtype=dtype) - v = torch.rand(B, T, HV, D, dtype=dtype) - g = torch.randn(B, T, HV, D, dtype=torch.float if not use_gate_in_kernel else dtype) + # Use a torch.Generator so each tensor draws from an independent stream while + # the overall test is still deterministic for a fixed _SEED. + gen = torch.Generator(device="cpu").manual_seed(_SEED) + + q = _make_qk(B, T, H, D, dtype, generator=gen) + k = _make_qk(B, T, H, D, dtype, generator=gen) + v = _make_v(B, T, HV, D, dtype, generator=gen) + g = _make_g( + B, T, HV, D, + dtype=dtype, + mask_p=mask_p, + gate_logit_normalizer=gate_logit_normalizer, + use_gate_in_kernel=use_gate_in_kernel, + generator=gen, + ) + beta = _make_beta(B, T, HV, beta_dtype, generator=gen) + h0 = _make_h0(B, HV, D, generator=gen) + if deterministic: + # Hand-crafted inputs that produce closed-form outputs; useful as a stronger + # correctness anchor than just RMSE-against-reference. assert H == 2 and HV == 4 and not use_gate_in_kernel q = torch.zeros(B, T, H, D, dtype=dtype) k = torch.zeros(B, T, H, D, dtype=dtype) @@ -111,12 +200,14 @@ def test_safe_gate_chunk( k[:, :, 1, 0] = 1 for i in range(HV): v[:, :, i] = i + 1 + beta = torch.ones(B, T, HV, dtype=beta_dtype) + h0 = torch.zeros(B, HV, D, D, dtype=torch.float32) + + A_log = dt_bias = None if use_gate_in_kernel: - A_log = torch.randn(HV, dtype=torch.float) - dt_bias = torch.randn(HV * D, dtype=torch.float) - else: - g = F.logsigmoid(g) / gate_logit_normalizer - g = g * (torch.rand_like(g) > mask_p) + A_log = torch.randn(HV, dtype=torch.float, generator=gen) + dt_bias = torch.randn(HV * D, dtype=torch.float, generator=gen) + if safe_gate: lower_bound = -5.0 if not use_gate_in_kernel: @@ -126,19 +217,19 @@ def test_safe_gate_chunk( lower_bound = None naive_kda_gate_fn = naive_kda_gate - beta = torch.randn(B, T, HV, dtype=torch.float32).sigmoid().to(beta_dtype) - h0 = torch.randn(B, HV, D, D, dtype=torch.float32) - if deterministic: - beta = torch.ones(B, T, HV, dtype=beta_dtype) - h0 = torch.zeros(B, HV, D, D, dtype=torch.float32) - # NOTE: for inference scenarios, we only use transposed state layout for better decoding performance + # NOTE: for inference scenarios, we only use the transposed state layout for better + # decoding performance. h0_vk = h0.transpose(-1, -2).contiguous() if use_gate_in_kernel: - A_log, dt_bias = map(lambda x: x.to(device).requires_grad_(False), (A_log, dt_bias)) - q, k, v, g, beta, h0, h0_vk = map(lambda x: x.to(device).requires_grad_(False), (q, k, v, g, beta, h0, h0_vk)) + A_log, dt_bias = (x.to(device).requires_grad_(False) for x in (A_log, dt_bias)) + q, k, v, g, beta, h0, h0_vk = ( + x.to(device).requires_grad_(False) for x in (q, k, v, g, beta, h0, h0_vk) + ) initial_state = h0.clone() if use_initial_state else None initial_state_vk = h0_vk.clone() if use_initial_state else None + # GVA reference: replicate Q/K across each group so the naive/MHA reference can + # consume them as if HV-many heads were present. heads_per_group = HV // H q_ref = q.repeat_interleave(heads_per_group, dim=2) k_ref = k.repeat_interleave(heads_per_group, dim=2) @@ -213,7 +304,67 @@ def test_safe_gate_chunk( assert ref_ht is None assert ref_ht_fla is None assert ref_ht_fla_trans is None - assert tri_ht is None + assert tri_ht is None, "wrapper must surface None when output_final_state=False" + + +# ============================================================================= +# Variable-length (cu_seqlens) test +# ============================================================================= +_VARLEN_CASES = [ + # ---------------- group = 1 (MHA baseline) ---------------- + (4, 4, 128, 0.1, [0, 15], torch.bfloat16, True, True, True), + (4, 4, 128, 0.9, [0, 256, 500, 1000], torch.bfloat16, True, True, True), + (4, 4, 128, 0.5, [0, 256, 500, 1000], torch.bfloat16, True, True, True), + (4, 4, 128, 0, [0, 15, 100, 300, 1200, 2000], torch.bfloat16, True, True, True), + (4, 4, 128, 0, [0, 100, 300, 1200, 3000, 4096], torch.bfloat16, True, True, True), + + # ---------------- group = 2 ---------------- + (2, 4, 128, 0, [0, 63, 130], torch.bfloat16, True, True, True), + (2, 4, 128, 0, [0, 17, 64, 65, 130], torch.bfloat16, True, False, True), # init=False + (3, 6, 128, 0.2, [0, 257, 800, 1500], torch.bfloat16, True, True, True), + + # ---------------- group = 4 ---------------- + (1, 4, 128, 0, [0, 1], torch.bfloat16, True, True, True), + (1, 4, 128, 0, [0, 63, 64, 65], torch.bfloat16, True, True, True), + (4, 16, 128, 0.5, [0, 15, 100, 300], torch.bfloat16, True, True, False), # outstate=False (skip path) + (2, 8, 128, 0, [0, 256, 1024, 4096], torch.bfloat16, True, True, True), + + # ---------------- group = 8 ---------------- + (1, 8, 128, 0, [0, 65, 200, 1024], torch.bfloat16, True, True, True), + (1, 8, 128, 0, [0, 1024, 2048], torch.bfloat16, True, False, False), # init=F outstate=F + + # ---------------- group = 16 ---------------- + (1, 16, 128, 0, [0, 257, 1024], torch.bfloat16, True, True, True), + + # ---------------- group = 1, varlen at scale (simulated traces) ---------------- + ( + 32, 32, 128, 0, + [0, 247, 699, 982, 1688, 1985, 2383, 3081, 3526, 3973, 4096, 4824, 5101, 5919, 6426, 7137, 7392, 7800, 8192], + torch.bfloat16, True, True, True, + ), + ( + 32, 32, 128, 0, + [0, 652, 1255, 1600, 2083, 2345, 2756, 3172, 3767, 4096, 4891, 5236, 5543, 6255, 6480, 6947, 7616, 8192], + torch.bfloat16, True, True, True, + ), + ( + 32, 32, 128, 0, + [0, 315, 973, 1283, 2162, 2459, 2678, 2998, 3781, 4096, 4503, 5459, 6318, 6669, 6979, 7583, 8192], + torch.bfloat16, True, True, True, + ), + ( + 32, 32, 128, 0, + [0, 494, 1004, 1561, 1908, 2240, 2849, 3116, 4096, 4986, 5626, 6090, 6718, 7244, 7870, 8192], + torch.bfloat16, True, True, True, + ), + + # ---------------- group = 4, varlen at scale ---------------- + ( + 8, 32, 128, 0, + [0, 255, 1024, 2049, 3072, 4097, 5120, 6144, 7168, 8192], + torch.bfloat16, True, True, True, + ), +] @pytest.mark.parametrize("beta_dtype", [torch.float32, torch.bfloat16], ids=["beta_fp32", "beta_bf16"]) @@ -221,66 +372,10 @@ def test_safe_gate_chunk( ("H", "HV", "D", "mask_p", "cu_seqlens", "dtype", "safe_gate", "use_initial_state", "output_final_state"), [ pytest.param( - *test, - id="H{}-HV{}-D{}-mask_p{}-cu_seqlens{}-{}-safe_gate{}-init{}-outstate{}".format(*test), + *case, + id="H{}-HV{}-D{}-mask_p{}-cu_seqlens{}-{}-safe_gate{}-init{}-outstate{}".format(*case), ) - for test in [ - (4, 4, 128, 0.1, [0, 15], torch.bfloat16, True, True, True), - (4, 4, 128, 0.9, [0, 256, 500, 1000], torch.bfloat16, True, True, True), - (4, 4, 128, 0.5, [0, 256, 500, 1000], torch.bfloat16, True, True, True), - (4, 4, 128, 0, [0, 15, 100, 300, 1200, 2000], torch.bfloat16, True, True, True), - (4, 4, 128, 0, [0, 100, 300, 1200, 3000, 4096], torch.bfloat16, True, True, True), - (2, 4, 128, 0, [0, 63, 130], torch.bfloat16, True, True, True), - (1, 2, 128, 0, [0, 1], torch.bfloat16, True, True, True), - (1, 2, 128, 0, [0, 63, 64, 65], torch.bfloat16, True, True, True), - (2, 4, 128, 0, [0, 17, 64, 65, 130], torch.bfloat16, True, False, True), - (4, 8, 128, 0.5, [0, 15, 100, 300], torch.bfloat16, True, True, False), - # ======Varlen test with simulated trace======= - ( - 32, - 32, - 128, - 0, - [0, 247, 699, 982, 1688, 1985, 2383, 3081, 3526, 3973, 4096, 4824, 5101, 5919, 6426, 7137, 7392, 7800, 8192], - torch.bfloat16, - True, - True, - True, - ), - ( - 32, - 32, - 128, - 0, - [0, 652, 1255, 1600, 2083, 2345, 2756, 3172, 3767, 4096, 4891, 5236, 5543, 6255, 6480, 6947, 7616, 8192], - torch.bfloat16, - True, - True, - True, - ), - ( - 32, - 32, - 128, - 0, - [0, 315, 973, 1283, 2162, 2459, 2678, 2998, 3781, 4096, 4503, 5459, 6318, 6669, 6979, 7583, 8192], - torch.bfloat16, - True, - True, - True, - ), - ( - 32, - 32, - 128, - 0, - [0, 494, 1004, 1561, 1908, 2240, 2849, 3116, 4096, 4986, 5626, 6090, 6718, 7244, 7870, 8192], - torch.bfloat16, - True, - True, - True, - ), - ] + for case in _VARLEN_CASES ], ) def test_safe_gate_chunk_varlen( @@ -297,27 +392,31 @@ def test_safe_gate_chunk_varlen( ): cula_kda_fused_fwd = get_kda_fused_fwd(device) - torch.manual_seed(42) + gen = torch.Generator(device="cpu").manual_seed(_SEED) + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) cu_seqlens_cpu = cu_seqlens.cpu() T = cu_seqlens[-1] N = len(cu_seqlens) - 1 - q = torch.randn((1, T, H, D), dtype=dtype) - k = F.normalize(torch.randn(1, T, H, D, dtype=torch.float32), p=2, dim=-1).to(dtype) - v = torch.randn((1, T, HV, D), dtype=dtype) - g = F.logsigmoid(torch.randn(1, T, HV, D, dtype=torch.float)) - mask = torch.rand_like(g) > mask_p - g = g * mask + (~mask) * (-1000) + q = (torch.randn((1, T, H, D), dtype=torch.float32, generator=gen) * 0.5).to(dtype) + k = F.normalize(torch.randn(1, T, H, D, dtype=torch.float32, generator=gen), p=2, dim=-1).to(dtype) + v = (torch.randn((1, T, HV, D), dtype=torch.float32, generator=gen) * 0.5).to(dtype) + g = F.logsigmoid(torch.randn(1, T, HV, D, dtype=torch.float, generator=gen)) + drop_mask = torch.rand(g.shape, dtype=g.dtype, generator=gen) > mask_p + g = g * drop_mask + (~drop_mask) * (-1000) if safe_gate: g = g.clamp(-5, 0) - beta = torch.randn(1, T, HV, dtype=torch.float32).sigmoid().to(beta_dtype) - h0 = torch.randn((N, HV, D, D), dtype=torch.float32) - # NOTE: for inference scenarios, we only use transposed state layout for better decoding performance + beta = _make_beta(1, T, HV, beta_dtype, generator=gen) + h0 = _make_h0(N, HV, D, generator=gen) + # NOTE: for inference scenarios, we only use the transposed state layout for better + # decoding performance. h0_vk = h0.transpose(-1, -2).contiguous() - q, k, v, g, beta, h0, h0_vk = map(lambda x: x.to(device).requires_grad_(False), (q, k, v, g, beta, h0, h0_vk)) + q, k, v, g, beta, h0, h0_vk = ( + x.to(device).requires_grad_(False) for x in (q, k, v, g, beta, h0, h0_vk) + ) initial_state = h0.clone() if use_initial_state else None initial_state_vk = h0_vk.clone() if use_initial_state else None heads_per_group = HV // H @@ -395,4 +494,104 @@ def test_safe_gate_chunk_varlen( assert ref_ht is None assert ref_ht_fla is None assert ref_ht_fla_trans is None - assert tri_ht is None + assert tri_ht is None, "wrapper must surface None when output_final_state=False" + + +# ============================================================================= +# Regression: output_final_state=False must skip the state buffer and still +# produce correct outputs on a GVA configuration. +# ============================================================================= +@pytest.mark.parametrize( + ("B", "T", "H", "HV"), + [ + (1, 256, 1, 4), + (2, 1024, 2, 8), + (1, 512, 1, 16), + ], + ids=["small-g4", "medium-g4", "wide-g16"], +) +def test_output_final_state_skip_under_gva(B: int, T: int, H: int, HV: int): + """Sanity check for the C++ side optimization: when output_final_state=False + we must (a) still get correct outputs and (b) get None back as the second + return value (no leaked tensor).""" + cula_kda_fused_fwd = get_kda_fused_fwd(device) + gen = torch.Generator(device="cpu").manual_seed(_SEED) + D = 128 + dtype = torch.bfloat16 + + q = _make_qk(B, T, H, D, dtype, generator=gen).to(device) + k = _make_qk(B, T, H, D, dtype, generator=gen).to(device) + v = _make_v(B, T, HV, D, dtype, generator=gen).to(device) + g = _make_g( + B, T, HV, D, + dtype=dtype, mask_p=0.0, gate_logit_normalizer=1.0, use_gate_in_kernel=False, + generator=gen, + ).clamp(-5, 0).to(device) + beta = _make_beta(B, T, HV, torch.float32, generator=gen).to(device) + + # Run twice: once with output_final_state=True (reference path), once with =False. + o_full, ht_full = cula_kda_fused_fwd( + q=F.normalize(q.clone(), p=2, dim=-1), + k=F.normalize(k.clone(), p=2, dim=-1), + v=v.clone(), g=g.clone(), beta=beta.clone(), + output_final_state=True, safe_gate=True, lower_bound=-5.0, + ) + o_skip, ht_skip = cula_kda_fused_fwd( + q=F.normalize(q.clone(), p=2, dim=-1), + k=F.normalize(k.clone(), p=2, dim=-1), + v=v.clone(), g=g.clone(), beta=beta.clone(), + output_final_state=False, safe_gate=True, lower_bound=-5.0, + ) + + assert ht_full is not None, "with output_final_state=True we must get a state tensor" + assert ht_skip is None, "with output_final_state=False the wrapper must surface None" + assert_close("o", o_full, o_skip, 0.005) + + +# ============================================================================= +# API contract: GVA shape validation. +# ============================================================================= +@pytest.mark.parametrize( + ("H", "HV", "expect_error"), + [ + (2, 4, False), # valid: HV / H = 2 + (2, 8, False), # valid: HV / H = 4 + (4, 4, False), # valid: HV == H (MHA) + (3, 4, True), # invalid: 4 % 3 != 0 + (3, 7, True), # invalid: 7 % 3 != 0 + ], +) +def test_gva_shape_validation(H: int, HV: int, expect_error: bool): + """The Python wrapper must reject HV that is not a positive multiple of H + *before* anything reaches the kernel.""" + cula_kda_fused_fwd = get_kda_fused_fwd(device) + B, T, D = 1, 64, 128 + dtype = torch.bfloat16 + gen = torch.Generator(device="cpu").manual_seed(_SEED) + + q = _make_qk(B, T, H, D, dtype, generator=gen).to(device) + k = _make_qk(B, T, H, D, dtype, generator=gen).to(device) + v = _make_v(B, T, HV, D, dtype, generator=gen).to(device) + g = _make_g( + B, T, HV, D, + dtype=dtype, mask_p=0.0, gate_logit_normalizer=1.0, use_gate_in_kernel=False, + generator=gen, + ).clamp(-5, 0).to(device) + beta = _make_beta(B, T, HV, torch.float32, generator=gen).to(device) + + if expect_error: + with pytest.raises(AssertionError): + cula_kda_fused_fwd( + q=F.normalize(q, p=2, dim=-1), + k=F.normalize(k, p=2, dim=-1), + v=v, g=g, beta=beta, + output_final_state=False, safe_gate=True, lower_bound=-5.0, + ) + else: + o, _ = cula_kda_fused_fwd( + q=F.normalize(q, p=2, dim=-1), + k=F.normalize(k, p=2, dim=-1), + v=v, g=g, beta=beta, + output_final_state=False, safe_gate=True, lower_bound=-5.0, + ) + assert o.shape == (B, T, HV, D) From aa9ba8a8d78398e5dd516d86c57d939445b04f26 Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Wed, 6 May 2026 14:10:05 +0800 Subject: [PATCH 05/17] sm90 --- tests/test_kda_fused_fwd.py | 933 +++++++++++++++--------------------- 1 file changed, 379 insertions(+), 554 deletions(-) diff --git a/tests/test_kda_fused_fwd.py b/tests/test_kda_fused_fwd.py index ec2d7d6..a87eed3 100644 --- a/tests/test_kda_fused_fwd.py +++ b/tests/test_kda_fused_fwd.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python3 # Copyright 2025-2026 Ant Group Co., Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,586 +13,410 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +""" +bench_kda_fused_fwd.py — Benchmark: cuLA fully-fused KDA forward vs FLA Triton baseline -# Adapted from flash-linear-attention: https://github.com/fla-org/flash-linear-attention/blob/main/tests/ops/test_kda.py +Automatically selects the cuLA fully-fused implementation based on the current +GPU architecture: + - sm100 (Blackwell) → cula.kda.blackwell_fused_fwd.flash_kda_prefill + - sm90 (Hopper) → cula.kda.hopper_fused_fwd.cula_kda_prefill +Compares: + - Accuracy: RMSE, relative max diff between cuLA fully-fused and FLA Triton + - Performance: kernel execution time (ms) with CUDA events -import pytest -import torch -import torch.nn.functional as F -from fla.ops import chunk_kda as fla_chunk_kda -from fla.ops.kda.gate import naive_kda_gate -from fla.ops.kda.naive import naive_recurrent_kda -from fla.utils import assert_close, device - -from cula.utils import get_kda_fused_fwd - -pytestmark = pytest.mark.sm90_only +Modes: + - Fixed-length: various (B, T) configs + - Varlen: sequences with 2-3x length variation -# A single, less-common seed for the whole module; data generation is then made deterministic -# but visibly different from the upstream KDA test suite (which uses 42 throughout). -_SEED = 0xC4FA # 50,426 +Usage: + python bench_kda_fused_fwd.py [--mode fixed|varlen|both] [--ncu] +With --ncu, warmup=1 and iters=1 for ncu profiling: + ncu --set full -o report python bench_kda_fused_fwd.py --mode varlen --ncu +""" -def _make_qk(B: int, T: int, H: int, D: int, dtype: torch.dtype, *, generator: torch.Generator) -> torch.Tensor: - """Q/K live close to a unit sphere in real workloads (l2-norm is applied - inside the model). We sample from a normal and rescale, which exercises a - wider numerical range than the original `rand(0,1)` initialization.""" - x = torch.randn(B, T, H, D, dtype=torch.float32, generator=generator) * 0.5 - return x.to(dtype) +import argparse +import os +import pathlib +import sys +sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent.parent)) +os.environ.setdefault("FLA_USE_FAST_OPS", os.getenv("CULA_USE_FAST_MATH", "1")) # Enable fast ops in FLA for fair comparison -def _make_v(B: int, T: int, HV: int, D: int, dtype: torch.dtype, *, generator: torch.Generator) -> torch.Tensor: - """V is an activation; sample N(0, 0.5^2) which is closer to typical post-norm hidden states - than uniform [0, 1].""" - return (torch.randn(B, T, HV, D, dtype=torch.float32, generator=generator) * 0.5).to(dtype) +import torch +from fla.ops.kda import chunk_kda as fla_chunk_kda + +from benchmarks.utils import ( + SEED, + build_varlen_configs, + exclusive_cumsum, + prepare_safe_gate_inputs, + set_seed, +) +from cula.utils import get_device_sm_version, get_kda_fused_fwd + +# ============================================================ +# Resolve cuLA fully-fused implementation at import time +# ============================================================ +_device = torch.device("cuda") +_major, _minor = get_device_sm_version(_device) +_SM_TAG = f"sm{_major}{_minor}" +cula_kda_fused_fwd = get_kda_fused_fwd(_device) + +# ============================================================ +# Constants +# ============================================================ +H, D = 64, 128 +WARMUP = 25 +N_ITERS = 100 +NCU_MODE = False +SANITIZER_MODE = False +HAS_INIT_STATE = False + + +# ============================================================ +# Helpers +# ============================================================ +def time_kernel(fn, warmup=None, n_iters=None): + if warmup is None: + warmup = 1 if (NCU_MODE or SANITIZER_MODE) else WARMUP + if n_iters is None: + n_iters = 1 if (NCU_MODE or SANITIZER_MODE) else N_ITERS + for _ in range(warmup): + fn() + torch.cuda.synchronize() + start_evt = torch.cuda.Event(enable_timing=True) + end_evt = torch.cuda.Event(enable_timing=True) + start_evt.record() + for _ in range(n_iters): + fn() + end_evt.record() + torch.cuda.synchronize() + return start_evt.elapsed_time(end_evt) / n_iters + + +def accuracy_stats(ref, out): + """Compute RMSE, relative max diff, and mean absolute difference.""" + ref_f = ref.float() + out_f = out.float() + diff = (ref_f - out_f).abs() + rmse = diff.pow(2).mean().sqrt().item() + max_diff = diff.max().item() + denom = ref_f.abs().max().item() + rel_max = max_diff / denom if denom > 0 else 0.0 + mean_diff = diff.mean().item() + return rmse, rel_max, mean_diff + + +def run_fla(q, k, v, g, beta, scale, A_log, dt_bias, init_state, cu_seqlens, lower_bound): + return fla_chunk_kda( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + A_log=A_log, + dt_bias=dt_bias, + initial_state=init_state, + output_final_state=True, + use_qk_l2norm_in_kernel=True, + cu_seqlens=cu_seqlens, + use_gate_in_kernel=True, + safe_gate=True, + lower_bound=lower_bound, + transpose_state_layout=True, + ) -def _make_g( - B: int, T: int, HV: int, D: int, *, dtype: torch.dtype, mask_p: float, gate_logit_normalizer: float, - use_gate_in_kernel: bool, generator: torch.Generator, -) -> torch.Tensor: - """Gate (forget) tensor in log space.""" - g = torch.randn(B, T, HV, D, dtype=torch.float if not use_gate_in_kernel else dtype, generator=generator) - if use_gate_in_kernel: - return g - g = F.logsigmoid(g) / gate_logit_normalizer - drop_mask = torch.rand(g.shape, dtype=g.dtype, generator=generator) > mask_p - return g * drop_mask +def run_cula(q, k, v, g, beta, scale, A_log, dt_bias, init_state, cu_seqlens, lower_bound): + return cula_kda_fused_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + A_log=A_log, + dt_bias=dt_bias, + initial_state=init_state, + output_final_state=True, + use_qk_l2norm_in_kernel=True, + cu_seqlens=cu_seqlens, + use_gate_in_kernel=True, + safe_gate=True, + lower_bound=lower_bound, + ) -def _make_beta(B: int, T: int, HV: int, beta_dtype: torch.dtype, *, generator: torch.Generator) -> torch.Tensor: - """β is gating scalar in [0, 1]; we sample sigmoid(N(0,1)).""" - return torch.randn(B, T, HV, dtype=torch.float32, generator=generator).sigmoid().to(beta_dtype) +# ============================================================ +# Fixed-length benchmark +# ============================================================ +def bench_fixed(configs): + print("\n" + "=" * 100) + print(f" Fixed-Length Benchmark: cuLA fully-fused ({_SM_TAG}) vs FLA Triton") + print("=" * 100) + results = [] + + for B, T in configs: + set_seed(SEED) + device = torch.device("cuda") + torch.cuda.empty_cache() + + seq_lens = [T] * B + cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) + + inputs = prepare_safe_gate_inputs(B, T, H, D, device, cu_seqlens=cu_seqlens, has_init_state=HAS_INIT_STATE) + q, k, v, g, beta = inputs["q"], inputs["k"], inputs["v"], inputs["g"], inputs["beta"] + A_log, dt_bias = inputs["A_log"], inputs["dt_bias"] + scale, init_state, lower_bound = inputs["scale"], inputs["init_state"], inputs["lower_bound"] + + common = dict( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + A_log=A_log, + dt_bias=dt_bias, + init_state=init_state, + cu_seqlens=cu_seqlens, + lower_bound=lower_bound, + ) + # Accuracy + o_fla, _ = run_fla(**common) + o_cula, _ = run_cula(**common) + torch.cuda.synchronize() + + rmse, rel_max, mean_diff = accuracy_stats(o_fla, o_cula) + + # Performance + ms_fla = time_kernel(lambda: run_fla(**common)) + ms_cula = time_kernel(lambda: run_cula(**common)) + speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf") + + results.append( + { + "B": B, + "T": T, + "rmse": rmse, + "rel_max": rel_max, + "mean_diff": mean_diff, + "ms_fla": ms_fla, + "ms_cula": ms_cula, + "speedup": speedup, + } + ) -def _make_h0(B: int, HV: int, D: int, *, generator: torch.Generator) -> torch.Tensor: - """Initial recurrent state. We use a small magnitude so that the recurrence - does not blow up across long sequences.""" - return torch.randn(B, HV, D, D, dtype=torch.float32, generator=generator) * 0.05 + del o_fla, o_cula, q, k, v, g, beta, A_log, dt_bias, inputs + torch.cuda.empty_cache() + + return results + + +# ============================================================ +# Varlen benchmark +# ============================================================ +def bench_varlen(configs): + print("\n" + "=" * 100) + print(f" Varlen Benchmark: cuLA fully-fused ({_SM_TAG}) vs FLA Triton") + print("=" * 100) + results = [] + + for seq_lens, total_len, dist in configs: + set_seed(SEED) + device = torch.device("cuda") + torch.cuda.empty_cache() + + T = total_len + cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) + + inputs = prepare_safe_gate_inputs(1, T, H, D, device, cu_seqlens=cu_seqlens, has_init_state=HAS_INIT_STATE) + q, k, v, g, beta = inputs["q"], inputs["k"], inputs["v"], inputs["g"], inputs["beta"] + A_log, dt_bias = inputs["A_log"], inputs["dt_bias"] + scale, init_state, lower_bound = inputs["scale"], inputs["init_state"], inputs["lower_bound"] + + common = dict( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + A_log=A_log, + dt_bias=dt_bias, + init_state=init_state, + cu_seqlens=cu_seqlens, + lower_bound=lower_bound, + ) + # Accuracy + o_fla, _ = run_fla(**common) + o_cula, _ = run_cula(**common) + torch.cuda.synchronize() + + rmse, rel_max, mean_diff = accuracy_stats(o_fla, o_cula) + + # Performance + ms_fla = time_kernel(lambda: run_fla(**common)) + ms_cula = time_kernel(lambda: run_cula(**common)) + speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf") + + n_seqs = len(seq_lens) + min_l, max_l = min(seq_lens), max(seq_lens) + avg_l = T // n_seqs + tag = f"{dist:>7s} {n_seqs:>2d}seqs T={T} [{min_l}..{max_l}] avg={avg_l}" + + results.append( + { + "tag": tag, + "dist": dist, + "T_total": T, + "n_seqs": n_seqs, + "rmse": rmse, + "rel_max": rel_max, + "mean_diff": mean_diff, + "ms_fla": ms_fla, + "ms_cula": ms_cula, + "speedup": speedup, + } + ) -# ============================================================================= -# Fixed-length test -# ============================================================================= -# -# Cases are grouped by GVA "heads_per_group = HV / H": -# - group=1 → degenerates to plain MHA (sanity baseline) -# - group=2/4/8/16 → real GVA paths -# We deliberately mix small and large T (incl. non-multiple-of-chunk-size 63/65/1500), -# different B, and toggles for use_qk_l2norm_in_kernel / use_gate_in_kernel / -# use_initial_state / output_final_state to exercise as many code paths as -# possible without blowing up the matrix. -# ============================================================================= -_FIXED_CASES = [ - # ---------------- group = 1 (MHA baseline) ---------------- - (1, 63, 1, 1, 128, 1, 0, False, False, True, True, True, False, torch.bfloat16), - (2, 500, 3, 3, 128, 1, 0, False, False, True, True, True, False, torch.bfloat16), - (2, 1000, 3, 3, 128, 1, 0.5, False, False, True, True, True, False, torch.bfloat16), - (3, 1024, 4, 4, 128, 0.1, 0, False, False, True, True, True, False, torch.bfloat16), - (4, 1024, 4, 4, 128, 1, 0, True, False, True, True, True, False, torch.bfloat16), # qk_l2norm=True - (2, 1500, 4, 4, 128, 10, 0, False, True, True, True, True, False, torch.bfloat16), # gate_in_kernel - (4, 2048, 8, 8, 128, 1, 0, False, True, True, True, True, False, torch.bfloat16), - - # ---------------- group = 2 ---------------- - (1, 64, 1, 2, 128, 1, 0, False, False, True, True, True, False, torch.bfloat16), - (2, 512, 2, 4, 128, 1, 0, False, False, True, True, True, False, torch.bfloat16), - (3, 1024, 4, 8, 128, 0.1, 0, False, False, True, True, True, False, torch.bfloat16), - (1, 65, 2, 4, 128, 1, 0, False, False, True, True, True, True, torch.bfloat16), # deterministic - (2, 65, 2, 4, 128, 1, 0, False, False, True, False, False, False, torch.bfloat16), # init=F outstate=F - (4, 768, 4, 8, 128, 1, 0.3, True, False, True, True, True, False, torch.bfloat16), # qk_l2norm + dropout - - # ---------------- group = 4 ---------------- - (1, 65, 1, 4, 128, 1, 0, False, False, True, False, True, False, torch.bfloat16), # init=F - (2, 1024, 2, 8, 128, 1, 0, False, True, True, True, True, False, torch.bfloat16), # gate_in_kernel - (1, 256, 2, 8, 128, 1, 0, False, False, True, True, False, False, torch.bfloat16), # outstate=False (skip path) - (2, 4096, 2, 8, 128, 1, 0, False, True, True, True, True, False, torch.bfloat16), # long T - - # ---------------- group = 8 ---------------- - (1, 2048, 1, 8, 128, 1, 0, False, True, True, True, True, False, torch.bfloat16), - (2, 1024, 2, 16, 128, 1, 0, False, False, True, True, True, False, torch.bfloat16), - (1, 512, 1, 8, 128, 1, 0, False, False, True, True, False, False, torch.bfloat16), # outstate=False - - # ---------------- group = 16 ---------------- - (1, 512, 1, 16, 128, 1, 0, False, False, True, True, True, False, torch.bfloat16), - (1, 256, 2, 32, 128, 1, 0, False, False, True, True, True, False, torch.bfloat16), -] - - -@pytest.mark.parametrize("beta_dtype", [torch.float32, torch.bfloat16], ids=["beta_fp32", "beta_bf16"]) -@pytest.mark.parametrize( - ( - "B", - "T", - "H", - "HV", - "D", - "gate_logit_normalizer", - "mask_p", - "use_qk_l2norm_in_kernel", - "use_gate_in_kernel", - "safe_gate", - "use_initial_state", - "output_final_state", - "deterministic", - "dtype", - ), - [ - pytest.param( - *case, - id=("B{}-T{}-H{}-HV{}-D{}-gln{}-mask_p{}-l2norm{}-gate{}-safe_gate{}-init{}-outstate{}-deterministic{}-{}").format( - *case - ), + del o_fla, o_cula, q, k, v, g, beta, A_log, dt_bias, inputs + torch.cuda.empty_cache() + + return results + + +# ============================================================ +# Report +# ============================================================ +def print_report(fixed_results, varlen_results): + sep = "=" * 110 + print(f"\n\n{sep}") + print(" BENCHMARK REPORT: cula_kda_fused_fwd (fully-fused)") + print(f" cuLA {_SM_TAG} fully-fused vs FLA Triton") + print(f" H={H} D={D} dtype=bf16 safe_gate=True has_init_state={HAS_INIT_STATE}") + wu = 1 if (NCU_MODE or SANITIZER_MODE) else WARMUP + ni = 1 if (NCU_MODE or SANITIZER_MODE) else N_ITERS + mode_tag = " [NCU mode]" if NCU_MODE else (" [Sanitizer mode]" if SANITIZER_MODE else "") + print(f" Warmup={wu} Iters={ni}{mode_tag}") + print(sep) + + if fixed_results: + print("\n [Fixed-Length]") + print(f" {'─' * 90}") + print( + f" {'B':>3s} {'T':>6s} │ {'RMSE':>10s} {'rel_max':>10s} {'mean_diff':>10s}" + f" │ {'FLA(ms)':>9s} {'cuLA(ms)':>10s} {'Speedup':>8s}" ) - for case in _FIXED_CASES - ], -) -def test_safe_gate_chunk( - B: int, - T: int, - H: int, - HV: int, - D: int, - gate_logit_normalizer: float, - mask_p: float, - use_qk_l2norm_in_kernel: bool, - use_gate_in_kernel: bool, - safe_gate: bool, - use_initial_state: bool, - output_final_state: bool, - deterministic: bool, - dtype: torch.dtype, - beta_dtype: torch.dtype, -): - from fla.ops.kda.gate import naive_kda_lowerbound_gate - - cula_kda_fused_fwd = get_kda_fused_fwd(device) - - # Use a torch.Generator so each tensor draws from an independent stream while - # the overall test is still deterministic for a fixed _SEED. - gen = torch.Generator(device="cpu").manual_seed(_SEED) - - q = _make_qk(B, T, H, D, dtype, generator=gen) - k = _make_qk(B, T, H, D, dtype, generator=gen) - v = _make_v(B, T, HV, D, dtype, generator=gen) - g = _make_g( - B, T, HV, D, - dtype=dtype, - mask_p=mask_p, - gate_logit_normalizer=gate_logit_normalizer, - use_gate_in_kernel=use_gate_in_kernel, - generator=gen, + print(f" {'─' * 90}") + for r in fixed_results: + print( + f" {r['B']:3d} {r['T']:6d} │ " + f"{r['rmse']:10.6f} {r['rel_max']:10.6f} {r['mean_diff']:10.6f} │ " + f"{r['ms_fla']:9.4f} {r['ms_cula']:10.4f} {r['speedup']:7.2f}x" + ) + print(f" {'─' * 90}") + + if varlen_results: + print("\n [Varlen]") + print(f" {'─' * 105}") + print( + f" {'Config':>45s} │ {'RMSE':>10s} {'rel_max':>10s} {'mean_diff':>10s}" + f" │ {'FLA(ms)':>9s} {'cuLA(ms)':>10s} {'Speedup':>8s}" + ) + print(f" {'─' * 105}") + for r in varlen_results: + print( + f" {r['tag']:>45s} │ " + f"{r['rmse']:10.6f} {r['rel_max']:10.6f} {r['mean_diff']:10.6f} │ " + f"{r['ms_fla']:9.4f} {r['ms_cula']:10.4f} {r['speedup']:7.2f}x" + ) + print(f" {'─' * 105}") + + print(f"\n{sep}\n") + + +# ============================================================ +# Main +# ============================================================ +def main(): + parser = argparse.ArgumentParser(description="bench_kda_fused_fwd: cuLA fully-fused KDA forward vs FLA Triton") + parser.add_argument( + "--mode", + type=str, + default="both", + choices=["fixed", "varlen", "both"], + help="Which benchmark mode to run (default: both)", ) - beta = _make_beta(B, T, HV, beta_dtype, generator=gen) - h0 = _make_h0(B, HV, D, generator=gen) - - if deterministic: - # Hand-crafted inputs that produce closed-form outputs; useful as a stronger - # correctness anchor than just RMSE-against-reference. - assert H == 2 and HV == 4 and not use_gate_in_kernel - q = torch.zeros(B, T, H, D, dtype=dtype) - k = torch.zeros(B, T, H, D, dtype=dtype) - v = torch.zeros(B, T, HV, D, dtype=dtype) - g = torch.zeros(B, T, HV, D, dtype=torch.float) - q[:, :, 0, 0] = 1 - q[:, :, 1, 1] = 1 - k[:, :, 0, 0] = 1 - k[:, :, 1, 0] = 1 - for i in range(HV): - v[:, :, i] = i + 1 - beta = torch.ones(B, T, HV, dtype=beta_dtype) - h0 = torch.zeros(B, HV, D, D, dtype=torch.float32) - - A_log = dt_bias = None - if use_gate_in_kernel: - A_log = torch.randn(HV, dtype=torch.float, generator=gen) - dt_bias = torch.randn(HV * D, dtype=torch.float, generator=gen) - - if safe_gate: - lower_bound = -5.0 - if not use_gate_in_kernel: - g = g.clamp(-5, 0) - naive_kda_gate_fn = naive_kda_lowerbound_gate - else: - lower_bound = None - naive_kda_gate_fn = naive_kda_gate - - # NOTE: for inference scenarios, we only use the transposed state layout for better - # decoding performance. - h0_vk = h0.transpose(-1, -2).contiguous() - if use_gate_in_kernel: - A_log, dt_bias = (x.to(device).requires_grad_(False) for x in (A_log, dt_bias)) - q, k, v, g, beta, h0, h0_vk = ( - x.to(device).requires_grad_(False) for x in (q, k, v, g, beta, h0, h0_vk) + parser.add_argument( + "--ncu", + action="store_true", + help="NCU profiling mode: warmup=1, iters=1", ) - initial_state = h0.clone() if use_initial_state else None - initial_state_vk = h0_vk.clone() if use_initial_state else None - - # GVA reference: replicate Q/K across each group so the naive/MHA reference can - # consume them as if HV-many heads were present. - heads_per_group = HV // H - q_ref = q.repeat_interleave(heads_per_group, dim=2) - k_ref = k.repeat_interleave(heads_per_group, dim=2) - - ref, ref_ht = naive_recurrent_kda( - q=F.normalize(q_ref.clone(), p=2, dim=-1), - k=F.normalize(k_ref.clone(), p=2, dim=-1), - v=v.clone(), - g=(naive_kda_gate_fn(g, A_log, dt_bias) if use_gate_in_kernel else g.clone()), - beta=beta.clone(), - initial_state=initial_state, - output_final_state=output_final_state, + parser.add_argument( + "--sanitizer", + action="store_true", + help="Sanitizer mode: warmup=1, iters=1", ) - - ref_fla, ref_ht_fla = fla_chunk_kda( - q=F.normalize(q_ref.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q_ref.clone(), - k=F.normalize(k_ref.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k_ref.clone(), - v=v.clone(), - g=g.clone(), - beta=beta.clone(), - A_log=(A_log.clone() if use_gate_in_kernel else None), - dt_bias=(dt_bias.clone() if use_gate_in_kernel else None), - initial_state=initial_state, - output_final_state=output_final_state, - use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, - use_gate_in_kernel=use_gate_in_kernel, - safe_gate=safe_gate, - lower_bound=lower_bound, + parser.add_argument( + "--init_state", + action="store_true", + help="Use non-zero initial state (default: False)", ) - - ref_fla_trans, ref_ht_fla_trans = fla_chunk_kda( - q=F.normalize(q_ref.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q_ref.clone(), - k=F.normalize(k_ref.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k_ref.clone(), - v=v.clone(), - g=g.clone(), - beta=beta.clone(), - A_log=(A_log.clone() if use_gate_in_kernel else None), - dt_bias=(dt_bias.clone() if use_gate_in_kernel else None), - initial_state=initial_state_vk, - output_final_state=output_final_state, - use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, - use_gate_in_kernel=use_gate_in_kernel, - safe_gate=safe_gate, - lower_bound=lower_bound, - transpose_state_layout=True, + args = parser.parse_args() + + global NCU_MODE, SANITIZER_MODE, HAS_INIT_STATE + if args.ncu: + NCU_MODE = True + print("[NCU mode] warmup=1, iters=1") + if args.sanitizer: + SANITIZER_MODE = True + print("[Sanitizer mode] warmup=1, iters=1") + if args.init_state: + HAS_INIT_STATE = True + print("[init_state] using non-zero initial state") + + print( + f"[Device] {torch.cuda.get_device_name(0)} compute capability {_SM_TAG} → using {cula_kda_fused_fwd.__module__}.{cula_kda_fused_fwd.__name__}" ) - tri, tri_ht = cula_kda_fused_fwd( - q=F.normalize(q.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q.clone(), - k=F.normalize(k.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k.clone(), - v=v.clone(), - g=g.clone(), - beta=beta.clone(), - A_log=(A_log.clone() if use_gate_in_kernel else None), - dt_bias=(dt_bias.clone() if use_gate_in_kernel else None), - initial_state=initial_state_vk, - output_final_state=output_final_state, - use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, - use_gate_in_kernel=use_gate_in_kernel, - safe_gate=safe_gate, - lower_bound=lower_bound, + fixed_configs = [ + # (B, T) + (1, 512), + (1, 1024), + (1, 4096), + (1, 8192), + (1, 16384), + (2, 512), + (2, 1024), + (2, 4096), + (2, 8192), + (2, 16384), + ] + + varlen_configs = build_varlen_configs( + num_seqs_list=(10, 20), + total_lens=(4096, 8192, 16384), + dists=("uniform", "random", "skewed"), ) - assert_close("o", ref, tri, 0.005) - assert_close("o", ref_fla, tri, 0.005) - assert_close("o", ref_fla_trans, tri, 0.005) - if output_final_state: - assert_close("ht", ref_ht, tri_ht.transpose(-1, -2), 0.005) - assert_close("ht", ref_ht_fla, tri_ht.transpose(-1, -2), 0.005) - assert_close("ht", ref_ht_fla_trans, tri_ht, 0.005) - else: - assert ref_ht is None - assert ref_ht_fla is None - assert ref_ht_fla_trans is None - assert tri_ht is None, "wrapper must surface None when output_final_state=False" - - -# ============================================================================= -# Variable-length (cu_seqlens) test -# ============================================================================= -_VARLEN_CASES = [ - # ---------------- group = 1 (MHA baseline) ---------------- - (4, 4, 128, 0.1, [0, 15], torch.bfloat16, True, True, True), - (4, 4, 128, 0.9, [0, 256, 500, 1000], torch.bfloat16, True, True, True), - (4, 4, 128, 0.5, [0, 256, 500, 1000], torch.bfloat16, True, True, True), - (4, 4, 128, 0, [0, 15, 100, 300, 1200, 2000], torch.bfloat16, True, True, True), - (4, 4, 128, 0, [0, 100, 300, 1200, 3000, 4096], torch.bfloat16, True, True, True), - - # ---------------- group = 2 ---------------- - (2, 4, 128, 0, [0, 63, 130], torch.bfloat16, True, True, True), - (2, 4, 128, 0, [0, 17, 64, 65, 130], torch.bfloat16, True, False, True), # init=False - (3, 6, 128, 0.2, [0, 257, 800, 1500], torch.bfloat16, True, True, True), - - # ---------------- group = 4 ---------------- - (1, 4, 128, 0, [0, 1], torch.bfloat16, True, True, True), - (1, 4, 128, 0, [0, 63, 64, 65], torch.bfloat16, True, True, True), - (4, 16, 128, 0.5, [0, 15, 100, 300], torch.bfloat16, True, True, False), # outstate=False (skip path) - (2, 8, 128, 0, [0, 256, 1024, 4096], torch.bfloat16, True, True, True), - - # ---------------- group = 8 ---------------- - (1, 8, 128, 0, [0, 65, 200, 1024], torch.bfloat16, True, True, True), - (1, 8, 128, 0, [0, 1024, 2048], torch.bfloat16, True, False, False), # init=F outstate=F - - # ---------------- group = 16 ---------------- - (1, 16, 128, 0, [0, 257, 1024], torch.bfloat16, True, True, True), - - # ---------------- group = 1, varlen at scale (simulated traces) ---------------- - ( - 32, 32, 128, 0, - [0, 247, 699, 982, 1688, 1985, 2383, 3081, 3526, 3973, 4096, 4824, 5101, 5919, 6426, 7137, 7392, 7800, 8192], - torch.bfloat16, True, True, True, - ), - ( - 32, 32, 128, 0, - [0, 652, 1255, 1600, 2083, 2345, 2756, 3172, 3767, 4096, 4891, 5236, 5543, 6255, 6480, 6947, 7616, 8192], - torch.bfloat16, True, True, True, - ), - ( - 32, 32, 128, 0, - [0, 315, 973, 1283, 2162, 2459, 2678, 2998, 3781, 4096, 4503, 5459, 6318, 6669, 6979, 7583, 8192], - torch.bfloat16, True, True, True, - ), - ( - 32, 32, 128, 0, - [0, 494, 1004, 1561, 1908, 2240, 2849, 3116, 4096, 4986, 5626, 6090, 6718, 7244, 7870, 8192], - torch.bfloat16, True, True, True, - ), - - # ---------------- group = 4, varlen at scale ---------------- - ( - 8, 32, 128, 0, - [0, 255, 1024, 2049, 3072, 4097, 5120, 6144, 7168, 8192], - torch.bfloat16, True, True, True, - ), -] - - -@pytest.mark.parametrize("beta_dtype", [torch.float32, torch.bfloat16], ids=["beta_fp32", "beta_bf16"]) -@pytest.mark.parametrize( - ("H", "HV", "D", "mask_p", "cu_seqlens", "dtype", "safe_gate", "use_initial_state", "output_final_state"), - [ - pytest.param( - *case, - id="H{}-HV{}-D{}-mask_p{}-cu_seqlens{}-{}-safe_gate{}-init{}-outstate{}".format(*case), - ) - for case in _VARLEN_CASES - ], -) -def test_safe_gate_chunk_varlen( - H: int, - HV: int, - D: int, - mask_p: float, - cu_seqlens: list[int], - dtype: torch.dtype, - safe_gate: bool, - use_initial_state: bool, - output_final_state: bool, - beta_dtype: torch.dtype, -): - cula_kda_fused_fwd = get_kda_fused_fwd(device) - - gen = torch.Generator(device="cpu").manual_seed(_SEED) - - cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) - cu_seqlens_cpu = cu_seqlens.cpu() - T = cu_seqlens[-1] - N = len(cu_seqlens) - 1 - - q = (torch.randn((1, T, H, D), dtype=torch.float32, generator=gen) * 0.5).to(dtype) - k = F.normalize(torch.randn(1, T, H, D, dtype=torch.float32, generator=gen), p=2, dim=-1).to(dtype) - v = (torch.randn((1, T, HV, D), dtype=torch.float32, generator=gen) * 0.5).to(dtype) - g = F.logsigmoid(torch.randn(1, T, HV, D, dtype=torch.float, generator=gen)) - drop_mask = torch.rand(g.shape, dtype=g.dtype, generator=gen) > mask_p - g = g * drop_mask + (~drop_mask) * (-1000) - if safe_gate: - g = g.clamp(-5, 0) - - beta = _make_beta(1, T, HV, beta_dtype, generator=gen) - h0 = _make_h0(N, HV, D, generator=gen) - # NOTE: for inference scenarios, we only use the transposed state layout for better - # decoding performance. - h0_vk = h0.transpose(-1, -2).contiguous() - - q, k, v, g, beta, h0, h0_vk = ( - x.to(device).requires_grad_(False) for x in (q, k, v, g, beta, h0, h0_vk) - ) - initial_state = h0.clone() if use_initial_state else None - initial_state_vk = h0_vk.clone() if use_initial_state else None - heads_per_group = HV // H - q_ref = q.repeat_interleave(heads_per_group, dim=2) - k_ref = k.repeat_interleave(heads_per_group, dim=2) - - tri, tri_ht = cula_kda_fused_fwd( - q=F.normalize(q.clone(), p=2, dim=-1), - k=k.clone(), - v=v.clone(), - g=g.clone(), - beta=beta.clone(), - initial_state=initial_state_vk, - output_final_state=output_final_state, - cu_seqlens=cu_seqlens, - cu_seqlens_cpu=cu_seqlens_cpu, - safe_gate=safe_gate, - lower_bound=-5.0 if safe_gate else None, - ) + fixed_res, varlen_res = [], [] - ref_fla, ref_ht_fla = fla_chunk_kda( - q=F.normalize(q_ref.clone(), p=2, dim=-1), - k=k_ref.clone(), - v=v.clone(), - g=g.clone(), - beta=beta.clone(), - initial_state=initial_state, - output_final_state=output_final_state, - cu_seqlens=cu_seqlens, - cu_seqlens_cpu=cu_seqlens_cpu, - safe_gate=safe_gate, - lower_bound=-5.0 if safe_gate else None, - ) + if args.mode in ("fixed", "both"): + fixed_res = bench_fixed(fixed_configs) - ref_fla_trans, ref_ht_fla_trans = fla_chunk_kda( - q=F.normalize(q_ref.clone(), p=2, dim=-1), - k=k_ref.clone(), - v=v.clone(), - g=g.clone(), - beta=beta.clone(), - initial_state=initial_state_vk, - output_final_state=output_final_state, - cu_seqlens=cu_seqlens, - cu_seqlens_cpu=cu_seqlens_cpu, - safe_gate=safe_gate, - lower_bound=-5.0 if safe_gate else None, - transpose_state_layout=True, - ) + if args.mode in ("varlen", "both"): + varlen_res = bench_varlen(varlen_configs) - ref = [] - ref_ht = [] - for i in range(N): - ref_i, ref_ht_i = naive_recurrent_kda( - q=F.normalize(q_ref[:, cu_seqlens[i] : cu_seqlens[i + 1]], p=2, dim=-1), - k=k_ref[:, cu_seqlens[i] : cu_seqlens[i + 1]], - v=v[:, cu_seqlens[i] : cu_seqlens[i + 1]], - beta=beta[:, cu_seqlens[i] : cu_seqlens[i + 1]], - g=g[:, cu_seqlens[i] : cu_seqlens[i + 1]], - initial_state=h0[i] if use_initial_state else None, - output_final_state=output_final_state, - ) - ref.append(ref_i) - ref_ht.append(ref_ht_i) - ref = torch.cat(ref, 1) - ref_ht = torch.cat(ref_ht, 0) if output_final_state else None - - assert_close("o", ref, tri, 0.005) - assert_close("o", ref_fla, tri, 0.005) - assert_close("o", ref_fla_trans, tri, 0.005) - if output_final_state: - assert_close("ht", ref_ht, tri_ht.transpose(-1, -2), 0.005) - assert_close("ht", ref_ht_fla, tri_ht.transpose(-1, -2), 0.005) - assert_close("ht", ref_ht_fla_trans, tri_ht, 0.005) - else: - assert ref_ht is None - assert ref_ht_fla is None - assert ref_ht_fla_trans is None - assert tri_ht is None, "wrapper must surface None when output_final_state=False" - - -# ============================================================================= -# Regression: output_final_state=False must skip the state buffer and still -# produce correct outputs on a GVA configuration. -# ============================================================================= -@pytest.mark.parametrize( - ("B", "T", "H", "HV"), - [ - (1, 256, 1, 4), - (2, 1024, 2, 8), - (1, 512, 1, 16), - ], - ids=["small-g4", "medium-g4", "wide-g16"], -) -def test_output_final_state_skip_under_gva(B: int, T: int, H: int, HV: int): - """Sanity check for the C++ side optimization: when output_final_state=False - we must (a) still get correct outputs and (b) get None back as the second - return value (no leaked tensor).""" - cula_kda_fused_fwd = get_kda_fused_fwd(device) - gen = torch.Generator(device="cpu").manual_seed(_SEED) - D = 128 - dtype = torch.bfloat16 - - q = _make_qk(B, T, H, D, dtype, generator=gen).to(device) - k = _make_qk(B, T, H, D, dtype, generator=gen).to(device) - v = _make_v(B, T, HV, D, dtype, generator=gen).to(device) - g = _make_g( - B, T, HV, D, - dtype=dtype, mask_p=0.0, gate_logit_normalizer=1.0, use_gate_in_kernel=False, - generator=gen, - ).clamp(-5, 0).to(device) - beta = _make_beta(B, T, HV, torch.float32, generator=gen).to(device) - - # Run twice: once with output_final_state=True (reference path), once with =False. - o_full, ht_full = cula_kda_fused_fwd( - q=F.normalize(q.clone(), p=2, dim=-1), - k=F.normalize(k.clone(), p=2, dim=-1), - v=v.clone(), g=g.clone(), beta=beta.clone(), - output_final_state=True, safe_gate=True, lower_bound=-5.0, - ) - o_skip, ht_skip = cula_kda_fused_fwd( - q=F.normalize(q.clone(), p=2, dim=-1), - k=F.normalize(k.clone(), p=2, dim=-1), - v=v.clone(), g=g.clone(), beta=beta.clone(), - output_final_state=False, safe_gate=True, lower_bound=-5.0, - ) + print_report(fixed_res, varlen_res) - assert ht_full is not None, "with output_final_state=True we must get a state tensor" - assert ht_skip is None, "with output_final_state=False the wrapper must surface None" - assert_close("o", o_full, o_skip, 0.005) - - -# ============================================================================= -# API contract: GVA shape validation. -# ============================================================================= -@pytest.mark.parametrize( - ("H", "HV", "expect_error"), - [ - (2, 4, False), # valid: HV / H = 2 - (2, 8, False), # valid: HV / H = 4 - (4, 4, False), # valid: HV == H (MHA) - (3, 4, True), # invalid: 4 % 3 != 0 - (3, 7, True), # invalid: 7 % 3 != 0 - ], -) -def test_gva_shape_validation(H: int, HV: int, expect_error: bool): - """The Python wrapper must reject HV that is not a positive multiple of H - *before* anything reaches the kernel.""" - cula_kda_fused_fwd = get_kda_fused_fwd(device) - B, T, D = 1, 64, 128 - dtype = torch.bfloat16 - gen = torch.Generator(device="cpu").manual_seed(_SEED) - - q = _make_qk(B, T, H, D, dtype, generator=gen).to(device) - k = _make_qk(B, T, H, D, dtype, generator=gen).to(device) - v = _make_v(B, T, HV, D, dtype, generator=gen).to(device) - g = _make_g( - B, T, HV, D, - dtype=dtype, mask_p=0.0, gate_logit_normalizer=1.0, use_gate_in_kernel=False, - generator=gen, - ).clamp(-5, 0).to(device) - beta = _make_beta(B, T, HV, torch.float32, generator=gen).to(device) - - if expect_error: - with pytest.raises(AssertionError): - cula_kda_fused_fwd( - q=F.normalize(q, p=2, dim=-1), - k=F.normalize(k, p=2, dim=-1), - v=v, g=g, beta=beta, - output_final_state=False, safe_gate=True, lower_bound=-5.0, - ) - else: - o, _ = cula_kda_fused_fwd( - q=F.normalize(q, p=2, dim=-1), - k=F.normalize(k, p=2, dim=-1), - v=v, g=g, beta=beta, - output_final_state=False, safe_gate=True, lower_bound=-5.0, - ) - assert o.shape == (B, T, HV, D) + return fixed_res, varlen_res + + +if __name__ == "__main__": + main() From c6886fb63eee1ae6f5902c8407388b995dce8989 Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Wed, 6 May 2026 14:32:17 +0800 Subject: [PATCH 06/17] sm90 --- tests/test_kda_fused_fwd.py | 707 ++++++++++++++++-------------------- 1 file changed, 322 insertions(+), 385 deletions(-) diff --git a/tests/test_kda_fused_fwd.py b/tests/test_kda_fused_fwd.py index a87eed3..fba830a 100644 --- a/tests/test_kda_fused_fwd.py +++ b/tests/test_kda_fused_fwd.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # Copyright 2025-2026 Ant Group Co., Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,410 +12,348 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -bench_kda_fused_fwd.py — Benchmark: cuLA fully-fused KDA forward vs FLA Triton baseline +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang -Automatically selects the cuLA fully-fused implementation based on the current -GPU architecture: - - sm100 (Blackwell) → cula.kda.blackwell_fused_fwd.flash_kda_prefill - - sm90 (Hopper) → cula.kda.hopper_fused_fwd.cula_kda_prefill +# Adapted from flash-linear-attention: https://github.com/fla-org/flash-linear-attention/blob/main/tests/ops/test_kda.py -Compares: - - Accuracy: RMSE, relative max diff between cuLA fully-fused and FLA Triton - - Performance: kernel execution time (ms) with CUDA events - -Modes: - - Fixed-length: various (B, T) configs - - Varlen: sequences with 2-3x length variation - -Usage: - python bench_kda_fused_fwd.py [--mode fixed|varlen|both] [--ncu] - -With --ncu, warmup=1 and iters=1 for ncu profiling: - ncu --set full -o report python bench_kda_fused_fwd.py --mode varlen --ncu -""" - -import argparse -import os -import pathlib -import sys - -sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent.parent)) -os.environ.setdefault("FLA_USE_FAST_OPS", os.getenv("CULA_USE_FAST_MATH", "1")) # Enable fast ops in FLA for fair comparison +import pytest import torch -from fla.ops.kda import chunk_kda as fla_chunk_kda - -from benchmarks.utils import ( - SEED, - build_varlen_configs, - exclusive_cumsum, - prepare_safe_gate_inputs, - set_seed, +import torch.nn.functional as F +from fla.ops import chunk_kda as fla_chunk_kda +from fla.ops.kda.gate import naive_kda_gate +from fla.ops.kda.naive import naive_recurrent_kda +from fla.utils import assert_close, device + +from cula.utils import get_kda_fused_fwd + +pytestmark = pytest.mark.sm90_only + + +@pytest.mark.parametrize("beta_dtype", [torch.float32, torch.bfloat16], ids=["beta_fp32", "beta_bf16"]) +@pytest.mark.parametrize( + ( + "B", + "T", + "H", + "HV", + "D", + "gate_logit_normalizer", + "mask_p", + "use_qk_l2norm_in_kernel", + "use_gate_in_kernel", + "safe_gate", + "use_initial_state", + "dtype", + ), + [ + pytest.param( + *test, + id=("B{}-T{}-H{}-HV{}-D{}-gln{}-mask_p{}-l2norm{}-gate{}-safe_gate{}-init{}-{}").format( + *test + ), + ) + for test in [ + (1, 63, 1, 1, 128, 1, 0, False, False, True, True, torch.bfloat16), + (2, 500, 3, 3, 128, 1, 0, False, False, True, True, torch.bfloat16), + (2, 1000, 3, 3, 128, 1, 0.5, False, False, True, True, torch.bfloat16), + (3, 1024, 4, 4, 128, 0.1, 0, False, False, True, True, torch.bfloat16), + (4, 1024, 4, 4, 128, 1, 0, False, False, True, True, torch.bfloat16), + (4, 1024, 4, 4, 128, 1, 0, True, False, True, True, torch.bfloat16), + (2, 1500, 4, 4, 128, 10, 0, False, True, True, True, torch.bfloat16), + (4, 2048, 8, 8, 128, 1, 0, False, True, True, True, torch.bfloat16), + (2, 512, 2, 4, 128, 1, 0, False, False, True, True, torch.bfloat16), + (2, 1024, 2, 8, 128, 1, 0, False, True, True, True, torch.bfloat16), + (1, 64, 1, 2, 128, 1, 0, False, False, True, True, torch.bfloat16), + (1, 65, 1, 4, 128, 1, 0, False, False, True, False, torch.bfloat16), + ] + ], ) -from cula.utils import get_device_sm_version, get_kda_fused_fwd - -# ============================================================ -# Resolve cuLA fully-fused implementation at import time -# ============================================================ -_device = torch.device("cuda") -_major, _minor = get_device_sm_version(_device) -_SM_TAG = f"sm{_major}{_minor}" -cula_kda_fused_fwd = get_kda_fused_fwd(_device) - -# ============================================================ -# Constants -# ============================================================ -H, D = 64, 128 -WARMUP = 25 -N_ITERS = 100 -NCU_MODE = False -SANITIZER_MODE = False -HAS_INIT_STATE = False - - -# ============================================================ -# Helpers -# ============================================================ -def time_kernel(fn, warmup=None, n_iters=None): - if warmup is None: - warmup = 1 if (NCU_MODE or SANITIZER_MODE) else WARMUP - if n_iters is None: - n_iters = 1 if (NCU_MODE or SANITIZER_MODE) else N_ITERS - for _ in range(warmup): - fn() - torch.cuda.synchronize() - start_evt = torch.cuda.Event(enable_timing=True) - end_evt = torch.cuda.Event(enable_timing=True) - start_evt.record() - for _ in range(n_iters): - fn() - end_evt.record() - torch.cuda.synchronize() - return start_evt.elapsed_time(end_evt) / n_iters - - -def accuracy_stats(ref, out): - """Compute RMSE, relative max diff, and mean absolute difference.""" - ref_f = ref.float() - out_f = out.float() - diff = (ref_f - out_f).abs() - rmse = diff.pow(2).mean().sqrt().item() - max_diff = diff.max().item() - denom = ref_f.abs().max().item() - rel_max = max_diff / denom if denom > 0 else 0.0 - mean_diff = diff.mean().item() - return rmse, rel_max, mean_diff +def test_safe_gate_chunk( + B: int, + T: int, + H: int, + HV: int, + D: int, + gate_logit_normalizer: float, + mask_p: float, + use_qk_l2norm_in_kernel: bool, + use_gate_in_kernel: bool, + safe_gate: bool, + use_initial_state: bool, + dtype: torch.dtype, + beta_dtype: torch.dtype, +): + from fla.ops.kda.gate import naive_kda_lowerbound_gate + + cula_kda_fused_fwd = get_kda_fused_fwd(device) + + torch.manual_seed(42) + q = torch.rand(B, T, H, D, dtype=dtype) + k = torch.rand(B, T, H, D, dtype=dtype) + v = torch.rand(B, T, HV, D, dtype=dtype) + g = torch.randn(B, T, HV, D, dtype=torch.float if not use_gate_in_kernel else dtype) + if use_gate_in_kernel: + A_log = torch.randn(HV, dtype=torch.float) + dt_bias = torch.randn(HV * D, dtype=torch.float) + else: + g = F.logsigmoid(g) / gate_logit_normalizer + g = g * (torch.rand_like(g) > mask_p) + if safe_gate: + lower_bound = -5.0 + if not use_gate_in_kernel: + g = g.clamp(-5, 0) + naive_kda_gate_fn = naive_kda_lowerbound_gate + else: + lower_bound = None + naive_kda_gate_fn = naive_kda_gate + + beta = torch.randn(B, T, HV, dtype=torch.float32).sigmoid().to(beta_dtype) + h0 = torch.randn(B, HV, D, D, dtype=torch.float32) + # NOTE: for inference scenarios, we only use transposed state layout for better decoding performance + h0_vk = h0.transpose(-1, -2).contiguous() + if use_gate_in_kernel: + A_log, dt_bias = map(lambda x: x.to(device).requires_grad_(False), (A_log, dt_bias)) + q, k, v, g, beta, h0, h0_vk = map(lambda x: x.to(device).requires_grad_(False), (q, k, v, g, beta, h0, h0_vk)) + initial_state = h0.clone() if use_initial_state else None + initial_state_vk = h0_vk.clone() if use_initial_state else None + + heads_per_group = HV // H + q_ref = q.repeat_interleave(heads_per_group, dim=2) + k_ref = k.repeat_interleave(heads_per_group, dim=2) + + ref, ref_ht = naive_recurrent_kda( + q=F.normalize(q_ref.clone(), p=2, dim=-1), + k=F.normalize(k_ref.clone(), p=2, dim=-1), + v=v.clone(), + g=(naive_kda_gate_fn(g, A_log, dt_bias) if use_gate_in_kernel else g.clone()), + beta=beta.clone(), + initial_state=initial_state, + output_final_state=True, + ) + ref_fla, ref_ht_fla = fla_chunk_kda( + q=F.normalize(q_ref.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q_ref.clone(), + k=F.normalize(k_ref.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k_ref.clone(), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + A_log=(A_log.clone() if use_gate_in_kernel else None), + dt_bias=(dt_bias.clone() if use_gate_in_kernel else None), + initial_state=initial_state, + output_final_state=True, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + use_gate_in_kernel=use_gate_in_kernel, + safe_gate=safe_gate, + lower_bound=lower_bound, + ) -def run_fla(q, k, v, g, beta, scale, A_log, dt_bias, init_state, cu_seqlens, lower_bound): - return fla_chunk_kda( - q=q, - k=k, - v=v, - g=g, - beta=beta, - scale=scale, - A_log=A_log, - dt_bias=dt_bias, - initial_state=init_state, + ref_fla_trans, ref_ht_fla_trans = fla_chunk_kda( + q=F.normalize(q_ref.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q_ref.clone(), + k=F.normalize(k_ref.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k_ref.clone(), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + A_log=(A_log.clone() if use_gate_in_kernel else None), + dt_bias=(dt_bias.clone() if use_gate_in_kernel else None), + initial_state=initial_state_vk, output_final_state=True, - use_qk_l2norm_in_kernel=True, - cu_seqlens=cu_seqlens, - use_gate_in_kernel=True, - safe_gate=True, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + use_gate_in_kernel=use_gate_in_kernel, + safe_gate=safe_gate, lower_bound=lower_bound, transpose_state_layout=True, ) - -def run_cula(q, k, v, g, beta, scale, A_log, dt_bias, init_state, cu_seqlens, lower_bound): - return cula_kda_fused_fwd( - q=q, - k=k, - v=v, - g=g, - beta=beta, - scale=scale, - A_log=A_log, - dt_bias=dt_bias, - initial_state=init_state, + tri, tri_ht = cula_kda_fused_fwd( + q=F.normalize(q.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q.clone(), + k=F.normalize(k.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k.clone(), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + A_log=(A_log.clone() if use_gate_in_kernel else None), + dt_bias=(dt_bias.clone() if use_gate_in_kernel else None), + initial_state=initial_state_vk, output_final_state=True, - use_qk_l2norm_in_kernel=True, - cu_seqlens=cu_seqlens, - use_gate_in_kernel=True, - safe_gate=True, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + use_gate_in_kernel=use_gate_in_kernel, + safe_gate=safe_gate, lower_bound=lower_bound, ) - -# ============================================================ -# Fixed-length benchmark -# ============================================================ -def bench_fixed(configs): - print("\n" + "=" * 100) - print(f" Fixed-Length Benchmark: cuLA fully-fused ({_SM_TAG}) vs FLA Triton") - print("=" * 100) - results = [] - - for B, T in configs: - set_seed(SEED) - device = torch.device("cuda") - torch.cuda.empty_cache() - - seq_lens = [T] * B - cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) - - inputs = prepare_safe_gate_inputs(B, T, H, D, device, cu_seqlens=cu_seqlens, has_init_state=HAS_INIT_STATE) - q, k, v, g, beta = inputs["q"], inputs["k"], inputs["v"], inputs["g"], inputs["beta"] - A_log, dt_bias = inputs["A_log"], inputs["dt_bias"] - scale, init_state, lower_bound = inputs["scale"], inputs["init_state"], inputs["lower_bound"] - - common = dict( - q=q, - k=k, - v=v, - g=g, - beta=beta, - scale=scale, - A_log=A_log, - dt_bias=dt_bias, - init_state=init_state, - cu_seqlens=cu_seqlens, - lower_bound=lower_bound, + assert_close("o", ref, tri, 0.005) + assert_close("o", ref_fla, tri, 0.005) + assert_close("o", ref_fla_trans, tri, 0.005) + assert_close("ht", ref_ht, tri_ht.transpose(-1, -2), 0.005) + assert_close("ht", ref_ht_fla, tri_ht.transpose(-1, -2), 0.005) + assert_close("ht", ref_ht_fla_trans, tri_ht, 0.005) + + +@pytest.mark.parametrize("beta_dtype", [torch.float32, torch.bfloat16], ids=["beta_fp32", "beta_bf16"]) +@pytest.mark.parametrize( + ("H", "HV", "D", "mask_p", "cu_seqlens", "dtype", "safe_gate", "use_initial_state"), + [ + pytest.param( + *test, + id="H{}-HV{}-D{}-mask_p{}-cu_seqlens{}-{}-safe_gate{}-init{}".format(*test), ) - - # Accuracy - o_fla, _ = run_fla(**common) - o_cula, _ = run_cula(**common) - torch.cuda.synchronize() - - rmse, rel_max, mean_diff = accuracy_stats(o_fla, o_cula) - - # Performance - ms_fla = time_kernel(lambda: run_fla(**common)) - ms_cula = time_kernel(lambda: run_cula(**common)) - speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf") - - results.append( - { - "B": B, - "T": T, - "rmse": rmse, - "rel_max": rel_max, - "mean_diff": mean_diff, - "ms_fla": ms_fla, - "ms_cula": ms_cula, - "speedup": speedup, - } - ) - - del o_fla, o_cula, q, k, v, g, beta, A_log, dt_bias, inputs - torch.cuda.empty_cache() - - return results - - -# ============================================================ -# Varlen benchmark -# ============================================================ -def bench_varlen(configs): - print("\n" + "=" * 100) - print(f" Varlen Benchmark: cuLA fully-fused ({_SM_TAG}) vs FLA Triton") - print("=" * 100) - results = [] - - for seq_lens, total_len, dist in configs: - set_seed(SEED) - device = torch.device("cuda") - torch.cuda.empty_cache() - - T = total_len - cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) - - inputs = prepare_safe_gate_inputs(1, T, H, D, device, cu_seqlens=cu_seqlens, has_init_state=HAS_INIT_STATE) - q, k, v, g, beta = inputs["q"], inputs["k"], inputs["v"], inputs["g"], inputs["beta"] - A_log, dt_bias = inputs["A_log"], inputs["dt_bias"] - scale, init_state, lower_bound = inputs["scale"], inputs["init_state"], inputs["lower_bound"] - - common = dict( - q=q, - k=k, - v=v, - g=g, - beta=beta, - scale=scale, - A_log=A_log, - dt_bias=dt_bias, - init_state=init_state, - cu_seqlens=cu_seqlens, - lower_bound=lower_bound, - ) - - # Accuracy - o_fla, _ = run_fla(**common) - o_cula, _ = run_cula(**common) - torch.cuda.synchronize() - - rmse, rel_max, mean_diff = accuracy_stats(o_fla, o_cula) - - # Performance - ms_fla = time_kernel(lambda: run_fla(**common)) - ms_cula = time_kernel(lambda: run_cula(**common)) - speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf") - - n_seqs = len(seq_lens) - min_l, max_l = min(seq_lens), max(seq_lens) - avg_l = T // n_seqs - tag = f"{dist:>7s} {n_seqs:>2d}seqs T={T} [{min_l}..{max_l}] avg={avg_l}" - - results.append( - { - "tag": tag, - "dist": dist, - "T_total": T, - "n_seqs": n_seqs, - "rmse": rmse, - "rel_max": rel_max, - "mean_diff": mean_diff, - "ms_fla": ms_fla, - "ms_cula": ms_cula, - "speedup": speedup, - } - ) - - del o_fla, o_cula, q, k, v, g, beta, A_log, dt_bias, inputs - torch.cuda.empty_cache() - - return results - - -# ============================================================ -# Report -# ============================================================ -def print_report(fixed_results, varlen_results): - sep = "=" * 110 - print(f"\n\n{sep}") - print(" BENCHMARK REPORT: cula_kda_fused_fwd (fully-fused)") - print(f" cuLA {_SM_TAG} fully-fused vs FLA Triton") - print(f" H={H} D={D} dtype=bf16 safe_gate=True has_init_state={HAS_INIT_STATE}") - wu = 1 if (NCU_MODE or SANITIZER_MODE) else WARMUP - ni = 1 if (NCU_MODE or SANITIZER_MODE) else N_ITERS - mode_tag = " [NCU mode]" if NCU_MODE else (" [Sanitizer mode]" if SANITIZER_MODE else "") - print(f" Warmup={wu} Iters={ni}{mode_tag}") - print(sep) - - if fixed_results: - print("\n [Fixed-Length]") - print(f" {'─' * 90}") - print( - f" {'B':>3s} {'T':>6s} │ {'RMSE':>10s} {'rel_max':>10s} {'mean_diff':>10s}" - f" │ {'FLA(ms)':>9s} {'cuLA(ms)':>10s} {'Speedup':>8s}" - ) - print(f" {'─' * 90}") - for r in fixed_results: - print( - f" {r['B']:3d} {r['T']:6d} │ " - f"{r['rmse']:10.6f} {r['rel_max']:10.6f} {r['mean_diff']:10.6f} │ " - f"{r['ms_fla']:9.4f} {r['ms_cula']:10.4f} {r['speedup']:7.2f}x" - ) - print(f" {'─' * 90}") - - if varlen_results: - print("\n [Varlen]") - print(f" {'─' * 105}") - print( - f" {'Config':>45s} │ {'RMSE':>10s} {'rel_max':>10s} {'mean_diff':>10s}" - f" │ {'FLA(ms)':>9s} {'cuLA(ms)':>10s} {'Speedup':>8s}" - ) - print(f" {'─' * 105}") - for r in varlen_results: - print( - f" {r['tag']:>45s} │ " - f"{r['rmse']:10.6f} {r['rel_max']:10.6f} {r['mean_diff']:10.6f} │ " - f"{r['ms_fla']:9.4f} {r['ms_cula']:10.4f} {r['speedup']:7.2f}x" - ) - print(f" {'─' * 105}") - - print(f"\n{sep}\n") - - -# ============================================================ -# Main -# ============================================================ -def main(): - parser = argparse.ArgumentParser(description="bench_kda_fused_fwd: cuLA fully-fused KDA forward vs FLA Triton") - parser.add_argument( - "--mode", - type=str, - default="both", - choices=["fixed", "varlen", "both"], - help="Which benchmark mode to run (default: both)", - ) - parser.add_argument( - "--ncu", - action="store_true", - help="NCU profiling mode: warmup=1, iters=1", - ) - parser.add_argument( - "--sanitizer", - action="store_true", - help="Sanitizer mode: warmup=1, iters=1", - ) - parser.add_argument( - "--init_state", - action="store_true", - help="Use non-zero initial state (default: False)", + for test in [ + (4, 4, 128, 0.1, [0, 15], torch.bfloat16, True, True), + (4, 4, 128, 0.9, [0, 256, 500, 1000], torch.bfloat16, True, True), + (4, 4, 128, 0.5, [0, 256, 500, 1000], torch.bfloat16, True, True), + (4, 4, 128, 0, [0, 15, 100, 300, 1200, 2000], torch.bfloat16, True, True), + (4, 4, 128, 0, [0, 100, 300, 1200, 3000, 4096], torch.bfloat16, True, True), + (2, 4, 128, 0, [0, 63, 130], torch.bfloat16, True, True), + (1, 2, 128, 0, [0, 1], torch.bfloat16, True, True), + (1, 2, 128, 0, [0, 63, 64, 65], torch.bfloat16, True, True), + (2, 4, 128, 0, [0, 17, 64, 65, 130], torch.bfloat16, True, False), + # ======Varlen test with simulated trace======= + ( + 32, + 32, + 128, + 0, + [0, 247, 699, 982, 1688, 1985, 2383, 3081, 3526, 3973, 4096, 4824, 5101, 5919, 6426, 7137, 7392, 7800, 8192], + torch.bfloat16, + True, + True, + ), + ( + 32, + 32, + 128, + 0, + [0, 652, 1255, 1600, 2083, 2345, 2756, 3172, 3767, 4096, 4891, 5236, 5543, 6255, 6480, 6947, 7616, 8192], + torch.bfloat16, + True, + True, + ), + ( + 32, + 32, + 128, + 0, + [0, 315, 973, 1283, 2162, 2459, 2678, 2998, 3781, 4096, 4503, 5459, 6318, 6669, 6979, 7583, 8192], + torch.bfloat16, + True, + True, + ), + ( + 32, + 32, + 128, + 0, + [0, 494, 1004, 1561, 1908, 2240, 2849, 3116, 4096, 4986, 5626, 6090, 6718, 7244, 7870, 8192], + torch.bfloat16, + True, + True, + ), + ] + ], +) +def test_safe_gate_chunk_varlen( + H: int, + HV: int, + D: int, + mask_p: float, + cu_seqlens: list[int], + dtype: torch.dtype, + safe_gate: bool, + use_initial_state: bool, + beta_dtype: torch.dtype, +): + cula_kda_fused_fwd = get_kda_fused_fwd(device) + + torch.manual_seed(42) + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + cu_seqlens_cpu = cu_seqlens.cpu() + T = cu_seqlens[-1] + N = len(cu_seqlens) - 1 + + q = torch.randn((1, T, H, D), dtype=dtype) + k = F.normalize(torch.randn(1, T, H, D, dtype=torch.float32), p=2, dim=-1).to(dtype) + v = torch.randn((1, T, HV, D), dtype=dtype) + g = F.logsigmoid(torch.randn(1, T, HV, D, dtype=torch.float)) + mask = torch.rand_like(g) > mask_p + g = g * mask + (~mask) * (-1000) + if safe_gate: + g = g.clamp(-5, 0) + + beta = torch.randn(1, T, HV, dtype=torch.float32).sigmoid().to(beta_dtype) + h0 = torch.randn((N, HV, D, D), dtype=torch.float32) + # NOTE: for inference scenarios, we only use transposed state layout for better decoding performance + h0_vk = h0.transpose(-1, -2).contiguous() + + q, k, v, g, beta, h0, h0_vk = map(lambda x: x.to(device).requires_grad_(False), (q, k, v, g, beta, h0, h0_vk)) + initial_state = h0.clone() if use_initial_state else None + initial_state_vk = h0_vk.clone() if use_initial_state else None + heads_per_group = HV // H + q_ref = q.repeat_interleave(heads_per_group, dim=2) + k_ref = k.repeat_interleave(heads_per_group, dim=2) + + tri, tri_ht = cula_kda_fused_fwd( + q=F.normalize(q.clone(), p=2, dim=-1), + k=k.clone(), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + initial_state=initial_state_vk, + output_final_state=True, + cu_seqlens=cu_seqlens, + cu_seqlens_cpu=cu_seqlens_cpu, + safe_gate=safe_gate, + lower_bound=-5.0 if safe_gate else None, ) - args = parser.parse_args() - global NCU_MODE, SANITIZER_MODE, HAS_INIT_STATE - if args.ncu: - NCU_MODE = True - print("[NCU mode] warmup=1, iters=1") - if args.sanitizer: - SANITIZER_MODE = True - print("[Sanitizer mode] warmup=1, iters=1") - if args.init_state: - HAS_INIT_STATE = True - print("[init_state] using non-zero initial state") - - print( - f"[Device] {torch.cuda.get_device_name(0)} compute capability {_SM_TAG} → using {cula_kda_fused_fwd.__module__}.{cula_kda_fused_fwd.__name__}" + ref_fla, ref_ht_fla = fla_chunk_kda( + q=F.normalize(q_ref.clone(), p=2, dim=-1), + k=k_ref.clone(), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + initial_state=initial_state, + output_final_state=True, + cu_seqlens=cu_seqlens, + cu_seqlens_cpu=cu_seqlens_cpu, + safe_gate=safe_gate, + lower_bound=-5.0 if safe_gate else None, ) - fixed_configs = [ - # (B, T) - (1, 512), - (1, 1024), - (1, 4096), - (1, 8192), - (1, 16384), - (2, 512), - (2, 1024), - (2, 4096), - (2, 8192), - (2, 16384), - ] - - varlen_configs = build_varlen_configs( - num_seqs_list=(10, 20), - total_lens=(4096, 8192, 16384), - dists=("uniform", "random", "skewed"), + ref_fla_trans, ref_ht_fla_trans = fla_chunk_kda( + q=F.normalize(q_ref.clone(), p=2, dim=-1), + k=k_ref.clone(), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + initial_state=initial_state_vk, + output_final_state=True, + cu_seqlens=cu_seqlens, + cu_seqlens_cpu=cu_seqlens_cpu, + safe_gate=safe_gate, + lower_bound=-5.0 if safe_gate else None, + transpose_state_layout=True, ) - fixed_res, varlen_res = [], [] - - if args.mode in ("fixed", "both"): - fixed_res = bench_fixed(fixed_configs) - - if args.mode in ("varlen", "both"): - varlen_res = bench_varlen(varlen_configs) - - print_report(fixed_res, varlen_res) - - return fixed_res, varlen_res - - -if __name__ == "__main__": - main() + ref = [] + ref_ht = [] + for i in range(N): + ref_i, ref_ht_i = naive_recurrent_kda( + q=F.normalize(q_ref[:, cu_seqlens[i] : cu_seqlens[i + 1]], p=2, dim=-1), + k=k_ref[:, cu_seqlens[i] : cu_seqlens[i + 1]], + v=v[:, cu_seqlens[i] : cu_seqlens[i + 1]], + beta=beta[:, cu_seqlens[i] : cu_seqlens[i + 1]], + g=g[:, cu_seqlens[i] : cu_seqlens[i + 1]], + initial_state=h0[i] if use_initial_state else None, + output_final_state=True, + ) + ref.append(ref_i) + ref_ht.append(ref_ht_i) + ref = torch.cat(ref, 1) + ref_ht = torch.cat(ref_ht, 0) + + assert_close("o", ref, tri, 0.005) + assert_close("o", ref_fla, tri, 0.005) + assert_close("o", ref_fla_trans, tri, 0.005) + assert_close("ht", ref_ht, tri_ht.transpose(-1, -2), 0.005) + assert_close("ht", ref_ht_fla, tri_ht.transpose(-1, -2), 0.005) + assert_close("ht", ref_ht_fla_trans, tri_ht, 0.005) From c516f28870d2221ec4b5439a6de726b73c764b5b Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Wed, 6 May 2026 16:21:48 +0800 Subject: [PATCH 07/17] sm90 --- cula/kda/hopper_fused_fwd.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/cula/kda/hopper_fused_fwd.py b/cula/kda/hopper_fused_fwd.py index efcfbea..d8b7c9d 100644 --- a/cula/kda/hopper_fused_fwd.py +++ b/cula/kda/hopper_fused_fwd.py @@ -55,6 +55,8 @@ def forward( "q, k, v, g must share batch and sequence dimensions." ) + # num_qk_heads: Q和K共享的注意力头数量,记作H + # num_v_heads: V,G(门控)共享的注意力头数量, 几座HV batch_size, seq_len, num_qk_heads, head_dim = q.shape num_v_heads = v.shape[-2] # Order matters: enforce positivity *before* the modulo so we never % 0. @@ -206,12 +208,6 @@ def cula_kda_prefill( Outputs of shape `[B, T, HV, D]`. final_state (torch.Tensor): Final state of shape `[N, HV, D, D]` if `output_final_state=True` else `None`. - - GVA constraint: - - q.shape == k.shape == [B, T, H, D] - - v.shape == g.shape == [B, T, HV, D], beta.shape == [B, T, HV] - - HV must be a positive multiple of H. heads_per_group = HV // H. - - When HV == H this degenerates to the regular MHA case. """ assert_hopper() assert safe_gate, "Only support safe_gate=True." From c9e4e97c1876ff9b6a89765cc6c975da5b491b49 Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Wed, 6 May 2026 16:28:53 +0800 Subject: [PATCH 08/17] sm90 --- cula/kda/hopper_fused_fwd.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/cula/kda/hopper_fused_fwd.py b/cula/kda/hopper_fused_fwd.py index d8b7c9d..db2b9f3 100644 --- a/cula/kda/hopper_fused_fwd.py +++ b/cula/kda/hopper_fused_fwd.py @@ -55,8 +55,6 @@ def forward( "q, k, v, g must share batch and sequence dimensions." ) - # num_qk_heads: Q和K共享的注意力头数量,记作H - # num_v_heads: V,G(门控)共享的注意力头数量, 几座HV batch_size, seq_len, num_qk_heads, head_dim = q.shape num_v_heads = v.shape[-2] # Order matters: enforce positivity *before* the modulo so we never % 0. @@ -115,11 +113,7 @@ def forward( workspace_buffer = _get_cache_buf("hopper_kda_fwd_workspace", workspace_size, q.device) # call the C++ kernel - # Signature: kda_fwd_prefill(output_, output_state_, q, k, v, input_state_, - # alpha_, beta_, cu_seqlens, workspace, scale, - # safe_gate, output_final_state) - # Passing output_final_state lets the C++ side skip allocating + writing back - # the [N, HV, D, D] fp32 state tensor when the caller does not need it. + # Signature: kda_fwd_prefill(output_, output_state_, q, k, v, input_state_, alpha_, beta_, cu_seqlens, workspace, scale, safe_gate) o, final_state = cula_cuda.kda_fwd_prefill( None, # output_ (auto-allocate) None, # output_state_ (auto-allocate iff output_final_state=True) From 987df50b7bc57ae5bff2b8dabdfaccc6d716b54a Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Wed, 6 May 2026 16:41:10 +0800 Subject: [PATCH 09/17] sm90 --- cula/kda/hopper_fused_fwd.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/cula/kda/hopper_fused_fwd.py b/cula/kda/hopper_fused_fwd.py index db2b9f3..8b36211 100644 --- a/cula/kda/hopper_fused_fwd.py +++ b/cula/kda/hopper_fused_fwd.py @@ -57,7 +57,6 @@ def forward( batch_size, seq_len, num_qk_heads, head_dim = q.shape num_v_heads = v.shape[-2] - # Order matters: enforce positivity *before* the modulo so we never % 0. assert num_qk_heads > 0, f"num_qk_heads must be positive, got {num_qk_heads}." assert num_v_heads > 0, f"num_v_heads must be positive, got {num_v_heads}." assert num_v_heads % num_qk_heads == 0, ( @@ -99,7 +98,7 @@ def forward( q, q_rstd = l2norm_fwd(q) k, k_rstd = l2norm_fwd(k) - # reshape to packed [T, H, D] / [T, HV, D] for the C++ kernel + # reshape q/k to packed [T, H, K] and v/g to [T, HV, K], beta to [T, HV] for the C++ kernel packed_seq = batch_size * seq_len q = q.reshape(packed_seq, num_qk_heads, head_dim).contiguous() k = k.reshape(packed_seq, num_qk_heads, head_dim).contiguous() @@ -116,7 +115,7 @@ def forward( # Signature: kda_fwd_prefill(output_, output_state_, q, k, v, input_state_, alpha_, beta_, cu_seqlens, workspace, scale, safe_gate) o, final_state = cula_cuda.kda_fwd_prefill( None, # output_ (auto-allocate) - None, # output_state_ (auto-allocate iff output_final_state=True) + None, # output_state_ (auto-allocate) q, k, v, @@ -127,7 +126,6 @@ def forward( workspace_buffer, scale, safe_gate, - output_final_state, ) # reshape back @@ -165,23 +163,23 @@ def cula_kda_prefill( Args: q (torch.Tensor): - queries of shape `[B, T, H, D]`. + queries of shape `[B, T, H, K]`. k (torch.Tensor): - keys of shape `[B, T, H, D]`. + keys of shape `[B, T, H, K]`. v (torch.Tensor): - values of shape `[B, T, HV, D]`. + values of shape `[B, T, HV, K]`. g (torch.Tensor): - (forget) gating tensor (in log space!) of shape `[B, T, HV, D]`. + (forget) gating tensor (in log space!) of shape `[B, T, HV, K]`. beta (torch.Tensor): betas of shape `[B, T, HV]`. scale (Optional[float]): Scale factor for the KDA attention scores. If not provided, it will default to `1 / sqrt(D)`. Default: `None`. initial_state (Optional[torch.Tensor]): - Initial state of shape `[N, HV, D, D]` for `N` input sequences. + Initial state of shape `[N, HV, K, K]` for `N` input sequences. Default: `None`. output_final_state (Optional[bool]): - Whether to output the final state of shape `[N, HV, D, D]`. Default: `False`. + Whether to output the final state of shape `[N, HV, K, K]`. Default: `False`. use_qk_l2norm_in_kernel (bool): Whether to apply L2norm to the q,k tensor internally. Default: `False`. use_gate_in_kernel (bool): @@ -199,9 +197,9 @@ def cula_kda_prefill( Returns: o (torch.Tensor): - Outputs of shape `[B, T, HV, D]`. + Outputs of shape `[B, T, HV, K]`. final_state (torch.Tensor): - Final state of shape `[N, HV, D, D]` if `output_final_state=True` else `None`. + Final state of shape `[N, HV, K, K]` if `output_final_state=True` else `None`. """ assert_hopper() assert safe_gate, "Only support safe_gate=True." From ca8f431fc198e091c877820cbfbb5c99df36c053 Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Wed, 6 May 2026 16:49:27 +0800 Subject: [PATCH 10/17] sm90 --- csrc/api/kda_sm90.cu | 7 +++---- csrc/api/pybind.cu | 3 +-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/csrc/api/kda_sm90.cu b/csrc/api/kda_sm90.cu index cd0ffed..89df316 100644 --- a/csrc/api/kda_sm90.cu +++ b/csrc/api/kda_sm90.cu @@ -34,8 +34,7 @@ kda_fwd_prefill( torch::Tensor const& cu_seqlens, torch::Tensor workspace_buffer, float scale, - bool safe_gate, - bool output_final_state) { + bool safe_gate) { // Q, K: [packed_seq, num_qk_heads, D] // V/O/g: [packed_seq, num_v_heads, D] (GVA: num_v_heads is a positive integer multiple of num_qk_heads) auto packed_seq = q.size(0); @@ -154,7 +153,7 @@ kda_fwd_prefill( kda::sm90::launch_kda_fwd_prefill_kernel( stream, reinterpret_cast(output.data_ptr()), - output_state_ptr, + output_state.data_ptr(), reinterpret_cast(q.data_ptr()), reinterpret_cast(k.data_ptr()), reinterpret_cast(v.data_ptr()), @@ -176,7 +175,7 @@ kda_fwd_prefill( kda::sm90::launch_kda_fwd_prefill_kernel( stream, reinterpret_cast(output.data_ptr()), - output_state_ptr, + output_state.data_ptr(), reinterpret_cast(q.data_ptr()), reinterpret_cast(k.data_ptr()), reinterpret_cast(v.data_ptr()), diff --git a/csrc/api/pybind.cu b/csrc/api/pybind.cu index a38ac6d..ba2deb6 100644 --- a/csrc/api/pybind.cu +++ b/csrc/api/pybind.cu @@ -64,8 +64,7 @@ kda_fwd_prefill( torch::Tensor const& cu_seqlens, torch::Tensor workspace_buffer, float scale, - bool safe_gate, - bool output_final_state); + bool safe_gate); #endif PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { From 1b27a7ff818a9fa09576957b6c38655ab3c9190b Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Wed, 6 May 2026 16:58:49 +0800 Subject: [PATCH 11/17] sm90 --- csrc/api/kda_sm90.cu | 26 +++++++------------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/csrc/api/kda_sm90.cu b/csrc/api/kda_sm90.cu index 89df316..07b6727 100644 --- a/csrc/api/kda_sm90.cu +++ b/csrc/api/kda_sm90.cu @@ -57,30 +57,18 @@ kda_fwd_prefill( num_qk_heads); TORCH_CHECK(head_size == v.size(2), "KDA requires Q and V head dim to match, got ", head_size, " vs ", v.size(2)); - // Allocate output if not provided. Output is sized by V/O heads. + // Allocate output if not provided torch::Tensor output = output_.has_value() ? output_.value() : torch::empty( {packed_seq, num_v_heads, head_size}, torch::TensorOptions().dtype(q.dtype()).device(q.device())); - // Allocate output state only when the caller actually needs it. When - // output_final_state=False and the caller did not pass an out tensor, we return a - // 0-element placeholder tensor (no GMEM allocation) and pass nullptr to the - // kernel; the kernel will then skip the final-state write-back entirely, saving - // a [N, HV, D, D] fp32 allocation + GMEM store. - torch::Tensor output_state; - bool need_output_state_buffer = output_final_state || output_state_.has_value(); - if (output_state_.has_value()) { - output_state = output_state_.value(); - } else if (need_output_state_buffer) { - output_state = torch::zeros( - {num_seqs, num_v_heads, head_size, head_size}, - torch::TensorOptions().dtype(torch::kFloat32).device(q.device())); - } else { - // 0-element placeholder so the returned tuple never carries an undefined Tensor. - output_state = torch::empty({0}, torch::TensorOptions().dtype(torch::kFloat32).device(q.device())); - } - + // Allocate output state if not provided + torch::Tensor output_state = output_state_.has_value() + ? output_state_.value() + : torch::zeros( + {num_seqs, num_v_heads, head_size, head_size}, + torch::TensorOptions().dtype(torch::kFloat32).device(q.device())); // Validate dtypes TORCH_CHECK(q.dtype() == torch::kBFloat16, "q must be bfloat16"); TORCH_CHECK(k.dtype() == torch::kBFloat16, "k must be bfloat16"); From 95086a49df5b6f3307fcaaa2a219faf9abb43a60 Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Thu, 7 May 2026 16:32:52 +0800 Subject: [PATCH 12/17] add sm90 benchmarks --- benchmarks/bench_kda_fused_fwd.py | 140 +++++++++++++++++++++++------- benchmarks/utils.py | 68 +++++++++++++-- 2 files changed, 168 insertions(+), 40 deletions(-) diff --git a/benchmarks/bench_kda_fused_fwd.py b/benchmarks/bench_kda_fused_fwd.py index a87eed3..d297157 100644 --- a/benchmarks/bench_kda_fused_fwd.py +++ b/benchmarks/bench_kda_fused_fwd.py @@ -25,12 +25,17 @@ - Accuracy: RMSE, relative max diff between cuLA fully-fused and FLA Triton - Performance: kernel execution time (ms) with CUDA events -Modes: - - Fixed-length: various (B, T) configs - - Varlen: sequences with 2-3x length variation +Modes (each config carries its own (H, HV); GVA is enabled when HV > H): + - Fixed-length: various (B, T, H, HV) configs. + - Varlen: sequences with 2-3x length variation, per-config (H, HV). + +Under GVA (HV > H), q/k/g/beta are expanded from H to HV heads via +`repeat(..., "... h d -> ... (h g) d")`. This keeps FLA's `chunk_kda` +(which does not natively support GVA) and cuLA's SM100 fully-fused forward +(which requires q/k/v to share the head dim) on the same input layout. Usage: - python bench_kda_fused_fwd.py [--mode fixed|varlen|both] [--ncu] + python bench_kda_fused_fwd.py [--mode fixed|varlen|both|all] [--ncu] With --ncu, warmup=1 and iters=1 for ncu profiling: ncu --set full -o report python bench_kda_fused_fwd.py --mode varlen --ncu @@ -67,6 +72,8 @@ # ============================================================ # Constants # ============================================================ +# Default number of Q/K heads. Each benchmark config may additionally specify +# HV (number of V heads) to enable GVA (HV > H must be a positive multiple of H). H, D = 64, 128 WARMUP = 25 N_ITERS = 100 @@ -151,15 +158,47 @@ def run_cula(q, k, v, g, beta, scale, A_log, dt_bias, init_state, cu_seqlens, lo # ============================================================ -# Fixed-length benchmark +# Config normalization helpers +# ============================================================ +def _normalize_fixed_config(cfg): + """Accept either (B, T) or (B, T, H_qk, HV) and return the 4-tuple form. + + For the 2-tuple legacy form, defaults to H_qk=HV=H (no GVA). + """ + if len(cfg) == 2: + B, T = cfg + return B, T, H, H + if len(cfg) == 4: + return cfg + raise ValueError(f"Fixed config must be (B, T) or (B, T, H, HV), got {cfg!r}") + + +def _normalize_varlen_config(cfg): + """Accept (seq_lens, total_len, dist) or (seq_lens, total_len, dist, H_qk, HV). + + For the 3-tuple legacy form, defaults to H_qk=HV=H (no GVA). + """ + if len(cfg) == 3: + seq_lens, total_len, dist = cfg + return seq_lens, total_len, dist, H, H + if len(cfg) == 5: + return cfg + raise ValueError( + f"Varlen config must be (seq_lens, total_len, dist) or (seq_lens, total_len, dist, H, HV), got {cfg!r}" + ) + + +# ============================================================ +# Fixed-length benchmark (GVA-aware via per-config HV) # ============================================================ def bench_fixed(configs): print("\n" + "=" * 100) - print(f" Fixed-Length Benchmark: cuLA fully-fused ({_SM_TAG}) vs FLA Triton") + print(f" Fixed-Length Benchmark: cuLA fully-fused ({_SM_TAG}) vs FLA Triton (GVA when HV > H)") print("=" * 100) results = [] - for B, T in configs: + for cfg in configs: + B, T, H_qk, HV = _normalize_fixed_config(cfg) set_seed(SEED) device = torch.device("cuda") torch.cuda.empty_cache() @@ -167,7 +206,12 @@ def bench_fixed(configs): seq_lens = [T] * B cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) - inputs = prepare_safe_gate_inputs(B, T, H, D, device, cu_seqlens=cu_seqlens, has_init_state=HAS_INIT_STATE) + inputs = prepare_safe_gate_inputs( + B, T, H_qk, D, device, + cu_seqlens=cu_seqlens, + has_init_state=HAS_INIT_STATE, + num_v_heads=HV, + ) q, k, v, g, beta = inputs["q"], inputs["k"], inputs["v"], inputs["g"], inputs["beta"] A_log, dt_bias = inputs["A_log"], inputs["dt_bias"] scale, init_state, lower_bound = inputs["scale"], inputs["init_state"], inputs["lower_bound"] @@ -202,6 +246,8 @@ def bench_fixed(configs): { "B": B, "T": T, + "H": H_qk, + "HV": HV, "rmse": rmse, "rel_max": rel_max, "mean_diff": mean_diff, @@ -218,15 +264,16 @@ def bench_fixed(configs): # ============================================================ -# Varlen benchmark +# Varlen benchmark (GVA-aware via per-config HV) # ============================================================ def bench_varlen(configs): print("\n" + "=" * 100) - print(f" Varlen Benchmark: cuLA fully-fused ({_SM_TAG}) vs FLA Triton") + print(f" Varlen Benchmark: cuLA fully-fused ({_SM_TAG}) vs FLA Triton (GVA when HV > H)") print("=" * 100) results = [] - for seq_lens, total_len, dist in configs: + for cfg in configs: + seq_lens, total_len, dist, H_qk, HV = _normalize_varlen_config(cfg) set_seed(SEED) device = torch.device("cuda") torch.cuda.empty_cache() @@ -234,7 +281,12 @@ def bench_varlen(configs): T = total_len cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) - inputs = prepare_safe_gate_inputs(1, T, H, D, device, cu_seqlens=cu_seqlens, has_init_state=HAS_INIT_STATE) + inputs = prepare_safe_gate_inputs( + 1, T, H_qk, D, device, + cu_seqlens=cu_seqlens, + has_init_state=HAS_INIT_STATE, + num_v_heads=HV, + ) q, k, v, g, beta = inputs["q"], inputs["k"], inputs["v"], inputs["g"], inputs["beta"] A_log, dt_bias = inputs["A_log"], inputs["dt_bias"] scale, init_state, lower_bound = inputs["scale"], inputs["init_state"], inputs["lower_bound"] @@ -276,6 +328,8 @@ def bench_varlen(configs): "dist": dist, "T_total": T, "n_seqs": n_seqs, + "H": H_qk, + "HV": HV, "rmse": rmse, "rel_max": rel_max, "mean_diff": mean_diff, @@ -295,11 +349,12 @@ def bench_varlen(configs): # Report # ============================================================ def print_report(fixed_results, varlen_results): - sep = "=" * 110 + sep = "=" * 120 print(f"\n\n{sep}") print(" BENCHMARK REPORT: cula_kda_fused_fwd (fully-fused)") print(f" cuLA {_SM_TAG} fully-fused vs FLA Triton") - print(f" H={H} D={D} dtype=bf16 safe_gate=True has_init_state={HAS_INIT_STATE}") + print(f" D={D} dtype=bf16 safe_gate=True has_init_state={HAS_INIT_STATE}") + print(f" GVA rows are those with HV > H (H, HV shown per row).") wu = 1 if (NCU_MODE or SANITIZER_MODE) else WARMUP ni = 1 if (NCU_MODE or SANITIZER_MODE) else N_ITERS mode_tag = " [NCU mode]" if NCU_MODE else (" [Sanitizer mode]" if SANITIZER_MODE else "") @@ -308,35 +363,39 @@ def print_report(fixed_results, varlen_results): if fixed_results: print("\n [Fixed-Length]") - print(f" {'─' * 90}") + print(f" {'─' * 110}") print( - f" {'B':>3s} {'T':>6s} │ {'RMSE':>10s} {'rel_max':>10s} {'mean_diff':>10s}" - f" │ {'FLA(ms)':>9s} {'cuLA(ms)':>10s} {'Speedup':>8s}" + f" {'B':>3s} {'T':>6s} {'H':>3s} {'HV':>3s} {'GVA':>4s} │ " + f"{'RMSE':>10s} {'rel_max':>10s} {'mean_diff':>10s} │ " + f"{'FLA(ms)':>9s} {'cuLA(ms)':>10s} {'Speedup':>8s}" ) - print(f" {'─' * 90}") + print(f" {'─' * 110}") for r in fixed_results: + gva_tag = "yes" if r["HV"] > r["H"] else "no" print( - f" {r['B']:3d} {r['T']:6d} │ " + f" {r['B']:3d} {r['T']:6d} {r['H']:3d} {r['HV']:3d} {gva_tag:>4s} │ " f"{r['rmse']:10.6f} {r['rel_max']:10.6f} {r['mean_diff']:10.6f} │ " f"{r['ms_fla']:9.4f} {r['ms_cula']:10.4f} {r['speedup']:7.2f}x" ) - print(f" {'─' * 90}") + print(f" {'─' * 110}") if varlen_results: print("\n [Varlen]") - print(f" {'─' * 105}") + print(f" {'─' * 120}") print( - f" {'Config':>45s} │ {'RMSE':>10s} {'rel_max':>10s} {'mean_diff':>10s}" - f" │ {'FLA(ms)':>9s} {'cuLA(ms)':>10s} {'Speedup':>8s}" + f" {'Config':>45s} {'H':>3s} {'HV':>3s} {'GVA':>4s} │ " + f"{'RMSE':>10s} {'rel_max':>10s} {'mean_diff':>10s} │ " + f"{'FLA(ms)':>9s} {'cuLA(ms)':>10s} {'Speedup':>8s}" ) - print(f" {'─' * 105}") + print(f" {'─' * 120}") for r in varlen_results: + gva_tag = "yes" if r["HV"] > r["H"] else "no" print( - f" {r['tag']:>45s} │ " + f" {r['tag']:>45s} {r['H']:3d} {r['HV']:3d} {gva_tag:>4s} │ " f"{r['rmse']:10.6f} {r['rel_max']:10.6f} {r['mean_diff']:10.6f} │ " f"{r['ms_fla']:9.4f} {r['ms_cula']:10.4f} {r['speedup']:7.2f}x" ) - print(f" {'─' * 105}") + print(f" {'─' * 120}") print(f"\n{sep}\n") @@ -349,9 +408,9 @@ def main(): parser.add_argument( "--mode", type=str, - default="both", - choices=["fixed", "varlen", "both"], - help="Which benchmark mode to run (default: both)", + default="all", + choices=["fixed", "varlen", "both", "all"], + help="Which benchmark mode to run (default: all). 'all'/'both' = fixed + varlen.", ) parser.add_argument( "--ncu", @@ -385,8 +444,10 @@ def main(): f"[Device] {torch.cuda.get_device_name(0)} compute capability {_SM_TAG} → using {cula_kda_fused_fwd.__module__}.{cula_kda_fused_fwd.__name__}" ) + # Fixed-length configs: (B, T) means H_qk=HV=H (no GVA); (B, T, H_qk, HV) + # activates GVA when HV > H. The two forms can be freely mixed. fixed_configs = [ - # (B, T) + # Non-GVA (H_qk == HV == H): (1, 512), (1, 1024), (1, 4096), @@ -397,20 +458,33 @@ def main(): (2, 4096), (2, 8192), (2, 16384), + # GVA (HV > H, same D=128): + (1, 1024, 16, 64), + (1, 4096, 16, 64), + (1, 8192, 16, 64), + (1, 4096, 32, 64), + (1, 8192, 32, 64), + (2, 4096, 16, 64), + (2, 8192, 16, 64), ] - varlen_configs = build_varlen_configs( + # Varlen configs: 3-tuples (seq_lens, total_len, dist) default to no GVA; + # extend to 5-tuples (..., H_qk, HV) to activate GVA on varlen workloads. + varlen_configs_base = build_varlen_configs( num_seqs_list=(10, 20), total_lens=(4096, 8192, 16384), dists=("uniform", "random", "skewed"), ) + # A small GVA subset reuses the non-GVA varlen shapes with (H_qk=16, HV=64). + gva_varlen = [(seq_lens, T, dist, 16, 64) for (seq_lens, T, dist) in varlen_configs_base if T <= 8192] + varlen_configs = list(varlen_configs_base) + gva_varlen fixed_res, varlen_res = [], [] - if args.mode in ("fixed", "both"): + if args.mode in ("fixed", "both", "all"): fixed_res = bench_fixed(fixed_configs) - if args.mode in ("varlen", "both"): + if args.mode in ("varlen", "both", "all"): varlen_res = bench_varlen(varlen_configs) print_report(fixed_res, varlen_res) diff --git a/benchmarks/utils.py b/benchmarks/utils.py index 29bab04..2e75a6c 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -256,25 +256,77 @@ def build_varlen_configs( def prepare_safe_gate_inputs( - batch_size, T, H, D, device, cu_seqlens=None, chunk_size=CHUNK_SIZE, seed=SEED, has_init_state=False + batch_size, + T, + H, + D, + device, + cu_seqlens=None, + chunk_size=CHUNK_SIZE, + seed=SEED, + has_init_state=False, + num_v_heads=None, ): """Prepare inputs for safe_gate benchmarks (use_gate_in_kernel=True, safe_gate=True). + Args: + batch_size: Batch size `B`. + T: Per-sequence length. + H: Number of Q/K heads (the "group" count, must be > 0). + D: Head dimension. + device: torch device. + cu_seqlens: Optional cumulative sequence lengths for varlen. + chunk_size: Chunk size for `prepare_chunk_indices`. + seed: RNG seed. + has_init_state: Whether to allocate a non-zero initial state. + num_v_heads: Optional number of V heads (HV). Defaults to `H` (no GVA). + When `HV > H`, GVA (Grouped Value Attention) is enabled: + - v/g/beta are allocated with HV heads. + - q/k/g/beta are expanded from H to HV heads via + `repeat_interleave(..., dim=head)`, equivalent to the einops + pattern `repeat(x, "... h d -> ... (h g) d")` used by + `fla.layers.kda`. + - `A_log` / `dt_bias` are sized to HV because FLA's + `kda_gate_chunk_cumsum` indexes them per head of `g`. + This expanded layout works on both cuLA backends (SM100 requires + q.shape[-2] == v.shape[-2]; SM90 accepts both native and expanded + GVA) and on FLA's `chunk_kda` (which does not natively support GVA). + All tensors are flattened to (1, B*T, ...) for cu_seqlens compatibility. + The returned dict always contains `H` and `HV` keys so callers can report + the effective head counts uniformly. """ + HV = H if num_v_heads is None else num_v_heads + assert H > 0, f"H must be positive, got {H}." + assert HV > 0, f"HV must be positive, got {HV}." + assert HV % H == 0, f"HV ({HV}) must be a positive multiple of H ({H})." + dtype = torch.bfloat16 scale = D ** (-0.5) set_seed(seed) + # Allocate native GVA shapes: + # q, k: (B, T, H, D) + # v, g: (B, T, HV, D) + # beta: (B, T, HV) + # When HV == H this collapses to the standard non-GVA layout. q = torch.randn(batch_size, T, H, D, dtype=dtype, device=device).requires_grad_(False) k = torch.randn(batch_size, T, H, D, dtype=dtype, device=device).requires_grad_(False) - v = torch.randn(batch_size, T, H, D, dtype=dtype, device=device).requires_grad_(False) - g = torch.randn(batch_size, T, H, D, dtype=dtype, device=device).requires_grad_(False) - beta = torch.randn(batch_size, T, H, dtype=torch.float, device=device).sigmoid().requires_grad_(False) + v = torch.randn(batch_size, T, HV, D, dtype=dtype, device=device).requires_grad_(False) + g = torch.randn(batch_size, T, HV, D, dtype=dtype, device=device).requires_grad_(False) + beta = torch.randn(batch_size, T, HV, dtype=torch.float, device=device).sigmoid().requires_grad_(False) + + # GVA expansion: bring q/k up to HV heads so all tensors share head dim. + group = HV // H + if group > 1: + q = q.repeat_interleave(group, dim=2).contiguous() + k = k.repeat_interleave(group, dim=2).contiguous() - A_log = torch.randn(H, dtype=torch.float, device=device).requires_grad_(False) - dt_bias = torch.randn(H * D, dtype=torch.float, device=device).requires_grad_(False) + # A_log / dt_bias must match the head count of `g` (HV), otherwise + # kda_gate_chunk_cumsum would index out of bounds for i_h >= H. + A_log = torch.randn(HV, dtype=torch.float, device=device).requires_grad_(False) + dt_bias = torch.randn(HV * D, dtype=torch.float, device=device).requires_grad_(False) # flatten to batch_size=1 for cu_seqlens compatibility if batch_size != 1: @@ -285,7 +337,7 @@ def prepare_safe_gate_inputs( init_state = None if has_init_state: num_seqs = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch_size - init_state = torch.randn(num_seqs, H, D, D, dtype=torch.float, device=device).requires_grad_(False) + init_state = torch.randn(num_seqs, HV, D, D, dtype=torch.float, device=device).requires_grad_(False) return dict( q=q, @@ -300,6 +352,8 @@ def prepare_safe_gate_inputs( chunk_indices=chunk_indices, init_state=init_state, lower_bound=-5.0, + H=H, + HV=HV, ) From 2b0c922636759fed0b7d32ef4669309b4c6de93b Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Thu, 7 May 2026 18:41:53 +0800 Subject: [PATCH 13/17] benchmarks --- benchmarks/bench_kda_fused_fwd.py | 133 ++++++++++++++++++------------ benchmarks/utils.py | 59 +++++++++++++ 2 files changed, 140 insertions(+), 52 deletions(-) diff --git a/benchmarks/bench_kda_fused_fwd.py b/benchmarks/bench_kda_fused_fwd.py index d297157..52570d2 100644 --- a/benchmarks/bench_kda_fused_fwd.py +++ b/benchmarks/bench_kda_fused_fwd.py @@ -25,23 +25,23 @@ - Accuracy: RMSE, relative max diff between cuLA fully-fused and FLA Triton - Performance: kernel execution time (ms) with CUDA events -Modes (each config carries its own (H, HV); GVA is enabled when HV > H): - - Fixed-length: various (B, T, H, HV) configs. - - Varlen: sequences with 2-3x length variation, per-config (H, HV). +Modes: + - Fixed-length: various (B, T) configs; a small GVA supplement uses (B, T, H, HV). + - Varlen: sequences with realistic length variation; a few GVA configs included. -Under GVA (HV > H), q/k/g/beta are expanded from H to HV heads via -`repeat(..., "... h d -> ... (h g) d")`. This keeps FLA's `chunk_kda` -(which does not natively support GVA) and cuLA's SM100 fully-fused forward -(which requires q/k/v to share the head dim) on the same input layout. +Each config may carry an explicit (H_qk, HV) pair to enable GVA (HV > H). +Under GVA, q/k/g/beta are expanded from H to HV heads so that both FLA and +cuLA operate on the same input layout. Usage: - python bench_kda_fused_fwd.py [--mode fixed|varlen|both|all] [--ncu] + python bench_kda_fused_fwd.py [--mode fixed|varlen|both] [--ncu] With --ncu, warmup=1 and iters=1 for ncu profiling: ncu --set full -o report python bench_kda_fused_fwd.py --mode varlen --ncu """ import argparse +import datetime import os import pathlib import sys @@ -56,6 +56,7 @@ SEED, build_varlen_configs, exclusive_cumsum, + prepare_gva_inputs, prepare_safe_gate_inputs, set_seed, ) @@ -193,7 +194,7 @@ def _normalize_varlen_config(cfg): # ============================================================ def bench_fixed(configs): print("\n" + "=" * 100) - print(f" Fixed-Length Benchmark: cuLA fully-fused ({_SM_TAG}) vs FLA Triton (GVA when HV > H)") + print(f" Fixed-Length Benchmark: cuLA fully-fused ({_SM_TAG}) vs FLA Triton") print("=" * 100) results = [] @@ -206,31 +207,32 @@ def bench_fixed(configs): seq_lens = [T] * B cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) - inputs = prepare_safe_gate_inputs( - B, T, H_qk, D, device, - cu_seqlens=cu_seqlens, - has_init_state=HAS_INIT_STATE, - num_v_heads=HV, - ) + if HV > H_qk: + inputs = prepare_gva_inputs( + B, T, H_qk, HV, D, device, + cu_seqlens=cu_seqlens, + has_init_state=HAS_INIT_STATE, + ) + else: + inputs = prepare_safe_gate_inputs( + B, T, H_qk, D, device, + cu_seqlens=cu_seqlens, + has_init_state=HAS_INIT_STATE, + ) q, k, v, g, beta = inputs["q"], inputs["k"], inputs["v"], inputs["g"], inputs["beta"] A_log, dt_bias = inputs["A_log"], inputs["dt_bias"] scale, init_state, lower_bound = inputs["scale"], inputs["init_state"], inputs["lower_bound"] common = dict( - q=q, - k=k, - v=v, - g=g, - beta=beta, - scale=scale, - A_log=A_log, - dt_bias=dt_bias, - init_state=init_state, - cu_seqlens=cu_seqlens, - lower_bound=lower_bound, + q=q, k=k, v=v, g=g, beta=beta, + scale=scale, A_log=A_log, dt_bias=dt_bias, + init_state=init_state, cu_seqlens=cu_seqlens, lower_bound=lower_bound, ) - # Accuracy + # Accuracy: run once to warm up JIT, then measure on the second call + run_fla(**common) + run_cula(**common) + torch.cuda.synchronize() o_fla, _ = run_fla(**common) o_cula, _ = run_cula(**common) torch.cuda.synchronize() @@ -258,6 +260,7 @@ def bench_fixed(configs): ) del o_fla, o_cula, q, k, v, g, beta, A_log, dt_bias, inputs + del common, scale, init_state, lower_bound, cu_seqlens torch.cuda.empty_cache() return results @@ -268,7 +271,7 @@ def bench_fixed(configs): # ============================================================ def bench_varlen(configs): print("\n" + "=" * 100) - print(f" Varlen Benchmark: cuLA fully-fused ({_SM_TAG}) vs FLA Triton (GVA when HV > H)") + print(f" Varlen Benchmark: cuLA fully-fused ({_SM_TAG}) vs FLA Triton") print("=" * 100) results = [] @@ -281,12 +284,18 @@ def bench_varlen(configs): T = total_len cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) - inputs = prepare_safe_gate_inputs( - 1, T, H_qk, D, device, - cu_seqlens=cu_seqlens, - has_init_state=HAS_INIT_STATE, - num_v_heads=HV, - ) + if HV > H_qk: + inputs = prepare_gva_inputs( + 1, T, H_qk, HV, D, device, + cu_seqlens=cu_seqlens, + has_init_state=HAS_INIT_STATE, + ) + else: + inputs = prepare_safe_gate_inputs( + 1, T, H_qk, D, device, + cu_seqlens=cu_seqlens, + has_init_state=HAS_INIT_STATE, + ) q, k, v, g, beta = inputs["q"], inputs["k"], inputs["v"], inputs["g"], inputs["beta"] A_log, dt_bias = inputs["A_log"], inputs["dt_bias"] scale, init_state, lower_bound = inputs["scale"], inputs["init_state"], inputs["lower_bound"] @@ -305,7 +314,10 @@ def bench_varlen(configs): lower_bound=lower_bound, ) - # Accuracy + # Accuracy: run once to warm up JIT, then measure on the second call + run_fla(**common) + run_cula(**common) + torch.cuda.synchronize() o_fla, _ = run_fla(**common) o_cula, _ = run_cula(**common) torch.cuda.synchronize() @@ -340,6 +352,7 @@ def bench_varlen(configs): ) del o_fla, o_cula, q, k, v, g, beta, A_log, dt_bias, inputs + del common, scale, init_state, lower_bound, cu_seqlens torch.cuda.empty_cache() return results @@ -350,6 +363,7 @@ def bench_varlen(configs): # ============================================================ def print_report(fixed_results, varlen_results): sep = "=" * 120 + ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") print(f"\n\n{sep}") print(" BENCHMARK REPORT: cula_kda_fused_fwd (fully-fused)") print(f" cuLA {_SM_TAG} fully-fused vs FLA Triton") @@ -359,6 +373,7 @@ def print_report(fixed_results, varlen_results): ni = 1 if (NCU_MODE or SANITIZER_MODE) else N_ITERS mode_tag = " [NCU mode]" if NCU_MODE else (" [Sanitizer mode]" if SANITIZER_MODE else "") print(f" Warmup={wu} Iters={ni}{mode_tag}") + print(f" {ts}") print(sep) if fixed_results: @@ -370,8 +385,13 @@ def print_report(fixed_results, varlen_results): f"{'FLA(ms)':>9s} {'cuLA(ms)':>10s} {'Speedup':>8s}" ) print(f" {'─' * 110}") + prev_gva = False for r in fixed_results: - gva_tag = "yes" if r["HV"] > r["H"] else "no" + is_gva = r["HV"] > r["H"] + if is_gva and not prev_gva: + print(f" {'·' * 110} ← GVA (HV > H)") + prev_gva = is_gva + gva_tag = "yes" if is_gva else "no" print( f" {r['B']:3d} {r['T']:6d} {r['H']:3d} {r['HV']:3d} {gva_tag:>4s} │ " f"{r['rmse']:10.6f} {r['rel_max']:10.6f} {r['mean_diff']:10.6f} │ " @@ -388,8 +408,13 @@ def print_report(fixed_results, varlen_results): f"{'FLA(ms)':>9s} {'cuLA(ms)':>10s} {'Speedup':>8s}" ) print(f" {'─' * 120}") + prev_gva = False for r in varlen_results: - gva_tag = "yes" if r["HV"] > r["H"] else "no" + is_gva = r["HV"] > r["H"] + if is_gva and not prev_gva: + print(f" {'·' * 120} ← GVA (HV > H)") + prev_gva = is_gva + gva_tag = "yes" if is_gva else "no" print( f" {r['tag']:>45s} {r['H']:3d} {r['HV']:3d} {gva_tag:>4s} │ " f"{r['rmse']:10.6f} {r['rel_max']:10.6f} {r['mean_diff']:10.6f} │ " @@ -408,9 +433,9 @@ def main(): parser.add_argument( "--mode", type=str, - default="all", - choices=["fixed", "varlen", "both", "all"], - help="Which benchmark mode to run (default: all). 'all'/'both' = fixed + varlen.", + default="both", + choices=["fixed", "varlen", "both"], + help="Which benchmark mode to run (default: both).", ) parser.add_argument( "--ncu", @@ -458,14 +483,13 @@ def main(): (2, 4096), (2, 8192), (2, 16384), - # GVA (HV > H, same D=128): - (1, 1024, 16, 64), - (1, 4096, 16, 64), - (1, 8192, 16, 64), - (1, 4096, 32, 64), - (1, 8192, 32, 64), - (2, 4096, 16, 64), - (2, 8192, 16, 64), + # GVA (HV > H, same D=128): representative 2x and 4x ratios at key T values. + # Kept small intentionally — this script benchmarks the general KDA kernel; + # GVA configs are a supplementary check, not the primary focus. + (1, 4096, 32, 64), # 2x ratio + (1, 8192, 32, 64), # 2x ratio + (1, 4096, 16, 64), # 4x ratio + (1, 8192, 16, 64), # 4x ratio ] # Varlen configs: 3-tuples (seq_lens, total_len, dist) default to no GVA; @@ -475,16 +499,21 @@ def main(): total_lens=(4096, 8192, 16384), dists=("uniform", "random", "skewed"), ) - # A small GVA subset reuses the non-GVA varlen shapes with (H_qk=16, HV=64). - gva_varlen = [(seq_lens, T, dist, 16, 64) for (seq_lens, T, dist) in varlen_configs_base if T <= 8192] + # GVA varlen: one representative config per total-length (random dist, 10 seqs, + # 4x ratio H=16→HV=64). Kept minimal — supplementary GVA coverage only. + gva_varlen = [ + (seq_lens, T, dist, 16, 64) + for (seq_lens, T, dist) in varlen_configs_base + if dist == "random" and len(seq_lens) == 10 + ] varlen_configs = list(varlen_configs_base) + gva_varlen fixed_res, varlen_res = [], [] - if args.mode in ("fixed", "both", "all"): + if args.mode in ("fixed", "both"): fixed_res = bench_fixed(fixed_configs) - if args.mode in ("varlen", "both", "all"): + if args.mode in ("varlen", "both"): varlen_res = bench_varlen(varlen_configs) print_report(fixed_res, varlen_res) diff --git a/benchmarks/utils.py b/benchmarks/utils.py index 2e75a6c..5c5f918 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -357,6 +357,65 @@ def prepare_safe_gate_inputs( ) +def prepare_gva_inputs( + batch_size, + T, + H, + HV, + D, + device, + cu_seqlens=None, + chunk_size=CHUNK_SIZE, + seed=SEED, + has_init_state=False, +): + """Prepare inputs for GVA (Grouped Value Attention) benchmarks. + + GVA uses separate head counts for QK (H) and V (HV), where HV > H and + HV is a positive multiple of H. The grouping factor is G = HV // H. + + Layout produced (after GVA expansion): + q, k, g, beta : HV heads (q/k expanded from H via repeat_interleave) + v : HV heads + A_log, dt_bias: sized to HV + + This is a thin, named wrapper around ``prepare_safe_gate_inputs`` with + ``num_v_heads=HV``. Having an explicit function makes the GVA data path + visible in benchmark scripts and easier to discover. + + Args: + batch_size: Batch size ``B``. + T: Per-sequence token count. + H: Number of QK heads (group count). + HV: Number of V heads; must satisfy ``HV > H`` and ``HV % H == 0``. + D: Head dimension. + device: Target torch device. + cu_seqlens: Optional cumulative sequence-length tensor for varlen mode. + chunk_size: Chunk size passed to ``prepare_chunk_indices``. + seed: RNG seed for reproducibility. + has_init_state: If True, allocate a non-zero recurrent initial state. + + Returns: + Same dict as ``prepare_safe_gate_inputs``. + """ + if HV <= H: + raise ValueError(f"GVA requires HV > H, got H={H}, HV={HV}.") + if HV % H != 0: + raise ValueError(f"HV ({HV}) must be a positive multiple of H ({H}).") + return prepare_safe_gate_inputs( + batch_size=batch_size, + T=T, + H=H, + D=D, + device=device, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + seed=seed, + has_init_state=has_init_state, + num_v_heads=HV, + ) + + def prepare_intra_inputs(batch_size, T, H, D, device, cu_seqlens=None, chunk_size=CHUNK_SIZE, seed=SEED): """Prepare preprocessed inputs ready for chunk_kda_fwd_intra. From b538129906cd5f0f87b16727a3e0218e79c64381 Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Thu, 7 May 2026 19:14:55 +0800 Subject: [PATCH 14/17] benchmarks --- benchmarks/bench_kda_fused_fwd.py | 521 +++++++++++++++++++++++++----- benchmarks/utils.py | 160 ++++++--- 2 files changed, 548 insertions(+), 133 deletions(-) diff --git a/benchmarks/bench_kda_fused_fwd.py b/benchmarks/bench_kda_fused_fwd.py index 52570d2..135527d 100644 --- a/benchmarks/bench_kda_fused_fwd.py +++ b/benchmarks/bench_kda_fused_fwd.py @@ -25,23 +25,34 @@ - Accuracy: RMSE, relative max diff between cuLA fully-fused and FLA Triton - Performance: kernel execution time (ms) with CUDA events -Modes: - - Fixed-length: various (B, T) configs; a small GVA supplement uses (B, T, H, HV). - - Varlen: sequences with realistic length variation; a few GVA configs included. - -Each config may carry an explicit (H_qk, HV) pair to enable GVA (HV > H). -Under GVA, q/k/g/beta are expanded from H to HV heads so that both FLA and -cuLA operate on the same input layout. +Modes (--mode, default: all): + - fixed: Fixed-length sequences, various (B, T, H, HV) configs. + GVA rows (HV > H) are mixed in alongside non-GVA rows. + - varlen: Variable-length sequences with 2-3x length variation. + Non-GVA base configs plus a GVA subset (H=16, HV=64). + - gva: Dedicated GVA benchmark (fixed + varlen) using prepare_gva_inputs. + Covers multiple GVA ratios (2x / 4x / 8x); compares cuLA vs FLA. + - overhead: GVA overhead benchmark — cuLA GVA vs cuLA non-GVA at the same + total head count (HV). Both paths present identical tensor shapes + to the kernel, so a near-zero overhead% proves that GVA adds no + measurable kernel latency regression. + - both: Fixed-length + varlen only (legacy alias, no gva/overhead). + - all: Run all of the above. + +Under GVA (HV > H), q/k are expanded from H to HV heads via +`repeat_interleave(..., dim=2)`, equivalent to the einops pattern +`repeat(x, "... h d -> ... (h g) d")`. This keeps FLA's `chunk_kda` +(which does not natively support GVA) and cuLA's SM100 fully-fused forward +(which requires q/k/v to share the head dim) on the same input layout. Usage: - python bench_kda_fused_fwd.py [--mode fixed|varlen|both] [--ncu] + python bench_kda_fused_fwd.py [--mode fixed|varlen|gva|overhead|both|all] [--ncu] With --ncu, warmup=1 and iters=1 for ncu profiling: - ncu --set full -o report python bench_kda_fused_fwd.py --mode varlen --ncu + ncu --set full -o report python bench_kda_fused_fwd.py --mode overhead --ncu """ import argparse -import datetime import os import pathlib import sys @@ -54,6 +65,8 @@ from benchmarks.utils import ( SEED, + build_gva_fixed_configs, + build_gva_varlen_configs, build_varlen_configs, exclusive_cumsum, prepare_gva_inputs, @@ -194,7 +207,7 @@ def _normalize_varlen_config(cfg): # ============================================================ def bench_fixed(configs): print("\n" + "=" * 100) - print(f" Fixed-Length Benchmark: cuLA fully-fused ({_SM_TAG}) vs FLA Triton") + print(f" Fixed-Length Benchmark: cuLA fully-fused ({_SM_TAG}) vs FLA Triton (GVA when HV > H)") print("=" * 100) results = [] @@ -207,32 +220,31 @@ def bench_fixed(configs): seq_lens = [T] * B cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) - if HV > H_qk: - inputs = prepare_gva_inputs( - B, T, H_qk, HV, D, device, - cu_seqlens=cu_seqlens, - has_init_state=HAS_INIT_STATE, - ) - else: - inputs = prepare_safe_gate_inputs( - B, T, H_qk, D, device, - cu_seqlens=cu_seqlens, - has_init_state=HAS_INIT_STATE, - ) + inputs = prepare_safe_gate_inputs( + B, T, H_qk, D, device, + cu_seqlens=cu_seqlens, + has_init_state=HAS_INIT_STATE, + num_v_heads=HV, + ) q, k, v, g, beta = inputs["q"], inputs["k"], inputs["v"], inputs["g"], inputs["beta"] A_log, dt_bias = inputs["A_log"], inputs["dt_bias"] scale, init_state, lower_bound = inputs["scale"], inputs["init_state"], inputs["lower_bound"] common = dict( - q=q, k=k, v=v, g=g, beta=beta, - scale=scale, A_log=A_log, dt_bias=dt_bias, - init_state=init_state, cu_seqlens=cu_seqlens, lower_bound=lower_bound, + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + A_log=A_log, + dt_bias=dt_bias, + init_state=init_state, + cu_seqlens=cu_seqlens, + lower_bound=lower_bound, ) - # Accuracy: run once to warm up JIT, then measure on the second call - run_fla(**common) - run_cula(**common) - torch.cuda.synchronize() + # Accuracy o_fla, _ = run_fla(**common) o_cula, _ = run_cula(**common) torch.cuda.synchronize() @@ -260,7 +272,6 @@ def bench_fixed(configs): ) del o_fla, o_cula, q, k, v, g, beta, A_log, dt_bias, inputs - del common, scale, init_state, lower_bound, cu_seqlens torch.cuda.empty_cache() return results @@ -271,7 +282,7 @@ def bench_fixed(configs): # ============================================================ def bench_varlen(configs): print("\n" + "=" * 100) - print(f" Varlen Benchmark: cuLA fully-fused ({_SM_TAG}) vs FLA Triton") + print(f" Varlen Benchmark: cuLA fully-fused ({_SM_TAG}) vs FLA Triton (GVA when HV > H)") print("=" * 100) results = [] @@ -284,21 +295,104 @@ def bench_varlen(configs): T = total_len cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) - if HV > H_qk: - inputs = prepare_gva_inputs( - 1, T, H_qk, HV, D, device, - cu_seqlens=cu_seqlens, - has_init_state=HAS_INIT_STATE, - ) - else: - inputs = prepare_safe_gate_inputs( - 1, T, H_qk, D, device, - cu_seqlens=cu_seqlens, - has_init_state=HAS_INIT_STATE, - ) + inputs = prepare_safe_gate_inputs( + 1, T, H_qk, D, device, + cu_seqlens=cu_seqlens, + has_init_state=HAS_INIT_STATE, + num_v_heads=HV, + ) + q, k, v, g, beta = inputs["q"], inputs["k"], inputs["v"], inputs["g"], inputs["beta"] + A_log, dt_bias = inputs["A_log"], inputs["dt_bias"] + scale, init_state, lower_bound = inputs["scale"], inputs["init_state"], inputs["lower_bound"] + + common = dict( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + A_log=A_log, + dt_bias=dt_bias, + init_state=init_state, + cu_seqlens=cu_seqlens, + lower_bound=lower_bound, + ) + + # Accuracy + o_fla, _ = run_fla(**common) + o_cula, _ = run_cula(**common) + torch.cuda.synchronize() + + rmse, rel_max, mean_diff = accuracy_stats(o_fla, o_cula) + + # Performance + ms_fla = time_kernel(lambda: run_fla(**common)) + ms_cula = time_kernel(lambda: run_cula(**common)) + speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf") + + n_seqs = len(seq_lens) + min_l, max_l = min(seq_lens), max(seq_lens) + avg_l = T // n_seqs + tag = f"{dist:>7s} {n_seqs:>2d}seqs T={T} [{min_l}..{max_l}] avg={avg_l}" + + results.append( + { + "tag": tag, + "dist": dist, + "T_total": T, + "n_seqs": n_seqs, + "H": H_qk, + "HV": HV, + "rmse": rmse, + "rel_max": rel_max, + "mean_diff": mean_diff, + "ms_fla": ms_fla, + "ms_cula": ms_cula, + "speedup": speedup, + } + ) + + del o_fla, o_cula, q, k, v, g, beta, A_log, dt_bias, inputs + torch.cuda.empty_cache() + + return results + + +# ============================================================ +# GVA-dedicated benchmarks (uses prepare_gva_inputs) +# ============================================================ +def bench_gva_fixed(configs): + """Fixed-length GVA benchmark using :func:`prepare_gva_inputs`. + + All configs must have HV > H. Data is prepared the same way as in the + KimiDeltaAttention layer: q/k/g/beta are first generated with H heads and + then expanded to HV heads via einops repeat before being fed to both cuLA + and FLA. + """ + print("\n" + "=" * 100) + print(f" GVA Fixed-Length Benchmark: cuLA fully-fused ({_SM_TAG}) vs FLA Triton") + print("=" * 100) + results = [] + + for B, T, H_qk, HV in configs: + assert HV > H_qk, f"GVA requires HV > H, got H={H_qk} HV={HV}" + set_seed(SEED) + device = torch.device("cuda") + torch.cuda.empty_cache() + + seq_lens = [T] * B + cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) + + inputs = prepare_gva_inputs( + B, T, H_qk, HV, D, device, + cu_seqlens=cu_seqlens, + has_init_state=HAS_INIT_STATE, + ) q, k, v, g, beta = inputs["q"], inputs["k"], inputs["v"], inputs["g"], inputs["beta"] A_log, dt_bias = inputs["A_log"], inputs["dt_bias"] scale, init_state, lower_bound = inputs["scale"], inputs["init_state"], inputs["lower_bound"] + gva_ratio = inputs["gva_ratio"] common = dict( q=q, @@ -314,10 +408,86 @@ def bench_varlen(configs): lower_bound=lower_bound, ) - # Accuracy: run once to warm up JIT, then measure on the second call - run_fla(**common) - run_cula(**common) + # Accuracy + o_fla, _ = run_fla(**common) + o_cula, _ = run_cula(**common) torch.cuda.synchronize() + + rmse, rel_max, mean_diff = accuracy_stats(o_fla, o_cula) + + # Performance + ms_fla = time_kernel(lambda: run_fla(**common)) + ms_cula = time_kernel(lambda: run_cula(**common)) + speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf") + + results.append( + { + "B": B, + "T": T, + "H": H_qk, + "HV": HV, + "gva_ratio": gva_ratio, + "rmse": rmse, + "rel_max": rel_max, + "mean_diff": mean_diff, + "ms_fla": ms_fla, + "ms_cula": ms_cula, + "speedup": speedup, + } + ) + + del o_fla, o_cula, q, k, v, g, beta, A_log, dt_bias, inputs + torch.cuda.empty_cache() + + return results + + +def bench_gva_varlen(configs): + """Varlen GVA benchmark using :func:`prepare_gva_inputs`. + + Configs are 5-tuples (seq_lens, total_len, dist, H, HV) as produced by + :func:`~benchmarks.utils.build_gva_varlen_configs`. + All configs must have HV > H. + """ + print("\n" + "=" * 100) + print(f" GVA Varlen Benchmark: cuLA fully-fused ({_SM_TAG}) vs FLA Triton") + print("=" * 100) + results = [] + + for seq_lens, total_len, dist, H_qk, HV in configs: + assert HV > H_qk, f"GVA requires HV > H, got H={H_qk} HV={HV}" + set_seed(SEED) + device = torch.device("cuda") + torch.cuda.empty_cache() + + T = total_len + cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) + + inputs = prepare_gva_inputs( + 1, T, H_qk, HV, D, device, + cu_seqlens=cu_seqlens, + has_init_state=HAS_INIT_STATE, + ) + q, k, v, g, beta = inputs["q"], inputs["k"], inputs["v"], inputs["g"], inputs["beta"] + A_log, dt_bias = inputs["A_log"], inputs["dt_bias"] + scale, init_state, lower_bound = inputs["scale"], inputs["init_state"], inputs["lower_bound"] + gva_ratio = inputs["gva_ratio"] + + common = dict( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + A_log=A_log, + dt_bias=dt_bias, + init_state=init_state, + cu_seqlens=cu_seqlens, + lower_bound=lower_bound, + ) + + # Accuracy o_fla, _ = run_fla(**common) o_cula, _ = run_cula(**common) torch.cuda.synchronize() @@ -342,6 +512,7 @@ def bench_varlen(configs): "n_seqs": n_seqs, "H": H_qk, "HV": HV, + "gva_ratio": gva_ratio, "rmse": rmse, "rel_max": rel_max, "mean_diff": mean_diff, @@ -352,7 +523,99 @@ def bench_varlen(configs): ) del o_fla, o_cula, q, k, v, g, beta, A_log, dt_bias, inputs - del common, scale, init_state, lower_bound, cu_seqlens + torch.cuda.empty_cache() + + return results + + +# ============================================================ +# GVA overhead benchmark +# (proves GVA adds no kernel cost vs a plain non-GVA run) +# ============================================================ +def bench_gva_overhead(configs): + """Quantify the kernel overhead introduced by GVA vs a plain non-GVA run. + + For every ``(B, T, H_qk, HV)`` config this function runs **cuLA only** + with two different input preparations — both produce tensors of identical + shape ``(1, B*T, HV, D)`` entering the kernel: + + * **baseline** – standard non-GVA: H = HV unique q/k heads, prepared via + :func:`prepare_safe_gate_inputs` with ``num_v_heads=HV``. + * **GVA** – grouped q/k: H < HV heads expanded to HV via + ``repeat_interleave``, prepared via :func:`prepare_gva_inputs`. + + Because the kernel receives identically-shaped tensors in both cases, the + extra work that GVA adds is *only* the ``repeat_interleave`` call done in + Python before the kernel is launched. A near-zero ``overhead%`` column in + the report confirms that the GVA feature introduces no measurable kernel + latency regression. + + Note: FLA is intentionally excluded; the comparison is purely cuLA vs cuLA. + """ + print("\n" + "=" * 100) + print(f" GVA Overhead Benchmark: cuLA GVA vs cuLA non-GVA (same kernel shape, {_SM_TAG})") + print("=" * 100) + results = [] + + for cfg in configs: + B, T, H_qk, HV = _normalize_fixed_config(cfg) + assert HV > H_qk, f"GVA overhead bench requires HV > H, got H={H_qk} HV={HV}" + gva_ratio = HV // H_qk + set_seed(SEED) + device = torch.device("cuda") + torch.cuda.empty_cache() + + seq_lens = [T] * B + cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) + + # ── baseline: non-GVA with HV heads (H == HV) ──────────────────────── + inp_base = prepare_safe_gate_inputs( + B, T, HV, D, device, + cu_seqlens=cu_seqlens, + has_init_state=HAS_INIT_STATE, + num_v_heads=HV, + ) + common_base = dict( + q=inp_base["q"], k=inp_base["k"], v=inp_base["v"], + g=inp_base["g"], beta=inp_base["beta"], + scale=inp_base["scale"], A_log=inp_base["A_log"], + dt_bias=inp_base["dt_bias"], init_state=inp_base["init_state"], + cu_seqlens=cu_seqlens, lower_bound=inp_base["lower_bound"], + ) + + # ── GVA: H_qk heads expanded to HV ─────────────────────────────────── + inp_gva = prepare_gva_inputs( + B, T, H_qk, HV, D, device, + cu_seqlens=cu_seqlens, + has_init_state=HAS_INIT_STATE, + ) + common_gva = dict( + q=inp_gva["q"], k=inp_gva["k"], v=inp_gva["v"], + g=inp_gva["g"], beta=inp_gva["beta"], + scale=inp_gva["scale"], A_log=inp_gva["A_log"], + dt_bias=inp_gva["dt_bias"], init_state=inp_gva["init_state"], + cu_seqlens=cu_seqlens, lower_bound=inp_gva["lower_bound"], + ) + + # ── performance ─────────────────────────────────────────────────────── + ms_base = time_kernel(lambda: run_cula(**common_base)) + ms_gva = time_kernel(lambda: run_cula(**common_gva)) + overhead_pct = (ms_gva - ms_base) / ms_base * 100.0 if ms_base > 0 else 0.0 + + results.append( + { + "B": B, + "T": T, + "H": H_qk, + "HV": HV, + "gva_ratio": gva_ratio, + "ms_base": ms_base, + "ms_gva": ms_gva, + "overhead_pct": overhead_pct, + } + ) + + del inp_base, inp_gva torch.cuda.empty_cache() return results @@ -361,9 +624,8 @@ def bench_varlen(configs): # ============================================================ # Report # ============================================================ -def print_report(fixed_results, varlen_results): +def print_report(fixed_results, varlen_results, gva_fixed_results=None, gva_varlen_results=None, overhead_results=None): sep = "=" * 120 - ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") print(f"\n\n{sep}") print(" BENCHMARK REPORT: cula_kda_fused_fwd (fully-fused)") print(f" cuLA {_SM_TAG} fully-fused vs FLA Triton") @@ -373,7 +635,6 @@ def print_report(fixed_results, varlen_results): ni = 1 if (NCU_MODE or SANITIZER_MODE) else N_ITERS mode_tag = " [NCU mode]" if NCU_MODE else (" [Sanitizer mode]" if SANITIZER_MODE else "") print(f" Warmup={wu} Iters={ni}{mode_tag}") - print(f" {ts}") print(sep) if fixed_results: @@ -385,13 +646,8 @@ def print_report(fixed_results, varlen_results): f"{'FLA(ms)':>9s} {'cuLA(ms)':>10s} {'Speedup':>8s}" ) print(f" {'─' * 110}") - prev_gva = False for r in fixed_results: - is_gva = r["HV"] > r["H"] - if is_gva and not prev_gva: - print(f" {'·' * 110} ← GVA (HV > H)") - prev_gva = is_gva - gva_tag = "yes" if is_gva else "no" + gva_tag = "yes" if r["HV"] > r["H"] else "no" print( f" {r['B']:3d} {r['T']:6d} {r['H']:3d} {r['HV']:3d} {gva_tag:>4s} │ " f"{r['rmse']:10.6f} {r['rel_max']:10.6f} {r['mean_diff']:10.6f} │ " @@ -408,13 +664,8 @@ def print_report(fixed_results, varlen_results): f"{'FLA(ms)':>9s} {'cuLA(ms)':>10s} {'Speedup':>8s}" ) print(f" {'─' * 120}") - prev_gva = False for r in varlen_results: - is_gva = r["HV"] > r["H"] - if is_gva and not prev_gva: - print(f" {'·' * 120} ← GVA (HV > H)") - prev_gva = is_gva - gva_tag = "yes" if is_gva else "no" + gva_tag = "yes" if r["HV"] > r["H"] else "no" print( f" {r['tag']:>45s} {r['H']:3d} {r['HV']:3d} {gva_tag:>4s} │ " f"{r['rmse']:10.6f} {r['rel_max']:10.6f} {r['mean_diff']:10.6f} │ " @@ -422,6 +673,57 @@ def print_report(fixed_results, varlen_results): ) print(f" {'─' * 120}") + if gva_fixed_results: + print("\n [GVA Fixed-Length] (data prepared via prepare_gva_inputs)") + print(f" {'─' * 116}") + print( + f" {'B':>3s} {'T':>6s} {'H':>3s} {'HV':>3s} {'ratio':>5s} │ " + f"{'RMSE':>10s} {'rel_max':>10s} {'mean_diff':>10s} │ " + f"{'FLA(ms)':>9s} {'cuLA(ms)':>10s} {'Speedup':>8s}" + ) + print(f" {'─' * 116}") + for r in gva_fixed_results: + print( + f" {r['B']:3d} {r['T']:6d} {r['H']:3d} {r['HV']:3d} {r['gva_ratio']:4d}x │ " + f"{r['rmse']:10.6f} {r['rel_max']:10.6f} {r['mean_diff']:10.6f} │ " + f"{r['ms_fla']:9.4f} {r['ms_cula']:10.4f} {r['speedup']:7.2f}x" + ) + print(f" {'─' * 116}") + + if gva_varlen_results: + print("\n [GVA Varlen] (data prepared via prepare_gva_inputs)") + print(f" {'─' * 126}") + print( + f" {'Config':>45s} {'H':>3s} {'HV':>3s} {'ratio':>5s} │ " + f"{'RMSE':>10s} {'rel_max':>10s} {'mean_diff':>10s} │ " + f"{'FLA(ms)':>9s} {'cuLA(ms)':>10s} {'Speedup':>8s}" + ) + print(f" {'─' * 126}") + for r in gva_varlen_results: + print( + f" {r['tag']:>45s} {r['H']:3d} {r['HV']:3d} {r['gva_ratio']:4d}x │ " + f"{r['rmse']:10.6f} {r['rel_max']:10.6f} {r['mean_diff']:10.6f} │ " + f"{r['ms_fla']:9.4f} {r['ms_cula']:10.4f} {r['speedup']:7.2f}x" + ) + print(f" {'─' * 126}") + + if overhead_results: + print("\n [GVA Overhead] cuLA GVA vs cuLA non-GVA — same kernel shape, same HV heads") + print(" (near-zero overhead% proves GVA adds no kernel latency)") + print(f" {'─' * 96}") + print( + f" {'B':>3s} {'T':>6s} {'H':>3s} {'HV':>3s} {'ratio':>5s} │ " + f"{'base(ms)':>10s} {'gva(ms)':>10s} {'overhead%':>10s}" + ) + print(f" {'─' * 96}") + for r in overhead_results: + flag = " ✓" if abs(r["overhead_pct"]) < 3.0 else " !" + print( + f" {r['B']:3d} {r['T']:6d} {r['H']:3d} {r['HV']:3d} {r['gva_ratio']:4d}x │ " + f"{r['ms_base']:10.4f} {r['ms_gva']:10.4f} {r['overhead_pct']:+9.2f}%{flag}" + ) + print(f" {'─' * 96}") + print(f"\n{sep}\n") @@ -433,9 +735,17 @@ def main(): parser.add_argument( "--mode", type=str, - default="both", - choices=["fixed", "varlen", "both"], - help="Which benchmark mode to run (default: both).", + default="all", + choices=["fixed", "varlen", "gva", "overhead", "both", "all"], + help=( + "Which benchmark mode to run (default: all). " + "'all' = fixed + varlen + gva + overhead. " + "'both' = fixed + varlen only (legacy alias). " + "'gva' runs the dedicated GVA benchmark (fixed + varlen, " + "data prepared via prepare_gva_inputs with multiple GVA ratios). " + "'overhead' compares cuLA-GVA vs cuLA-non-GVA at the same total " + "head count to prove GVA adds no kernel latency regression." + ), ) parser.add_argument( "--ncu", @@ -469,8 +779,10 @@ def main(): f"[Device] {torch.cuda.get_device_name(0)} compute capability {_SM_TAG} → using {cula_kda_fused_fwd.__module__}.{cula_kda_fused_fwd.__name__}" ) - # Fixed-length configs: (B, T) means H_qk=HV=H (no GVA); (B, T, H_qk, HV) + # ------------------------------------------------------------------ + # Fixed-length configs: (B, T) → H_qk=HV=H (no GVA); (B, T, H_qk, HV) # activates GVA when HV > H. The two forms can be freely mixed. + # ------------------------------------------------------------------ fixed_configs = [ # Non-GVA (H_qk == HV == H): (1, 512), @@ -483,13 +795,14 @@ def main(): (2, 4096), (2, 8192), (2, 16384), - # GVA (HV > H, same D=128): representative 2x and 4x ratios at key T values. - # Kept small intentionally — this script benchmarks the general KDA kernel; - # GVA configs are a supplementary check, not the primary focus. - (1, 4096, 32, 64), # 2x ratio - (1, 8192, 32, 64), # 2x ratio - (1, 4096, 16, 64), # 4x ratio - (1, 8192, 16, 64), # 4x ratio + # GVA (HV > H, same D=128): + (1, 1024, 16, 64), + (1, 4096, 16, 64), + (1, 8192, 16, 64), + (1, 4096, 32, 64), + (1, 8192, 32, 64), + (2, 4096, 16, 64), + (2, 8192, 16, 64), ] # Varlen configs: 3-tuples (seq_lens, total_len, dist) default to no GVA; @@ -499,26 +812,62 @@ def main(): total_lens=(4096, 8192, 16384), dists=("uniform", "random", "skewed"), ) - # GVA varlen: one representative config per total-length (random dist, 10 seqs, - # 4x ratio H=16→HV=64). Kept minimal — supplementary GVA coverage only. - gva_varlen = [ - (seq_lens, T, dist, 16, 64) - for (seq_lens, T, dist) in varlen_configs_base - if dist == "random" and len(seq_lens) == 10 + # A small GVA subset reuses the non-GVA varlen shapes with (H_qk=16, HV=64). + gva_varlen_mixed = [(seq_lens, T, dist, 16, 64) for (seq_lens, T, dist) in varlen_configs_base if T <= 8192] + varlen_configs = list(varlen_configs_base) + gva_varlen_mixed + + # ------------------------------------------------------------------ + # Dedicated GVA configs (multiple GVA ratios, uses prepare_gva_inputs) + # ------------------------------------------------------------------ + gva_fixed_configs = build_gva_fixed_configs( + batch_sizes=(1, 2), + seq_lens=(1024, 4096, 8192), + h_hv_pairs=((8, 32), (16, 64), (32, 64), (16, 128)), + ) + gva_varlen_configs = build_gva_varlen_configs( + h_hv_pairs=((16, 64), (32, 64)), + num_seqs_list=(10, 20), + total_lens=(4096, 8192), + dists=("uniform", "random", "skewed"), + ) + + # Overhead configs: (B, T, H_qk, HV) — HV > H required. + # For each row the benchmark runs cuLA twice: + # baseline → non-GVA with HV heads (H == HV) + # gva → GVA with H_qk heads expanded to HV + # Both kernel inputs have shape (1, B*T, HV, D), so overhead% ≈ 0 proves + # that GVA adds no kernel latency regression. + overhead_configs = [ + (1, 1024, 16, 64), # 4x GVA ratio + (1, 4096, 16, 64), + (1, 8192, 16, 64), + (1, 16384, 16, 64), + (1, 4096, 32, 64), # 2x GVA ratio + (1, 8192, 32, 64), + (1, 4096, 8, 64), # 8x GVA ratio + (1, 8192, 8, 64), + (2, 4096, 16, 64), + (2, 8192, 16, 64), ] - varlen_configs = list(varlen_configs_base) + gva_varlen - fixed_res, varlen_res = [], [] + fixed_res, varlen_res, gva_fixed_res, gva_varlen_res, overhead_res = [], [], [], [], [] - if args.mode in ("fixed", "both"): + if args.mode in ("fixed", "both", "all"): fixed_res = bench_fixed(fixed_configs) - if args.mode in ("varlen", "both"): + if args.mode in ("varlen", "both", "all"): varlen_res = bench_varlen(varlen_configs) - print_report(fixed_res, varlen_res) + if args.mode in ("gva", "all"): + gva_fixed_res = bench_gva_fixed(gva_fixed_configs) + gva_varlen_res = bench_gva_varlen(gva_varlen_configs) + + if args.mode in ("overhead", "all"): + overhead_res = bench_gva_overhead(overhead_configs) + + print_report(fixed_res, varlen_res, gva_fixed_res, gva_varlen_res, overhead_res) - return fixed_res, varlen_res + return fixed_res, varlen_res, gva_fixed_res, gva_varlen_res, overhead_res if __name__ == "__main__": diff --git a/benchmarks/utils.py b/benchmarks/utils.py index 5c5f918..dd0eed5 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -250,6 +250,66 @@ def build_varlen_configs( return configs +def build_gva_fixed_configs( + batch_sizes=(1, 2), + seq_lens=(1024, 4096, 8192), + h_hv_pairs=((8, 32), (16, 64), (32, 64), (16, 128)), +): + """Build (B, T, H, HV) fixed-length configs for dedicated GVA benchmarks. + + All returned configs satisfy HV > H (GVA is always active). The + ``h_hv_pairs`` argument lets callers control which (H, HV) / GVA-ratio + combinations are explored; the default covers 4x and 2x ratios used by + real KimiDeltaAttention deployments. + + Returns: + list of (B: int, T: int, H: int, HV: int) + """ + assert all(HV > H and HV % H == 0 for H, HV in h_hv_pairs), ( + "Every (H, HV) pair must satisfy HV > H and HV % H == 0." + ) + return [ + (B, T, H, HV) + for B in batch_sizes + for T in seq_lens + for H, HV in h_hv_pairs + ] + + +def build_gva_varlen_configs( + h_hv_pairs=((16, 64), (32, 64)), + num_seqs_list=(10, 20), + total_lens=(4096, 8192), + dists=("uniform", "random", "skewed"), + random_seed=42, +): + """Build varlen GVA configs as 5-tuples (seq_lens, total_len, dist, H, HV). + + Generates the non-GVA base configs via :func:`build_varlen_configs`, then + attaches each ``(H, HV)`` pair from *h_hv_pairs* to produce dedicated GVA + 5-tuples ready for consumption by the varlen benchmark runner. + + All (H, HV) pairs must satisfy HV > H and HV % H == 0. + + Returns: + list of (seq_lens: list[int], total_len: int, dist: str, H: int, HV: int) + """ + assert all(HV > H and HV % H == 0 for H, HV in h_hv_pairs), ( + "Every (H, HV) pair must satisfy HV > H and HV % H == 0." + ) + base = build_varlen_configs( + num_seqs_list=num_seqs_list, + total_lens=total_lens, + dists=dists, + random_seed=random_seed, + ) + return [ + (seq_lens, T, dist, H, HV) + for H, HV in h_hv_pairs + for seq_lens, T, dist in base + ] + + # ============================================================================== # Common input preparation functions for benchmarks and demos # ============================================================================== @@ -358,61 +418,67 @@ def prepare_safe_gate_inputs( def prepare_gva_inputs( - batch_size, - T, - H, - HV, - D, - device, - cu_seqlens=None, - chunk_size=CHUNK_SIZE, - seed=SEED, - has_init_state=False, + batch_size, T, H, HV, D, device, cu_seqlens=None, chunk_size=CHUNK_SIZE, seed=SEED, has_init_state=False ): """Prepare inputs for GVA (Grouped Value Attention) benchmarks. - GVA uses separate head counts for QK (H) and V (HV), where HV > H and - HV is a positive multiple of H. The grouping factor is G = HV // H. + In GVA, num_v_heads (HV) > num_heads (H), with HV divisible by H. + q/k/g/beta are generated with H heads then repeated gva_ratio = HV // H times + to match v's HV heads, mirroring what KimiDeltaAttention does at the layer level. - Layout produced (after GVA expansion): - q, k, g, beta : HV heads (q/k expanded from H via repeat_interleave) - v : HV heads - A_log, dt_bias: sized to HV + All tensors are flattened to (1, B*T, ...) for cu_seqlens compatibility. + """ + assert HV % H == 0, f"HV={HV} must be divisible by H={H}" + gva_ratio = HV // H + dtype = torch.bfloat16 + scale = D ** (-0.5) - This is a thin, named wrapper around ``prepare_safe_gate_inputs`` with - ``num_v_heads=HV``. Having an explicit function makes the GVA data path - visible in benchmark scripts and easier to discover. + set_seed(seed) - Args: - batch_size: Batch size ``B``. - T: Per-sequence token count. - H: Number of QK heads (group count). - HV: Number of V heads; must satisfy ``HV > H`` and ``HV % H == 0``. - D: Head dimension. - device: Target torch device. - cu_seqlens: Optional cumulative sequence-length tensor for varlen mode. - chunk_size: Chunk size passed to ``prepare_chunk_indices``. - seed: RNG seed for reproducibility. - has_init_state: If True, allocate a non-zero recurrent initial state. + # Base tensors with H heads (as produced before the repeat in the layer) + q_base = torch.randn(batch_size, T, H, D, dtype=dtype, device=device).requires_grad_(False) + k_base = torch.randn(batch_size, T, H, D, dtype=dtype, device=device).requires_grad_(False) + v = torch.randn(batch_size, T, HV, D, dtype=dtype, device=device).requires_grad_(False) + g_base = torch.randn(batch_size, T, H, D, dtype=dtype, device=device).requires_grad_(False) + beta_base = torch.randn(batch_size, T, H, dtype=torch.float, device=device).sigmoid().requires_grad_(False) - Returns: - Same dict as ``prepare_safe_gate_inputs``. - """ - if HV <= H: - raise ValueError(f"GVA requires HV > H, got H={H}, HV={HV}.") - if HV % H != 0: - raise ValueError(f"HV ({HV}) must be a positive multiple of H ({H}).") - return prepare_safe_gate_inputs( - batch_size=batch_size, - T=T, - H=H, - D=D, - device=device, + A_log = torch.randn(HV, dtype=torch.float, device=device).requires_grad_(False) + dt_bias = torch.randn(HV * D, dtype=torch.float, device=device).requires_grad_(False) + + # Expand q, k, g, beta to HV heads (same as the layer's repeat) + from einops import repeat as einops_repeat + q = einops_repeat(q_base, "b t h d -> b t (h r) d", r=gva_ratio).contiguous() + k = einops_repeat(k_base, "b t h d -> b t (h r) d", r=gva_ratio).contiguous() + g = einops_repeat(g_base, "b t h d -> b t (h r) d", r=gva_ratio).contiguous() + beta = einops_repeat(beta_base, "b t h -> b t (h r)", r=gva_ratio).contiguous() + + # Flatten to batch_size=1 for cu_seqlens compatibility + if batch_size != 1: + q, k, v, g, beta = map(lambda x: rearrange(x, "b t ... -> 1 (b t) ..."), (q, k, v, g, beta)) + + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + + init_state = None + if has_init_state: + num_seqs = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch_size + init_state = torch.randn(num_seqs, HV, D, D, dtype=torch.float, device=device).requires_grad_(False) + + return dict( + q=q, + k=k, + v=v, + g=g, + beta=beta, + A_log=A_log, + dt_bias=dt_bias, + scale=scale, cu_seqlens=cu_seqlens, - chunk_size=chunk_size, - seed=seed, - has_init_state=has_init_state, - num_v_heads=HV, + chunk_indices=chunk_indices, + init_state=init_state, + lower_bound=-5.0, + H=H, + HV=HV, + gva_ratio=gva_ratio, ) From e5d24c37fbb6758ce50130c5339670333efc5adc Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Sat, 9 May 2026 12:09:47 +0800 Subject: [PATCH 15/17] fix --- benchmarks/bench_kda_fused_fwd.py | 488 ++++-------------------- benchmarks/utils.py | 127 +----- csrc/api/kda_sm90.cu | 3 +- csrc/kda/sm90/collective/store_tma.hpp | 2 +- csrc/kda/sm90/kernel/tile_scheduler.hpp | 10 +- cula/kda/hopper_fused_fwd.py | 9 +- tests/test_kda_fused_fwd.py | 4 +- 7 files changed, 79 insertions(+), 564 deletions(-) diff --git a/benchmarks/bench_kda_fused_fwd.py b/benchmarks/bench_kda_fused_fwd.py index 135527d..db5e9e6 100644 --- a/benchmarks/bench_kda_fused_fwd.py +++ b/benchmarks/bench_kda_fused_fwd.py @@ -25,31 +25,20 @@ - Accuracy: RMSE, relative max diff between cuLA fully-fused and FLA Triton - Performance: kernel execution time (ms) with CUDA events -Modes (--mode, default: all): - - fixed: Fixed-length sequences, various (B, T, H, HV) configs. - GVA rows (HV > H) are mixed in alongside non-GVA rows. - - varlen: Variable-length sequences with 2-3x length variation. - Non-GVA base configs plus a GVA subset (H=16, HV=64). - - gva: Dedicated GVA benchmark (fixed + varlen) using prepare_gva_inputs. - Covers multiple GVA ratios (2x / 4x / 8x); compares cuLA vs FLA. - - overhead: GVA overhead benchmark — cuLA GVA vs cuLA non-GVA at the same - total head count (HV). Both paths present identical tensor shapes - to the kernel, so a near-zero overhead% proves that GVA adds no - measurable kernel latency regression. - - both: Fixed-length + varlen only (legacy alias, no gva/overhead). - - all: Run all of the above. - -Under GVA (HV > H), q/k are expanded from H to HV heads via -`repeat_interleave(..., dim=2)`, equivalent to the einops pattern -`repeat(x, "... h d -> ... (h g) d")`. This keeps FLA's `chunk_kda` -(which does not natively support GVA) and cuLA's SM100 fully-fused forward -(which requires q/k/v to share the head dim) on the same input layout. +Modes (--mode, default: both): + - fixed: Fixed-length sequences, various (B, T, H, HV) configs. + When HV > H the row is a GVA (Grouped Value Attention) workload; + GVA and MHA rows share the same sequence-length settings so they + can be compared side by side in a single table. + - varlen: Variable-length sequences with 2-3x length variation, same mixing + of GVA (HV > H) and MHA (HV == H) rows. + - both: Run fixed + varlen (default). Usage: - python bench_kda_fused_fwd.py [--mode fixed|varlen|gva|overhead|both|all] [--ncu] + python bench_kda_fused_fwd.py [--mode fixed|varlen|both] [--ncu] With --ncu, warmup=1 and iters=1 for ncu profiling: - ncu --set full -o report python bench_kda_fused_fwd.py --mode overhead --ncu + ncu --set full -o report python bench_kda_fused_fwd.py --mode varlen --ncu """ import argparse @@ -65,11 +54,8 @@ from benchmarks.utils import ( SEED, - build_gva_fixed_configs, - build_gva_varlen_configs, build_varlen_configs, exclusive_cumsum, - prepare_gva_inputs, prepare_safe_gate_inputs, set_seed, ) @@ -87,7 +73,8 @@ # Constants # ============================================================ # Default number of Q/K heads. Each benchmark config may additionally specify -# HV (number of V heads) to enable GVA (HV > H must be a positive multiple of H). +# HV (number of V heads); when HV > H the row runs in GVA mode (the kernel +# sees HV expanded q/k heads, prepared internally by prepare_safe_gate_inputs). H, D = 64, 128 WARMUP = 25 N_ITERS = 100 @@ -197,9 +184,12 @@ def _normalize_varlen_config(cfg): return seq_lens, total_len, dist, H, H if len(cfg) == 5: return cfg - raise ValueError( - f"Varlen config must be (seq_lens, total_len, dist) or (seq_lens, total_len, dist, H, HV), got {cfg!r}" - ) + raise ValueError(f"Varlen config must be (seq_lens, total_len, dist) or (seq_lens, total_len, dist, H, HV), got {cfg!r}") + + +def _gva_hint(H_qk, HV): + """Return a short tag marking GVA vs MHA rows for progress prints.""" + return f"[GVA {HV // H_qk}x]" if HV > H_qk else "[MHA]" # ============================================================ @@ -213,6 +203,7 @@ def bench_fixed(configs): for cfg in configs: B, T, H_qk, HV = _normalize_fixed_config(cfg) + print(f" {_gva_hint(H_qk, HV):>9s} B={B} T={T} H={H_qk} HV={HV}") set_seed(SEED) device = torch.device("cuda") torch.cuda.empty_cache() @@ -221,7 +212,11 @@ def bench_fixed(configs): cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) inputs = prepare_safe_gate_inputs( - B, T, H_qk, D, device, + B, + T, + H_qk, + D, + device, cu_seqlens=cu_seqlens, has_init_state=HAS_INIT_STATE, num_v_heads=HV, @@ -288,6 +283,7 @@ def bench_varlen(configs): for cfg in configs: seq_lens, total_len, dist, H_qk, HV = _normalize_varlen_config(cfg) + print(f" {_gva_hint(H_qk, HV):>9s} {dist} {len(seq_lens)}seqs T={total_len} H={H_qk} HV={HV}") set_seed(SEED) device = torch.device("cuda") torch.cuda.empty_cache() @@ -296,7 +292,11 @@ def bench_varlen(configs): cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) inputs = prepare_safe_gate_inputs( - 1, T, H_qk, D, device, + 1, + T, + H_qk, + D, + device, cu_seqlens=cu_seqlens, has_init_state=HAS_INIT_STATE, num_v_heads=HV, @@ -359,278 +359,16 @@ def bench_varlen(configs): return results -# ============================================================ -# GVA-dedicated benchmarks (uses prepare_gva_inputs) -# ============================================================ -def bench_gva_fixed(configs): - """Fixed-length GVA benchmark using :func:`prepare_gva_inputs`. - - All configs must have HV > H. Data is prepared the same way as in the - KimiDeltaAttention layer: q/k/g/beta are first generated with H heads and - then expanded to HV heads via einops repeat before being fed to both cuLA - and FLA. - """ - print("\n" + "=" * 100) - print(f" GVA Fixed-Length Benchmark: cuLA fully-fused ({_SM_TAG}) vs FLA Triton") - print("=" * 100) - results = [] - - for B, T, H_qk, HV in configs: - assert HV > H_qk, f"GVA requires HV > H, got H={H_qk} HV={HV}" - set_seed(SEED) - device = torch.device("cuda") - torch.cuda.empty_cache() - - seq_lens = [T] * B - cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) - - inputs = prepare_gva_inputs( - B, T, H_qk, HV, D, device, - cu_seqlens=cu_seqlens, - has_init_state=HAS_INIT_STATE, - ) - q, k, v, g, beta = inputs["q"], inputs["k"], inputs["v"], inputs["g"], inputs["beta"] - A_log, dt_bias = inputs["A_log"], inputs["dt_bias"] - scale, init_state, lower_bound = inputs["scale"], inputs["init_state"], inputs["lower_bound"] - gva_ratio = inputs["gva_ratio"] - - common = dict( - q=q, - k=k, - v=v, - g=g, - beta=beta, - scale=scale, - A_log=A_log, - dt_bias=dt_bias, - init_state=init_state, - cu_seqlens=cu_seqlens, - lower_bound=lower_bound, - ) - - # Accuracy - o_fla, _ = run_fla(**common) - o_cula, _ = run_cula(**common) - torch.cuda.synchronize() - - rmse, rel_max, mean_diff = accuracy_stats(o_fla, o_cula) - - # Performance - ms_fla = time_kernel(lambda: run_fla(**common)) - ms_cula = time_kernel(lambda: run_cula(**common)) - speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf") - - results.append( - { - "B": B, - "T": T, - "H": H_qk, - "HV": HV, - "gva_ratio": gva_ratio, - "rmse": rmse, - "rel_max": rel_max, - "mean_diff": mean_diff, - "ms_fla": ms_fla, - "ms_cula": ms_cula, - "speedup": speedup, - } - ) - - del o_fla, o_cula, q, k, v, g, beta, A_log, dt_bias, inputs - torch.cuda.empty_cache() - - return results - - -def bench_gva_varlen(configs): - """Varlen GVA benchmark using :func:`prepare_gva_inputs`. - - Configs are 5-tuples (seq_lens, total_len, dist, H, HV) as produced by - :func:`~benchmarks.utils.build_gva_varlen_configs`. - All configs must have HV > H. - """ - print("\n" + "=" * 100) - print(f" GVA Varlen Benchmark: cuLA fully-fused ({_SM_TAG}) vs FLA Triton") - print("=" * 100) - results = [] - - for seq_lens, total_len, dist, H_qk, HV in configs: - assert HV > H_qk, f"GVA requires HV > H, got H={H_qk} HV={HV}" - set_seed(SEED) - device = torch.device("cuda") - torch.cuda.empty_cache() - - T = total_len - cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) - - inputs = prepare_gva_inputs( - 1, T, H_qk, HV, D, device, - cu_seqlens=cu_seqlens, - has_init_state=HAS_INIT_STATE, - ) - q, k, v, g, beta = inputs["q"], inputs["k"], inputs["v"], inputs["g"], inputs["beta"] - A_log, dt_bias = inputs["A_log"], inputs["dt_bias"] - scale, init_state, lower_bound = inputs["scale"], inputs["init_state"], inputs["lower_bound"] - gva_ratio = inputs["gva_ratio"] - - common = dict( - q=q, - k=k, - v=v, - g=g, - beta=beta, - scale=scale, - A_log=A_log, - dt_bias=dt_bias, - init_state=init_state, - cu_seqlens=cu_seqlens, - lower_bound=lower_bound, - ) - - # Accuracy - o_fla, _ = run_fla(**common) - o_cula, _ = run_cula(**common) - torch.cuda.synchronize() - - rmse, rel_max, mean_diff = accuracy_stats(o_fla, o_cula) - - # Performance - ms_fla = time_kernel(lambda: run_fla(**common)) - ms_cula = time_kernel(lambda: run_cula(**common)) - speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf") - - n_seqs = len(seq_lens) - min_l, max_l = min(seq_lens), max(seq_lens) - avg_l = T // n_seqs - tag = f"{dist:>7s} {n_seqs:>2d}seqs T={T} [{min_l}..{max_l}] avg={avg_l}" - - results.append( - { - "tag": tag, - "dist": dist, - "T_total": T, - "n_seqs": n_seqs, - "H": H_qk, - "HV": HV, - "gva_ratio": gva_ratio, - "rmse": rmse, - "rel_max": rel_max, - "mean_diff": mean_diff, - "ms_fla": ms_fla, - "ms_cula": ms_cula, - "speedup": speedup, - } - ) - - del o_fla, o_cula, q, k, v, g, beta, A_log, dt_bias, inputs - torch.cuda.empty_cache() - - return results - - -# ============================================================ -# GVA overhead benchmark -# (proves GVA adds no kernel cost vs a plain non-GVA run) -# ============================================================ -def bench_gva_overhead(configs): - """Quantify the kernel overhead introduced by GVA vs a plain non-GVA run. - - For every ``(B, T, H_qk, HV)`` config this function runs **cuLA only** - with two different input preparations — both produce tensors of identical - shape ``(1, B*T, HV, D)`` entering the kernel: - - * **baseline** – standard non-GVA: H = HV unique q/k heads, prepared via - :func:`prepare_safe_gate_inputs` with ``num_v_heads=HV``. - * **GVA** – grouped q/k: H < HV heads expanded to HV via - ``repeat_interleave``, prepared via :func:`prepare_gva_inputs`. - - Because the kernel receives identically-shaped tensors in both cases, the - extra work that GVA adds is *only* the ``repeat_interleave`` call done in - Python before the kernel is launched. A near-zero ``overhead%`` column in - the report confirms that the GVA feature introduces no measurable kernel - latency regression. - - Note: FLA is intentionally excluded; the comparison is purely cuLA vs cuLA. - """ - print("\n" + "=" * 100) - print(f" GVA Overhead Benchmark: cuLA GVA vs cuLA non-GVA (same kernel shape, {_SM_TAG})") - print("=" * 100) - results = [] - - for cfg in configs: - B, T, H_qk, HV = _normalize_fixed_config(cfg) - assert HV > H_qk, f"GVA overhead bench requires HV > H, got H={H_qk} HV={HV}" - gva_ratio = HV // H_qk - set_seed(SEED) - device = torch.device("cuda") - torch.cuda.empty_cache() - - seq_lens = [T] * B - cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device) - - # ── baseline: non-GVA with HV heads (H == HV) ──────────────────────── - inp_base = prepare_safe_gate_inputs( - B, T, HV, D, device, - cu_seqlens=cu_seqlens, - has_init_state=HAS_INIT_STATE, - num_v_heads=HV, - ) - common_base = dict( - q=inp_base["q"], k=inp_base["k"], v=inp_base["v"], - g=inp_base["g"], beta=inp_base["beta"], - scale=inp_base["scale"], A_log=inp_base["A_log"], - dt_bias=inp_base["dt_bias"], init_state=inp_base["init_state"], - cu_seqlens=cu_seqlens, lower_bound=inp_base["lower_bound"], - ) - - # ── GVA: H_qk heads expanded to HV ─────────────────────────────────── - inp_gva = prepare_gva_inputs( - B, T, H_qk, HV, D, device, - cu_seqlens=cu_seqlens, - has_init_state=HAS_INIT_STATE, - ) - common_gva = dict( - q=inp_gva["q"], k=inp_gva["k"], v=inp_gva["v"], - g=inp_gva["g"], beta=inp_gva["beta"], - scale=inp_gva["scale"], A_log=inp_gva["A_log"], - dt_bias=inp_gva["dt_bias"], init_state=inp_gva["init_state"], - cu_seqlens=cu_seqlens, lower_bound=inp_gva["lower_bound"], - ) - - # ── performance ─────────────────────────────────────────────────────── - ms_base = time_kernel(lambda: run_cula(**common_base)) - ms_gva = time_kernel(lambda: run_cula(**common_gva)) - overhead_pct = (ms_gva - ms_base) / ms_base * 100.0 if ms_base > 0 else 0.0 - - results.append( - { - "B": B, - "T": T, - "H": H_qk, - "HV": HV, - "gva_ratio": gva_ratio, - "ms_base": ms_base, - "ms_gva": ms_gva, - "overhead_pct": overhead_pct, - } - ) - - del inp_base, inp_gva - torch.cuda.empty_cache() - - return results - - # ============================================================ # Report # ============================================================ -def print_report(fixed_results, varlen_results, gva_fixed_results=None, gva_varlen_results=None, overhead_results=None): +def print_report(fixed_results, varlen_results): sep = "=" * 120 print(f"\n\n{sep}") print(" BENCHMARK REPORT: cula_kda_fused_fwd (fully-fused)") print(f" cuLA {_SM_TAG} fully-fused vs FLA Triton") print(f" D={D} dtype=bf16 safe_gate=True has_init_state={HAS_INIT_STATE}") - print(f" GVA rows are those with HV > H (H, HV shown per row).") + print(" GVA rows are those with HV > H (H, HV shown per row).") wu = 1 if (NCU_MODE or SANITIZER_MODE) else WARMUP ni = 1 if (NCU_MODE or SANITIZER_MODE) else N_ITERS mode_tag = " [NCU mode]" if NCU_MODE else (" [Sanitizer mode]" if SANITIZER_MODE else "") @@ -647,7 +385,7 @@ def print_report(fixed_results, varlen_results, gva_fixed_results=None, gva_varl ) print(f" {'─' * 110}") for r in fixed_results: - gva_tag = "yes" if r["HV"] > r["H"] else "no" + gva_tag = f"{r['HV'] // r['H']}x" if r["HV"] > r["H"] else "no" print( f" {r['B']:3d} {r['T']:6d} {r['H']:3d} {r['HV']:3d} {gva_tag:>4s} │ " f"{r['rmse']:10.6f} {r['rel_max']:10.6f} {r['mean_diff']:10.6f} │ " @@ -665,7 +403,7 @@ def print_report(fixed_results, varlen_results, gva_fixed_results=None, gva_varl ) print(f" {'─' * 120}") for r in varlen_results: - gva_tag = "yes" if r["HV"] > r["H"] else "no" + gva_tag = f"{r['HV'] // r['H']}x" if r["HV"] > r["H"] else "no" print( f" {r['tag']:>45s} {r['H']:3d} {r['HV']:3d} {gva_tag:>4s} │ " f"{r['rmse']:10.6f} {r['rel_max']:10.6f} {r['mean_diff']:10.6f} │ " @@ -673,57 +411,6 @@ def print_report(fixed_results, varlen_results, gva_fixed_results=None, gva_varl ) print(f" {'─' * 120}") - if gva_fixed_results: - print("\n [GVA Fixed-Length] (data prepared via prepare_gva_inputs)") - print(f" {'─' * 116}") - print( - f" {'B':>3s} {'T':>6s} {'H':>3s} {'HV':>3s} {'ratio':>5s} │ " - f"{'RMSE':>10s} {'rel_max':>10s} {'mean_diff':>10s} │ " - f"{'FLA(ms)':>9s} {'cuLA(ms)':>10s} {'Speedup':>8s}" - ) - print(f" {'─' * 116}") - for r in gva_fixed_results: - print( - f" {r['B']:3d} {r['T']:6d} {r['H']:3d} {r['HV']:3d} {r['gva_ratio']:4d}x │ " - f"{r['rmse']:10.6f} {r['rel_max']:10.6f} {r['mean_diff']:10.6f} │ " - f"{r['ms_fla']:9.4f} {r['ms_cula']:10.4f} {r['speedup']:7.2f}x" - ) - print(f" {'─' * 116}") - - if gva_varlen_results: - print("\n [GVA Varlen] (data prepared via prepare_gva_inputs)") - print(f" {'─' * 126}") - print( - f" {'Config':>45s} {'H':>3s} {'HV':>3s} {'ratio':>5s} │ " - f"{'RMSE':>10s} {'rel_max':>10s} {'mean_diff':>10s} │ " - f"{'FLA(ms)':>9s} {'cuLA(ms)':>10s} {'Speedup':>8s}" - ) - print(f" {'─' * 126}") - for r in gva_varlen_results: - print( - f" {r['tag']:>45s} {r['H']:3d} {r['HV']:3d} {r['gva_ratio']:4d}x │ " - f"{r['rmse']:10.6f} {r['rel_max']:10.6f} {r['mean_diff']:10.6f} │ " - f"{r['ms_fla']:9.4f} {r['ms_cula']:10.4f} {r['speedup']:7.2f}x" - ) - print(f" {'─' * 126}") - - if overhead_results: - print("\n [GVA Overhead] cuLA GVA vs cuLA non-GVA — same kernel shape, same HV heads") - print(" (near-zero overhead% proves GVA adds no kernel latency)") - print(f" {'─' * 96}") - print( - f" {'B':>3s} {'T':>6s} {'H':>3s} {'HV':>3s} {'ratio':>5s} │ " - f"{'base(ms)':>10s} {'gva(ms)':>10s} {'overhead%':>10s}" - ) - print(f" {'─' * 96}") - for r in overhead_results: - flag = " ✓" if abs(r["overhead_pct"]) < 3.0 else " !" - print( - f" {r['B']:3d} {r['T']:6d} {r['H']:3d} {r['HV']:3d} {r['gva_ratio']:4d}x │ " - f"{r['ms_base']:10.4f} {r['ms_gva']:10.4f} {r['overhead_pct']:+9.2f}%{flag}" - ) - print(f" {'─' * 96}") - print(f"\n{sep}\n") @@ -735,16 +422,13 @@ def main(): parser.add_argument( "--mode", type=str, - default="all", - choices=["fixed", "varlen", "gva", "overhead", "both", "all"], + default="both", + choices=["fixed", "varlen", "both"], help=( - "Which benchmark mode to run (default: all). " - "'all' = fixed + varlen + gva + overhead. " - "'both' = fixed + varlen only (legacy alias). " - "'gva' runs the dedicated GVA benchmark (fixed + varlen, " - "data prepared via prepare_gva_inputs with multiple GVA ratios). " - "'overhead' compares cuLA-GVA vs cuLA-non-GVA at the same total " - "head count to prove GVA adds no kernel latency regression." + "Which benchmark mode to run (default: both). " + "GVA rows (HV > H) are mixed in alongside MHA rows (HV == H) " + "under the same sequence-length settings, so GVA and MHA can be " + "compared side by side." ), ) parser.add_argument( @@ -780,94 +464,56 @@ def main(): ) # ------------------------------------------------------------------ - # Fixed-length configs: (B, T) → H_qk=HV=H (no GVA); (B, T, H_qk, HV) - # activates GVA when HV > H. The two forms can be freely mixed. + # Fixed-length configs — MHA and GVA rows share the same (B, T) grid. + # (B, T) → MHA (H_qk = HV = default H) + # (B, T, H_qk, HV) → GVA when HV > H_qk (HV must be a multiple of H_qk) # ------------------------------------------------------------------ fixed_configs = [ - # Non-GVA (H_qk == HV == H): - (1, 512), + # MHA (HV == H): (1, 1024), (1, 4096), (1, 8192), (1, 16384), - (2, 512), - (2, 1024), (2, 4096), (2, 8192), - (2, 16384), - # GVA (HV > H, same D=128): - (1, 1024, 16, 64), - (1, 4096, 16, 64), - (1, 8192, 16, 64), - (1, 4096, 32, 64), - (1, 8192, 32, 64), - (2, 4096, 16, 64), - (2, 8192, 16, 64), + # GVA (HV > H) at the same (B, T) shapes for side-by-side comparison: + (1, 1024, 16, 64), # 4x + (1, 4096, 16, 64), # 4x + (1, 8192, 16, 64), # 4x + (1, 16384, 16, 64), # 4x + (1, 4096, 32, 64), # 2x + (1, 8192, 32, 64), # 2x + (1, 4096, 8, 64), # 8x + (1, 8192, 8, 64), # 8x + (2, 4096, 16, 64), # 4x + (2, 8192, 16, 64), # 4x ] - # Varlen configs: 3-tuples (seq_lens, total_len, dist) default to no GVA; - # extend to 5-tuples (..., H_qk, HV) to activate GVA on varlen workloads. + # Varlen configs — identical sequence-length layouts replayed with and + # without GVA so MHA and GVA can be compared row-by-row. varlen_configs_base = build_varlen_configs( - num_seqs_list=(10, 20), - total_lens=(4096, 8192, 16384), - dists=("uniform", "random", "skewed"), - ) - # A small GVA subset reuses the non-GVA varlen shapes with (H_qk=16, HV=64). - gva_varlen_mixed = [(seq_lens, T, dist, 16, 64) for (seq_lens, T, dist) in varlen_configs_base if T <= 8192] - varlen_configs = list(varlen_configs_base) + gva_varlen_mixed - - # ------------------------------------------------------------------ - # Dedicated GVA configs (multiple GVA ratios, uses prepare_gva_inputs) - # ------------------------------------------------------------------ - gva_fixed_configs = build_gva_fixed_configs( - batch_sizes=(1, 2), - seq_lens=(1024, 4096, 8192), - h_hv_pairs=((8, 32), (16, 64), (32, 64), (16, 128)), - ) - gva_varlen_configs = build_gva_varlen_configs( - h_hv_pairs=((16, 64), (32, 64)), num_seqs_list=(10, 20), total_lens=(4096, 8192), dists=("uniform", "random", "skewed"), ) - - # Overhead configs: (B, T, H_qk, HV) — HV > H required. - # For each row the benchmark runs cuLA twice: - # baseline → non-GVA with HV heads (H == HV) - # gva → GVA with H_qk heads expanded to HV - # Both kernel inputs have shape (1, B*T, HV, D), so overhead% ≈ 0 proves - # that GVA adds no kernel latency regression. - overhead_configs = [ - (1, 1024, 16, 64), # 4x GVA ratio - (1, 4096, 16, 64), - (1, 8192, 16, 64), - (1, 16384, 16, 64), - (1, 4096, 32, 64), # 2x GVA ratio - (1, 8192, 32, 64), - (1, 4096, 8, 64), # 8x GVA ratio - (1, 8192, 8, 64), - (2, 4096, 16, 64), - (2, 8192, 16, 64), + gva_varlen_mixed = [ + (seq_lens, T, dist, H_qk, HV) + for (H_qk, HV) in ((16, 64), (32, 64)) + for (seq_lens, T, dist) in varlen_configs_base ] + varlen_configs = list(varlen_configs_base) + gva_varlen_mixed - fixed_res, varlen_res, gva_fixed_res, gva_varlen_res, overhead_res = [], [], [], [], [] + fixed_res, varlen_res = [], [] - if args.mode in ("fixed", "both", "all"): + if args.mode in ("fixed", "both"): fixed_res = bench_fixed(fixed_configs) - if args.mode in ("varlen", "both", "all"): + if args.mode in ("varlen", "both"): varlen_res = bench_varlen(varlen_configs) - if args.mode in ("gva", "all"): - gva_fixed_res = bench_gva_fixed(gva_fixed_configs) - gva_varlen_res = bench_gva_varlen(gva_varlen_configs) - - if args.mode in ("overhead", "all"): - overhead_res = bench_gva_overhead(overhead_configs) - - print_report(fixed_res, varlen_res, gva_fixed_res, gva_varlen_res, overhead_res) + print_report(fixed_res, varlen_res) - return fixed_res, varlen_res, gva_fixed_res, gva_varlen_res, overhead_res + return fixed_res, varlen_res if __name__ == "__main__": diff --git a/benchmarks/utils.py b/benchmarks/utils.py index dd0eed5..7fa3e13 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -250,66 +250,6 @@ def build_varlen_configs( return configs -def build_gva_fixed_configs( - batch_sizes=(1, 2), - seq_lens=(1024, 4096, 8192), - h_hv_pairs=((8, 32), (16, 64), (32, 64), (16, 128)), -): - """Build (B, T, H, HV) fixed-length configs for dedicated GVA benchmarks. - - All returned configs satisfy HV > H (GVA is always active). The - ``h_hv_pairs`` argument lets callers control which (H, HV) / GVA-ratio - combinations are explored; the default covers 4x and 2x ratios used by - real KimiDeltaAttention deployments. - - Returns: - list of (B: int, T: int, H: int, HV: int) - """ - assert all(HV > H and HV % H == 0 for H, HV in h_hv_pairs), ( - "Every (H, HV) pair must satisfy HV > H and HV % H == 0." - ) - return [ - (B, T, H, HV) - for B in batch_sizes - for T in seq_lens - for H, HV in h_hv_pairs - ] - - -def build_gva_varlen_configs( - h_hv_pairs=((16, 64), (32, 64)), - num_seqs_list=(10, 20), - total_lens=(4096, 8192), - dists=("uniform", "random", "skewed"), - random_seed=42, -): - """Build varlen GVA configs as 5-tuples (seq_lens, total_len, dist, H, HV). - - Generates the non-GVA base configs via :func:`build_varlen_configs`, then - attaches each ``(H, HV)`` pair from *h_hv_pairs* to produce dedicated GVA - 5-tuples ready for consumption by the varlen benchmark runner. - - All (H, HV) pairs must satisfy HV > H and HV % H == 0. - - Returns: - list of (seq_lens: list[int], total_len: int, dist: str, H: int, HV: int) - """ - assert all(HV > H and HV % H == 0 for H, HV in h_hv_pairs), ( - "Every (H, HV) pair must satisfy HV > H and HV % H == 0." - ) - base = build_varlen_configs( - num_seqs_list=num_seqs_list, - total_lens=total_lens, - dists=dists, - random_seed=random_seed, - ) - return [ - (seq_lens, T, dist, H, HV) - for H, HV in h_hv_pairs - for seq_lens, T, dist in base - ] - - # ============================================================================== # Common input preparation functions for benchmarks and demos # ============================================================================== @@ -350,7 +290,7 @@ def prepare_safe_gate_inputs( `kda_gate_chunk_cumsum` indexes them per head of `g`. This expanded layout works on both cuLA backends (SM100 requires q.shape[-2] == v.shape[-2]; SM90 accepts both native and expanded - GVA) and on FLA's `chunk_kda` (which does not natively support GVA). + GVA) and on FLA's `chunk_kda`. All tensors are flattened to (1, B*T, ...) for cu_seqlens compatibility. The returned dict always contains `H` and `HV` keys so callers can report @@ -417,71 +357,6 @@ def prepare_safe_gate_inputs( ) -def prepare_gva_inputs( - batch_size, T, H, HV, D, device, cu_seqlens=None, chunk_size=CHUNK_SIZE, seed=SEED, has_init_state=False -): - """Prepare inputs for GVA (Grouped Value Attention) benchmarks. - - In GVA, num_v_heads (HV) > num_heads (H), with HV divisible by H. - q/k/g/beta are generated with H heads then repeated gva_ratio = HV // H times - to match v's HV heads, mirroring what KimiDeltaAttention does at the layer level. - - All tensors are flattened to (1, B*T, ...) for cu_seqlens compatibility. - """ - assert HV % H == 0, f"HV={HV} must be divisible by H={H}" - gva_ratio = HV // H - dtype = torch.bfloat16 - scale = D ** (-0.5) - - set_seed(seed) - - # Base tensors with H heads (as produced before the repeat in the layer) - q_base = torch.randn(batch_size, T, H, D, dtype=dtype, device=device).requires_grad_(False) - k_base = torch.randn(batch_size, T, H, D, dtype=dtype, device=device).requires_grad_(False) - v = torch.randn(batch_size, T, HV, D, dtype=dtype, device=device).requires_grad_(False) - g_base = torch.randn(batch_size, T, H, D, dtype=dtype, device=device).requires_grad_(False) - beta_base = torch.randn(batch_size, T, H, dtype=torch.float, device=device).sigmoid().requires_grad_(False) - - A_log = torch.randn(HV, dtype=torch.float, device=device).requires_grad_(False) - dt_bias = torch.randn(HV * D, dtype=torch.float, device=device).requires_grad_(False) - - # Expand q, k, g, beta to HV heads (same as the layer's repeat) - from einops import repeat as einops_repeat - q = einops_repeat(q_base, "b t h d -> b t (h r) d", r=gva_ratio).contiguous() - k = einops_repeat(k_base, "b t h d -> b t (h r) d", r=gva_ratio).contiguous() - g = einops_repeat(g_base, "b t h d -> b t (h r) d", r=gva_ratio).contiguous() - beta = einops_repeat(beta_base, "b t h -> b t (h r)", r=gva_ratio).contiguous() - - # Flatten to batch_size=1 for cu_seqlens compatibility - if batch_size != 1: - q, k, v, g, beta = map(lambda x: rearrange(x, "b t ... -> 1 (b t) ..."), (q, k, v, g, beta)) - - chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None - - init_state = None - if has_init_state: - num_seqs = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch_size - init_state = torch.randn(num_seqs, HV, D, D, dtype=torch.float, device=device).requires_grad_(False) - - return dict( - q=q, - k=k, - v=v, - g=g, - beta=beta, - A_log=A_log, - dt_bias=dt_bias, - scale=scale, - cu_seqlens=cu_seqlens, - chunk_indices=chunk_indices, - init_state=init_state, - lower_bound=-5.0, - H=H, - HV=HV, - gva_ratio=gva_ratio, - ) - - def prepare_intra_inputs(batch_size, T, H, D, device, cu_seqlens=None, chunk_size=CHUNK_SIZE, seed=SEED): """Prepare preprocessed inputs ready for chunk_kda_fwd_intra. diff --git a/csrc/api/kda_sm90.cu b/csrc/api/kda_sm90.cu index a07958a..9e016eb 100644 --- a/csrc/api/kda_sm90.cu +++ b/csrc/api/kda_sm90.cu @@ -117,8 +117,7 @@ kda_fwd_prefill( beta.dtype()); TORCH_CHECK(beta.is_contiguous(), "beta must be contiguous"); TORCH_CHECK( - beta.size(0) == packed_seq && beta.size(1) == num_v_heads, - "beta shape must be [packed_seq, num_v_heads]"); + beta.size(0) == packed_seq && beta.size(1) == num_v_heads, "beta shape must be [packed_seq, num_v_heads]"); } if (input_state_.has_value()) { auto& input_state = input_state_.value(); diff --git a/csrc/kda/sm90/collective/store_tma.hpp b/csrc/kda/sm90/collective/store_tma.hpp index 6e2a784..0f7f7c1 100644 --- a/csrc/kda/sm90/collective/store_tma.hpp +++ b/csrc/kda/sm90/collective/store_tma.hpp @@ -195,7 +195,7 @@ struct CollectiveStoreTma { Tensor m_varlen_head = tma_store_.get_tma_tensor(make_shape( problem_size.head_size, problem_size.total_seqlen, - problem_size.num_v_heads)); // O lives in the V/O head space under GVA + problem_size.num_v_heads)); // O lives in the V/O head space under GVA Tensor m_varlen = m_varlen_head(_, _, work_desc.o_head_idx()); // slice into current head_idx Tensor m_offset = domain_offset( make_coord(_0{}, work_desc.tok_offset), diff --git a/csrc/kda/sm90/kernel/tile_scheduler.hpp b/csrc/kda/sm90/kernel/tile_scheduler.hpp index bdd79bf..c70c7a6 100644 --- a/csrc/kda/sm90/kernel/tile_scheduler.hpp +++ b/csrc/kda/sm90/kernel/tile_scheduler.hpp @@ -24,13 +24,13 @@ using namespace cute; struct WorkDesc { // coord - int32_t seq_idx; // which sequence to process - int32_t qk_head_idx; // head idx for Q/K (the representative of the GVA group) - int32_t head_idx; // head idx for V/O/g/beta - int64_t tok_offset; // start offset of this sequence in the packed tensor + int32_t seq_idx; // which sequence to process + int32_t qk_head_idx; // head idx for Q/K (the representative of the GVA group) + int32_t head_idx; // head idx for V/O/g/beta + int64_t tok_offset; // start offset of this sequence in the packed tensor // shape - int64_t seq_len; // length of this sequence + int64_t seq_len; // length of this sequence // update by mainloop int32_t tile_idx = 0; // current tile index (mutated by the mainloop) diff --git a/cula/kda/hopper_fused_fwd.py b/cula/kda/hopper_fused_fwd.py index f310e2f..c0399bb 100644 --- a/cula/kda/hopper_fused_fwd.py +++ b/cula/kda/hopper_fused_fwd.py @@ -24,6 +24,7 @@ import cula.cudac as cula_cuda from cula.utils import _get_cache_buf, assert_hopper, get_device_sm_count, prepare_uniform_cu_seqlens + class HopperChunkKDAFunction(torch.autograd.Function): @staticmethod @input_guard @@ -51,9 +52,7 @@ def forward( # GVA: q/k share num_qk_heads; v/g/beta share num_v_heads. # num_v_heads must be a positive multiple of num_qk_heads (heads_per_group = HV / H). assert q.shape == k.shape, "q and k must have the same shape." - assert q.shape[:2] == v.shape[:2] == g.shape[:2], ( - "q, k, v, g must share batch and sequence dimensions." - ) + assert q.shape[:2] == v.shape[:2] == g.shape[:2], "q, k, v, g must share batch and sequence dimensions." batch_size, seq_len, num_qk_heads, head_dim = q.shape num_v_heads = v.shape[-2] @@ -230,9 +229,7 @@ def cula_kda_prefill( raise ValueError(f"`lower_bound` must be in the safe range [-5, 0), got {lower_bound}.") assert q.shape == k.shape, "q and k must have the same shape." - assert q.shape[:2] == v.shape[:2] == g.shape[:2], ( - "q, k, v, g must share batch and sequence dimensions." - ) + assert q.shape[:2] == v.shape[:2] == g.shape[:2], "q, k, v, g must share batch and sequence dimensions." batch_size, seq_len, num_qk_heads, head_dim = q.shape num_v_heads = v.shape[-2] diff --git a/tests/test_kda_fused_fwd.py b/tests/test_kda_fused_fwd.py index f228459..354f0c9 100644 --- a/tests/test_kda_fused_fwd.py +++ b/tests/test_kda_fused_fwd.py @@ -49,9 +49,7 @@ [ pytest.param( *test, - id=("B{}-T{}-H{}-HV{}-D{}-gln{}-mask_p{}-l2norm{}-gate{}-safe_gate{}-init{}-{}").format( - *test - ), + id=("B{}-T{}-H{}-HV{}-D{}-gln{}-mask_p{}-l2norm{}-gate{}-safe_gate{}-init{}-{}").format(*test), ) for test in [ (1, 63, 1, 1, 128, 1, 0, False, False, True, True, torch.bfloat16), From 955e536095e2c7bf4a2ffd67760b2b2668f40442 Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Sat, 9 May 2026 14:32:03 +0800 Subject: [PATCH 16/17] benchmark --- BENCHMARK_H200.md | 88 ++++++++++++++++++++++++++++++----------------- 1 file changed, 56 insertions(+), 32 deletions(-) diff --git a/BENCHMARK_H200.md b/BENCHMARK_H200.md index 181e397..06e6e18 100644 --- a/BENCHMARK_H200.md +++ b/BENCHMARK_H200.md @@ -1,8 +1,8 @@ # Benchmark Results — Hopper (SM90) -> Auto-generated by `benchmarks/generate_benchmark_hopper_md.py` on 2026-04-05. +> Auto-generated by `benchmarks/generate_benchmark_hopper_md.py` on 2026-05-09. -> **GPU:** NVIDIA H200 | **CUDA:** 12.9 | **PyTorch:** 2.9.1+cu129 +> **GPU:** NVIDIA H20 | **CUDA:** 12.8 | **PyTorch:** 2.9.1+cu128 > FLA baseline: [flash-linear-attention v0.4.2](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.4.2) @@ -16,41 +16,65 @@ Fully-fused KDA forward prefill kernel (sm90). | B | T | FLA Triton (ms) | cuLA Fused (ms) | Speedup | |---|---|-----------------|-----------------|---------| -| 1 | 512 | 0.576 | 0.230 | **2.51x** | -| 1 | 1024 | 0.572 | 0.248 | **2.31x** | -| 1 | 4096 | 0.936 | 0.899 | **1.04x** | -| 1 | 8192 | 1.819 | 1.758 | **1.03x** | -| 1 | 16384 | 3.599 | 3.521 | **1.02x** | -| 2 | 512 | 0.569 | 0.228 | **2.49x** | -| 2 | 1024 | 0.572 | 0.306 | **1.87x** | -| 2 | 4096 | 1.818 | 1.108 | **1.64x** | -| 2 | 8192 | 3.605 | 2.210 | **1.63x** | -| 2 | 16384 | 7.173 | 4.485 | **1.60x** | +| 1 | 1024 | 0.444 | 0.265 | **1.67x** | +| 1 | 4096 | 1.425 | 0.942 | **1.51x** | +| 1 | 8192 | 2.789 | 1.848 | **1.51x** | +| 1 | 16384 | 5.526 | 3.698 | **1.49x** | +| 2 | 4096 | 2.781 | 1.869 | **1.49x** | +| 2 | 8192 | 5.508 | 3.692 | **1.49x** | +| 1 | 1024 | 0.447 | 0.265 | **1.69x** | +| 1 | 4096 | 1.425 | 0.942 | **1.51x** | +| 1 | 8192 | 2.790 | 1.853 | **1.51x** | +| 1 | 16384 | 5.517 | 3.683 | **1.50x** | +| 1 | 4096 | 1.418 | 0.936 | **1.51x** | +| 1 | 8192 | 2.776 | 1.834 | **1.51x** | +| 1 | 4096 | 1.421 | 0.938 | **1.52x** | +| 1 | 8192 | 2.781 | 1.836 | **1.51x** | +| 2 | 4096 | 2.774 | 1.858 | **1.49x** | +| 2 | 8192 | 5.486 | 3.666 | **1.50x** | ### Variable-Length (H=64, D=128, bf16) | Config | FLA Triton (ms) | cuLA Fused (ms) | Speedup | |--------|-----------------|-----------------|---------| -| uniform 10seqs T=4096 [409..415] avg=409 | 1.016 | 0.707 | **1.44x** | -| random 10seqs T=4096 [24..1201] avg=409 | 1.008 | 0.660 | **1.53x** | -| skewed 10seqs T=4096 [227..2053] avg=409 | 1.005 | 0.668 | **1.50x** | -| uniform 20seqs T=4096 [204..220] avg=204 | 1.087 | 0.919 | **1.18x** | -| random 20seqs T=4096 [5..787] avg=204 | 1.066 | 0.736 | **1.45x** | -| skewed 20seqs T=4096 [107..2063] avg=204 | 1.038 | 0.724 | **1.43x** | -| uniform 10seqs T=8192 [819..821] avg=819 | 1.855 | 1.179 | **1.57x** | -| random 10seqs T=8192 [48..2401] avg=819 | 1.893 | 1.215 | **1.56x** | -| skewed 10seqs T=8192 [455..4097] avg=819 | 1.906 | 1.209 | **1.58x** | -| uniform 20seqs T=8192 [409..421] avg=409 | 1.961 | 1.406 | **1.39x** | -| random 20seqs T=8192 [9..1574] avg=409 | 1.954 | 1.283 | **1.52x** | -| skewed 20seqs T=8192 [215..4107] avg=409 | 1.957 | 1.300 | **1.51x** | -| uniform 10seqs T=16384 [1638..1642] avg=1638 | 3.646 | 2.188 | **1.67x** | -| random 10seqs T=16384 [95..4802] avg=1638 | 3.646 | 2.306 | **1.58x** | -| skewed 10seqs T=16384 [910..8194] avg=1638 | 3.656 | 2.335 | **1.57x** | -| uniform 20seqs T=16384 [819..823] avg=819 | 3.679 | 2.355 | **1.56x** | -| random 20seqs T=16384 [19..3147] avg=819 | 3.713 | 2.323 | **1.60x** | -| skewed 20seqs T=16384 [431..8195] avg=819 | 3.670 | 2.384 | **1.54x** | - -Summary (28 configs): **avg=1.58x**, min=1.02x, max=2.51x. +| uniform 10seqs T=4096 [409..415] avg=409 | 1.500 | 1.157 | **1.30x** | +| random 10seqs T=4096 [24..1201] avg=409 | 1.481 | 1.029 | **1.44x** | +| skewed 10seqs T=4096 [227..2053] avg=409 | 1.486 | 1.006 | **1.48x** | +| uniform 20seqs T=4096 [204..220] avg=204 | 1.639 | 1.327 | **1.24x** | +| random 20seqs T=4096 [5..787] avg=204 | 1.585 | 1.134 | **1.40x** | +| skewed 20seqs T=4096 [107..2063] avg=204 | 1.530 | 1.107 | **1.38x** | +| uniform 10seqs T=8192 [819..821] avg=819 | 2.782 | 1.939 | **1.43x** | +| random 10seqs T=8192 [48..2401] avg=819 | 2.823 | 1.878 | **1.50x** | +| skewed 10seqs T=8192 [455..4097] avg=819 | 2.830 | 1.887 | **1.50x** | +| uniform 20seqs T=8192 [409..421] avg=409 | 2.910 | 2.197 | **1.32x** | +| random 20seqs T=8192 [9..1574] avg=409 | 2.907 | 1.968 | **1.48x** | +| skewed 20seqs T=8192 [215..4107] avg=409 | 2.919 | 2.028 | **1.44x** | +| uniform 10seqs T=4096 [409..415] avg=409 | 1.501 | 1.161 | **1.29x** | +| random 10seqs T=4096 [24..1201] avg=409 | 1.481 | 1.032 | **1.43x** | +| skewed 10seqs T=4096 [227..2053] avg=409 | 1.486 | 1.006 | **1.48x** | +| uniform 20seqs T=4096 [204..220] avg=204 | 1.638 | 1.326 | **1.24x** | +| random 20seqs T=4096 [5..787] avg=204 | 1.585 | 1.137 | **1.39x** | +| skewed 20seqs T=4096 [107..2063] avg=204 | 1.531 | 1.113 | **1.38x** | +| uniform 10seqs T=8192 [819..821] avg=819 | 2.783 | 1.938 | **1.44x** | +| random 10seqs T=8192 [48..2401] avg=819 | 2.823 | 1.877 | **1.50x** | +| skewed 10seqs T=8192 [455..4097] avg=819 | 2.828 | 1.885 | **1.50x** | +| uniform 20seqs T=8192 [409..421] avg=409 | 2.909 | 2.194 | **1.33x** | +| random 20seqs T=8192 [9..1574] avg=409 | 2.906 | 1.962 | **1.48x** | +| skewed 20seqs T=8192 [215..4107] avg=409 | 2.921 | 2.024 | **1.44x** | +| uniform 10seqs T=4096 [409..415] avg=409 | 1.500 | 1.160 | **1.29x** | +| random 10seqs T=4096 [24..1201] avg=409 | 1.483 | 1.032 | **1.44x** | +| skewed 10seqs T=4096 [227..2053] avg=409 | 1.484 | 1.006 | **1.47x** | +| uniform 20seqs T=4096 [204..220] avg=204 | 1.638 | 1.324 | **1.24x** | +| random 20seqs T=4096 [5..787] avg=204 | 1.584 | 1.137 | **1.39x** | +| skewed 20seqs T=4096 [107..2063] avg=204 | 1.533 | 1.108 | **1.38x** | +| uniform 10seqs T=8192 [819..821] avg=819 | 2.782 | 1.939 | **1.43x** | +| random 10seqs T=8192 [48..2401] avg=819 | 2.818 | 1.876 | **1.50x** | +| skewed 10seqs T=8192 [455..4097] avg=819 | 2.829 | 1.883 | **1.50x** | +| uniform 20seqs T=8192 [409..421] avg=409 | 2.909 | 2.196 | **1.32x** | +| random 20seqs T=8192 [9..1574] avg=409 | 2.903 | 1.964 | **1.48x** | +| skewed 20seqs T=8192 [215..4107] avg=409 | 2.917 | 2.010 | **1.45x** | + +Summary (52 configs): **avg=1.45x**, min=1.24x, max=1.69x. To reproduce: From 09dbba94fad2dac1712dc01f85452a3df0195e70 Mon Sep 17 00:00:00 2001 From: sunnyxyli Date: Sat, 9 May 2026 14:36:01 +0800 Subject: [PATCH 17/17] benchmark --- BENCHMARK_H200.md | 88 +++++++++++++++++------------------------------ 1 file changed, 32 insertions(+), 56 deletions(-) diff --git a/BENCHMARK_H200.md b/BENCHMARK_H200.md index 06e6e18..181e397 100644 --- a/BENCHMARK_H200.md +++ b/BENCHMARK_H200.md @@ -1,8 +1,8 @@ # Benchmark Results — Hopper (SM90) -> Auto-generated by `benchmarks/generate_benchmark_hopper_md.py` on 2026-05-09. +> Auto-generated by `benchmarks/generate_benchmark_hopper_md.py` on 2026-04-05. -> **GPU:** NVIDIA H20 | **CUDA:** 12.8 | **PyTorch:** 2.9.1+cu128 +> **GPU:** NVIDIA H200 | **CUDA:** 12.9 | **PyTorch:** 2.9.1+cu129 > FLA baseline: [flash-linear-attention v0.4.2](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.4.2) @@ -16,65 +16,41 @@ Fully-fused KDA forward prefill kernel (sm90). | B | T | FLA Triton (ms) | cuLA Fused (ms) | Speedup | |---|---|-----------------|-----------------|---------| -| 1 | 1024 | 0.444 | 0.265 | **1.67x** | -| 1 | 4096 | 1.425 | 0.942 | **1.51x** | -| 1 | 8192 | 2.789 | 1.848 | **1.51x** | -| 1 | 16384 | 5.526 | 3.698 | **1.49x** | -| 2 | 4096 | 2.781 | 1.869 | **1.49x** | -| 2 | 8192 | 5.508 | 3.692 | **1.49x** | -| 1 | 1024 | 0.447 | 0.265 | **1.69x** | -| 1 | 4096 | 1.425 | 0.942 | **1.51x** | -| 1 | 8192 | 2.790 | 1.853 | **1.51x** | -| 1 | 16384 | 5.517 | 3.683 | **1.50x** | -| 1 | 4096 | 1.418 | 0.936 | **1.51x** | -| 1 | 8192 | 2.776 | 1.834 | **1.51x** | -| 1 | 4096 | 1.421 | 0.938 | **1.52x** | -| 1 | 8192 | 2.781 | 1.836 | **1.51x** | -| 2 | 4096 | 2.774 | 1.858 | **1.49x** | -| 2 | 8192 | 5.486 | 3.666 | **1.50x** | +| 1 | 512 | 0.576 | 0.230 | **2.51x** | +| 1 | 1024 | 0.572 | 0.248 | **2.31x** | +| 1 | 4096 | 0.936 | 0.899 | **1.04x** | +| 1 | 8192 | 1.819 | 1.758 | **1.03x** | +| 1 | 16384 | 3.599 | 3.521 | **1.02x** | +| 2 | 512 | 0.569 | 0.228 | **2.49x** | +| 2 | 1024 | 0.572 | 0.306 | **1.87x** | +| 2 | 4096 | 1.818 | 1.108 | **1.64x** | +| 2 | 8192 | 3.605 | 2.210 | **1.63x** | +| 2 | 16384 | 7.173 | 4.485 | **1.60x** | ### Variable-Length (H=64, D=128, bf16) | Config | FLA Triton (ms) | cuLA Fused (ms) | Speedup | |--------|-----------------|-----------------|---------| -| uniform 10seqs T=4096 [409..415] avg=409 | 1.500 | 1.157 | **1.30x** | -| random 10seqs T=4096 [24..1201] avg=409 | 1.481 | 1.029 | **1.44x** | -| skewed 10seqs T=4096 [227..2053] avg=409 | 1.486 | 1.006 | **1.48x** | -| uniform 20seqs T=4096 [204..220] avg=204 | 1.639 | 1.327 | **1.24x** | -| random 20seqs T=4096 [5..787] avg=204 | 1.585 | 1.134 | **1.40x** | -| skewed 20seqs T=4096 [107..2063] avg=204 | 1.530 | 1.107 | **1.38x** | -| uniform 10seqs T=8192 [819..821] avg=819 | 2.782 | 1.939 | **1.43x** | -| random 10seqs T=8192 [48..2401] avg=819 | 2.823 | 1.878 | **1.50x** | -| skewed 10seqs T=8192 [455..4097] avg=819 | 2.830 | 1.887 | **1.50x** | -| uniform 20seqs T=8192 [409..421] avg=409 | 2.910 | 2.197 | **1.32x** | -| random 20seqs T=8192 [9..1574] avg=409 | 2.907 | 1.968 | **1.48x** | -| skewed 20seqs T=8192 [215..4107] avg=409 | 2.919 | 2.028 | **1.44x** | -| uniform 10seqs T=4096 [409..415] avg=409 | 1.501 | 1.161 | **1.29x** | -| random 10seqs T=4096 [24..1201] avg=409 | 1.481 | 1.032 | **1.43x** | -| skewed 10seqs T=4096 [227..2053] avg=409 | 1.486 | 1.006 | **1.48x** | -| uniform 20seqs T=4096 [204..220] avg=204 | 1.638 | 1.326 | **1.24x** | -| random 20seqs T=4096 [5..787] avg=204 | 1.585 | 1.137 | **1.39x** | -| skewed 20seqs T=4096 [107..2063] avg=204 | 1.531 | 1.113 | **1.38x** | -| uniform 10seqs T=8192 [819..821] avg=819 | 2.783 | 1.938 | **1.44x** | -| random 10seqs T=8192 [48..2401] avg=819 | 2.823 | 1.877 | **1.50x** | -| skewed 10seqs T=8192 [455..4097] avg=819 | 2.828 | 1.885 | **1.50x** | -| uniform 20seqs T=8192 [409..421] avg=409 | 2.909 | 2.194 | **1.33x** | -| random 20seqs T=8192 [9..1574] avg=409 | 2.906 | 1.962 | **1.48x** | -| skewed 20seqs T=8192 [215..4107] avg=409 | 2.921 | 2.024 | **1.44x** | -| uniform 10seqs T=4096 [409..415] avg=409 | 1.500 | 1.160 | **1.29x** | -| random 10seqs T=4096 [24..1201] avg=409 | 1.483 | 1.032 | **1.44x** | -| skewed 10seqs T=4096 [227..2053] avg=409 | 1.484 | 1.006 | **1.47x** | -| uniform 20seqs T=4096 [204..220] avg=204 | 1.638 | 1.324 | **1.24x** | -| random 20seqs T=4096 [5..787] avg=204 | 1.584 | 1.137 | **1.39x** | -| skewed 20seqs T=4096 [107..2063] avg=204 | 1.533 | 1.108 | **1.38x** | -| uniform 10seqs T=8192 [819..821] avg=819 | 2.782 | 1.939 | **1.43x** | -| random 10seqs T=8192 [48..2401] avg=819 | 2.818 | 1.876 | **1.50x** | -| skewed 10seqs T=8192 [455..4097] avg=819 | 2.829 | 1.883 | **1.50x** | -| uniform 20seqs T=8192 [409..421] avg=409 | 2.909 | 2.196 | **1.32x** | -| random 20seqs T=8192 [9..1574] avg=409 | 2.903 | 1.964 | **1.48x** | -| skewed 20seqs T=8192 [215..4107] avg=409 | 2.917 | 2.010 | **1.45x** | - -Summary (52 configs): **avg=1.45x**, min=1.24x, max=1.69x. +| uniform 10seqs T=4096 [409..415] avg=409 | 1.016 | 0.707 | **1.44x** | +| random 10seqs T=4096 [24..1201] avg=409 | 1.008 | 0.660 | **1.53x** | +| skewed 10seqs T=4096 [227..2053] avg=409 | 1.005 | 0.668 | **1.50x** | +| uniform 20seqs T=4096 [204..220] avg=204 | 1.087 | 0.919 | **1.18x** | +| random 20seqs T=4096 [5..787] avg=204 | 1.066 | 0.736 | **1.45x** | +| skewed 20seqs T=4096 [107..2063] avg=204 | 1.038 | 0.724 | **1.43x** | +| uniform 10seqs T=8192 [819..821] avg=819 | 1.855 | 1.179 | **1.57x** | +| random 10seqs T=8192 [48..2401] avg=819 | 1.893 | 1.215 | **1.56x** | +| skewed 10seqs T=8192 [455..4097] avg=819 | 1.906 | 1.209 | **1.58x** | +| uniform 20seqs T=8192 [409..421] avg=409 | 1.961 | 1.406 | **1.39x** | +| random 20seqs T=8192 [9..1574] avg=409 | 1.954 | 1.283 | **1.52x** | +| skewed 20seqs T=8192 [215..4107] avg=409 | 1.957 | 1.300 | **1.51x** | +| uniform 10seqs T=16384 [1638..1642] avg=1638 | 3.646 | 2.188 | **1.67x** | +| random 10seqs T=16384 [95..4802] avg=1638 | 3.646 | 2.306 | **1.58x** | +| skewed 10seqs T=16384 [910..8194] avg=1638 | 3.656 | 2.335 | **1.57x** | +| uniform 20seqs T=16384 [819..823] avg=819 | 3.679 | 2.355 | **1.56x** | +| random 20seqs T=16384 [19..3147] avg=819 | 3.713 | 2.323 | **1.60x** | +| skewed 20seqs T=16384 [431..8195] avg=819 | 3.670 | 2.384 | **1.54x** | + +Summary (28 configs): **avg=1.58x**, min=1.02x, max=2.51x. To reproduce: