diff --git a/csrc/api/kda_sm100.cu b/csrc/api/kda_sm100.cu index ac32411..ff89887 100644 --- a/csrc/api/kda_sm100.cu +++ b/csrc/api/kda_sm100.cu @@ -37,7 +37,33 @@ ChunkKDAFwdIntra( KDA_fwd_intra_params params; params.total_q_len = q.size(0) * q.size(1); params.b = cu_seqlens.size(0) - 1; - params.h = q.size(2); + // GVA: Q/K are in h_qk head space (from q.size(2)); g/beta/Aqk/Akk are in h_v head + // space (from g.size(2)). When HV == HQK, heads_per_group == 1 and behaviour matches + // the pre-GVA path. + params.h_qk = q.size(2); + params.h_v = g.size(2); + TORCH_CHECK( + k.size(2) == params.h_qk, + "ChunkKDAFwdIntra: k.size(2) (", + k.size(2), + ") must match q.size(2) (", + params.h_qk, + ") under GVA (Q/K share h_qk)."); + TORCH_CHECK( + beta.size(-1) == params.h_v, + "ChunkKDAFwdIntra: beta.size(-1) (", + beta.size(-1), + ") must equal h_v (", + params.h_v, + ")."); + TORCH_CHECK( + params.h_qk > 0 && params.h_v > 0 && params.h_v % params.h_qk == 0, + "ChunkKDAFwdIntra: h_v (", + params.h_v, + ") must be a positive multiple of h_qk (", + params.h_qk, + ")."); + params.heads_per_group = params.h_v / params.h_qk; params.d = q.size(3); params.chunk_size = chunk_size; params.scale = scale; @@ -56,13 +82,15 @@ ChunkKDAFwdIntra( params.chunk_indices_ptr = chunk_indices.data_ptr(); params.Aqk_out_ptr = Aqk_out.data_ptr(); params.Akk_out_ptr = Akk_out.data_ptr(); - params.shape_Akk = cute::make_shape(params.total_q_len, params.chunk_size, params.h); - params.stride_Akk = cute::make_stride(params.chunk_size * params.h, cute::_1{}, params.chunk_size); + // Akk is laid out per v-head: (total_len, chunk_size, h_v). + params.shape_Akk = cute::make_shape(params.total_q_len, params.chunk_size, params.h_v); + params.stride_Akk = cute::make_stride(params.chunk_size * params.h_v, cute::_1{}, params.chunk_size); int tile_num = chunk_indices.size(0); auto device_prop = at::cuda::getCurrentDeviceProperties(); params.num_sm = device_prop->multiProcessorCount; - params.tile_scheduler_params = - StaticPersistentTileScheduler::Params{tile_num, params.h, params.num_sm, (int*)tile_counter.data_ptr()}; + // Tiles are enumerated in v-head space. + params.tile_scheduler_params = StaticPersistentTileScheduler::Params{ + tile_num, params.h_v, params.heads_per_group, params.num_sm, (int*)tile_counter.data_ptr()}; kda::sm100::run_kda_fwd_intra_sm100(params, at::cuda::getCurrentCUDAStream()); } @@ -85,7 +113,31 @@ ChunkKDAFwdRecompWU( KDA_fwd_recomp_w_u_params params; params.total_len = k.size(0) * k.size(1); params.b = cu_seqlens.size(0) - 1; - params.h = k.size(2); + // GVA: K (and optional Q) live in h_qk space; V/G/beta/A/w/u/kg/qg live in h_v space. + params.h_qk = k.size(2); + params.h_v = v.size(2); + TORCH_CHECK( + g.size(2) == params.h_v, + "ChunkKDAFwdRecompWU: g.size(2) (", + g.size(2), + ") must equal v.size(2) (", + params.h_v, + ")."); + TORCH_CHECK( + beta.size(-1) == params.h_v, + "ChunkKDAFwdRecompWU: beta.size(-1) (", + beta.size(-1), + ") must equal h_v (", + params.h_v, + ")."); + TORCH_CHECK( + params.h_qk > 0 && params.h_v > 0 && params.h_v % params.h_qk == 0, + "ChunkKDAFwdRecompWU: h_v (", + params.h_v, + ") must be a positive multiple of h_qk (", + params.h_qk, + ")."); + params.heads_per_group = params.h_v / params.h_qk; params.d = k.size(3); params.chunk_size = chunk_size; TORCH_CHECK( @@ -108,14 +160,32 @@ ChunkKDAFwdRecompWU( TORCH_CHECK( has_q == has_qg_out, "ChunkKDAFwdRecompWU: q and qg_out must either both be provided or both be omitted."); params.store_qg = has_q && has_qg_out; + if (params.store_qg) { + TORCH_CHECK( + q->size(2) == params.h_qk, + "ChunkKDAFwdRecompWU: q.size(2) (", + q->size(2), + ") must equal h_qk (", + params.h_qk, + ")."); + TORCH_CHECK( + qg_out->size(2) == params.h_v, + "ChunkKDAFwdRecompWU: qg_out.size(2) (", + qg_out->size(2), + ") must equal h_v (", + params.h_v, + ")."); + } params.q_ptr = params.store_qg ? q->data_ptr() : nullptr; params.qg_out_ptr = params.store_qg ? qg_out->data_ptr() : nullptr; - params.shape_wukg = cute::make_shape(params.total_len, params.d, params.h); - params.stride_wukg = cute::make_stride(params.d * params.h, cute::_1{}, params.d); + // w/u/kg/qg are per v-head: (total_len, d, h_v). + params.shape_wukg = cute::make_shape(params.total_len, params.d, params.h_v); + params.stride_wukg = cute::make_stride(params.d * params.h_v, cute::_1{}, params.d); int tile_num = chunk_indices.size(0); auto device_prop = at::cuda::getCurrentDeviceProperties(); params.num_sm = device_prop->multiProcessorCount; - params.tile_scheduler_params = StaticPersistentTileScheduler::Params{tile_num, params.h, params.num_sm, nullptr}; + params.tile_scheduler_params = StaticPersistentTileScheduler::Params{ + tile_num, params.h_v, params.heads_per_group, params.num_sm, nullptr}; kda::sm100::run_kda_fwd_recomp_w_u_sm100(params, at::cuda::getCurrentCUDAStream()); } \ No newline at end of file diff --git a/csrc/kda/sm100/kda_config.hpp b/csrc/kda/sm100/kda_config.hpp index 6f96529..67b496a 100644 --- a/csrc/kda/sm100/kda_config.hpp +++ b/csrc/kda/sm100/kda_config.hpp @@ -17,12 +17,18 @@ #include "kda/sm100/tile_scheduler.hpp" struct KDA_fwd_intra_params { - using GmemShapeAkk = cute::Shape; // (seqlen_kv, seqlen_kv, h) + // Akk shape is (total_seqlen, chunk_size, num_v_heads). Under GVA (num_v_heads > num_qk_heads), + // Aqk and Akk are produced per v-head because g/beta/Akk scaling all live in v-head space. + using GmemShapeAkk = cute::Shape; // (seqlen_kv, chunk_size, h_v) using GmemStrideAkk = cute::Stride; int total_q_len; int b; - int h; + // GVA: Q/K are sized by num_qk_heads; V, g, beta are sized by num_v_heads; Aqk/Akk are per v-head. + // When num_v_heads == num_qk_heads, heads_per_group == 1 and behaviour matches the pre-GVA path. + int h_qk; + int h_v; + int heads_per_group; // = h_v / h_qk, precomputed on host int d; int chunk_size; float scale; @@ -30,12 +36,12 @@ struct KDA_fwd_intra_params { bool unified_gref; bool is_beta_bf16; - void* __restrict__ q_ptr; //[b, t, h, d] - void* __restrict__ k_ptr; //[b, t, h, d] - void* __restrict__ g_ptr; //[b, t, h, d] - void* __restrict__ beta_ptr; //[b, t, h] - void* __restrict__ Aqk_out_ptr; //[b, t, h, BT] - void* __restrict__ Akk_out_ptr; //[b, t, h, BT] + void* __restrict__ q_ptr; //[b, t, h_qk, d] + void* __restrict__ k_ptr; //[b, t, h_qk, d] + void* __restrict__ g_ptr; //[b, t, h_v, d] + void* __restrict__ beta_ptr; //[b, t, h_v] + void* __restrict__ Aqk_out_ptr; //[b, t, h_v, BT] + void* __restrict__ Akk_out_ptr; //[b, t, h_v, BT] void* __restrict__ cu_seqlens_ptr; //[b + 1] void* __restrict__ chunk_indices_ptr; //[(b * t) / chunk_size, 2] @@ -48,28 +54,32 @@ struct KDA_fwd_intra_params { }; struct KDA_fwd_recomp_w_u_params { - using GmemShapeWUKg = cute::Shape; // (seqlen_kv, seqlen_kv, h) + // w/u/kg/qg all have shape (total_seqlen, d, num_v_heads) under GVA. + using GmemShapeWUKg = cute::Shape; // (seqlen_kv, d, h_v) using GmemStrideWUKg = cute::Stride; int total_len; int b; - int h; + // GVA: K and (optional) Q are sized by num_qk_heads; V/G/beta/Akk/w/u/kg/qg are per v-head. + int h_qk; + int h_v; + int heads_per_group; // = h_v / h_qk, precomputed on host int d; int chunk_size; bool is_beta_bf16; - void* __restrict__ k_ptr; //[b, t, h, d] - void* __restrict__ v_ptr; //[b, t, h, d] - void* __restrict__ q_ptr; //[b, t, h, d] (optional, for StoreQG) - void* __restrict__ beta_ptr; //[b, t, h] - void* __restrict__ A_ptr; //[b. t, h, BT] - void* __restrict__ g_ptr; //[b, t, h, d] + void* __restrict__ k_ptr; //[b, t, h_qk, d] + void* __restrict__ v_ptr; //[b, t, h_v, d] + void* __restrict__ q_ptr; //[b, t, h_qk, d] (optional, for StoreQG) + void* __restrict__ beta_ptr; //[b, t, h_v] + void* __restrict__ A_ptr; //[b, t, h_v, BT] + void* __restrict__ g_ptr; //[b, t, h_v, d] void* __restrict__ cu_seqlens_ptr; //[b + 1] void* __restrict__ chunk_indices_ptr; //[(b * t) / chunk_size, 2] - void* __restrict__ w_out_ptr; //[b, t, h, d] - void* __restrict__ u_out_ptr; //[b, t, h, d] - void* __restrict__ kg_out_ptr; //[b, t, h, d] - void* __restrict__ qg_out_ptr; //[b, t, h, d] (optional, for StoreQG) + void* __restrict__ w_out_ptr; //[b, t, h_v, d] + void* __restrict__ u_out_ptr; //[b, t, h_v, d] + void* __restrict__ kg_out_ptr; //[b, t, h_v, d] + void* __restrict__ qg_out_ptr; //[b, t, h_v, d] (optional, for StoreQG) bool store_qg; diff --git a/csrc/kda/sm100/kda_fwd_intra_kernel_sm100.hpp b/csrc/kda/sm100/kda_fwd_intra_kernel_sm100.hpp index 60dc4b3..021bec8 100644 --- a/csrc/kda/sm100/kda_fwd_intra_kernel_sm100.hpp +++ b/csrc/kda/sm100/kda_fwd_intra_kernel_sm100.hpp @@ -53,8 +53,8 @@ struct KdaChunkFwdIntraKernelSm100 { using SmemLayoutInputFP32 = typename Mainloop::SmemLayoutInputFP32; // TMA params (for host launcher) - template - using TmaParams = typename Mainloop::template TmaParams; + template + using TmaParams = typename Mainloop::template TmaParams; // Pipeline types (for construction in operator()) using PipelineQKG = typename Mainloop::PipelineQKG; @@ -321,29 +321,40 @@ __launch_bounds__(512, 1, 1) kda_fwd_intra_sm100_kernel_entry( template inline void run_kda_fwd_intra_sm100_impl_dispatch(KDA_fwd_intra_params& params, cudaStream_t stream) { - auto shape_QKG = make_shape(params.total_q_len, params.d, params.h); - auto stride_QKG = make_stride(params.h * params.d, _1{}, params.d); + // GVA: Q/K are sized by `h_qk`; G is sized by `h_v`. When HV == HQK + // (heads_per_group == 1), shape_QK and shape_VG coincide with the + // pre-GVA shape_QKG and behaviour is unchanged. + auto shape_QK = make_shape(params.total_q_len, params.d, params.h_qk); + auto stride_QK = make_stride(params.h_qk * params.d, _1{}, params.d); + auto shape_VG = make_shape(params.total_q_len, params.d, params.h_v); + auto stride_VG = make_stride(params.h_v * params.d, _1{}, params.d); // --- Build TMA descriptors --- auto tma_Q = cute::make_tma_copy( SM90_TMA_LOAD{}, - make_tensor(make_gmem_ptr((ku::bf16*)params.q_ptr), make_layout(shape_QKG, stride_QKG)), + make_tensor(make_gmem_ptr((ku::bf16*)params.q_ptr), make_layout(shape_QK, stride_QK)), typename Kernel::SmemLayoutInputBF16{}); auto tma_K = cute::make_tma_copy( SM90_TMA_LOAD{}, - make_tensor(make_gmem_ptr((ku::bf16*)params.k_ptr), make_layout(shape_QKG, stride_QKG)), + make_tensor(make_gmem_ptr((ku::bf16*)params.k_ptr), make_layout(shape_QK, stride_QK)), typename Kernel::SmemLayoutInputBF16{}); auto tma_G = cute::make_tma_copy( SM90_TMA_LOAD{}, - make_tensor(make_gmem_ptr((float*)params.g_ptr), make_layout(shape_QKG, stride_QKG)), + make_tensor(make_gmem_ptr((float*)params.g_ptr), make_layout(shape_VG, stride_VG)), typename Kernel::SmemLayoutInputFP32{}); // --- Pack TMA params --- - typename Kernel::template TmaParams + typename Kernel::template TmaParams< + decltype(shape_QK), + decltype(shape_VG), + decltype(tma_Q), + decltype(tma_K), + decltype(tma_G)> tma_params = { - shape_QKG, + shape_QK, + shape_VG, tma_Q, tma_K, tma_G, diff --git a/csrc/kda/sm100/kda_fwd_intra_mainloop_sm100.hpp b/csrc/kda/sm100/kda_fwd_intra_mainloop_sm100.hpp index 3aa2746..68f6baa 100644 --- a/csrc/kda/sm100/kda_fwd_intra_mainloop_sm100.hpp +++ b/csrc/kda/sm100/kda_fwd_intra_mainloop_sm100.hpp @@ -227,9 +227,13 @@ struct KdaChunkFwdIntraMainloopSm100 { }; // ===================== TMA Params ===================== - template + // GVA: Q/K live in h_qk head space (shape_qk), while G lives in h_v + // head space (shape_vg). When h_v == h_qk both shapes coincide and the + // TMA descriptors degrade to the pre-GVA behaviour. + template struct TmaParams { - ShapeQKG shape_qkg; + ShapeQK shape_qk; + ShapeVG shape_vg; TMA_Q tma_q; TMA_K tma_k; TMA_G tma_g; @@ -318,7 +322,10 @@ struct KdaChunkFwdIntraMainloopSm100 { for (; tile_scheduler.is_valid(); tile_scheduler.advance()) { int tid = tile_scheduler.get_current_tile_id(); - auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h, chunk_indices_ptr, cu_seqlens_ptr); + // head_idx here is the v-head index (Aqk/Akk/beta/g live in v-head space). + // qk_head_idx is only consumed by the TMA load warp for Q/K slicing. + auto blk_coord = TileScheduler::decode_tile_coord( + tid, params.h_v, params.heads_per_group, chunk_indices_ptr, cu_seqlens_ptr); int batch_idx = get<0>(blk_coord); int head_idx = get<1>(blk_coord); int tile_idx = get<2>(blk_coord); @@ -502,7 +509,8 @@ struct KdaChunkFwdIntraMainloopSm100 { int token_offset = cu_seqlens_ptr[batch_idx]; int row = idx_in_warpgroup % 64; int BT = TileT; - int H = params.h; + // Aqk is laid out per v-head: row-stride is h_v * BT, head slot offset is head_idx * BT. + int H = params.h_v; __nv_bfloat16* Aqk_base = reinterpret_cast<__nv_bfloat16*>(params.Aqk_out_ptr); __nv_bfloat16* qk_out_row = Aqk_base + static_cast(token_offset + tile_idx * TileT + row) * H * BT + head_idx * BT; @@ -568,7 +576,10 @@ struct KdaChunkFwdIntraMainloopSm100 { for (; tile_scheduler.is_valid(); tile_scheduler.advance()) { int tid = tile_scheduler.get_current_tile_id(); - auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h, chunk_indices_ptr, cu_seqlens_ptr); + // MMA loop does not actually consume head_idx, but we decode to advance the + // same tile space as the other warps (num_blocks * num_v_heads). + auto blk_coord = TileScheduler::decode_tile_coord( + tid, params.h_v, params.heads_per_group, chunk_indices_ptr, cu_seqlens_ptr); int batch_idx = get<0>(blk_coord); int head_idx = get<1>(blk_coord); int tile_idx = get<2>(blk_coord); @@ -703,21 +714,24 @@ struct KdaChunkFwdIntraMainloopSm100 { for (; tile_scheduler.is_valid(); tile_scheduler.advance()) { int tid = tile_scheduler.get_current_tile_id(); - // Decode tile coordinates - auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h, chunk_indices_ptr, cu_seqlens_ptr); + // Decode tile coordinates. head_idx is the v-head index (used for G), + // and qk_head_idx is the companion Q/K head (computed from heads_per_group). + auto blk_coord = TileScheduler::decode_tile_coord( + tid, params.h_v, params.heads_per_group, chunk_indices_ptr, cu_seqlens_ptr); int batch_idx = get<0>(blk_coord); - int head_idx = get<1>(blk_coord); + int head_idx = get<1>(blk_coord); // v-head index int tile_idx = get<2>(blk_coord); + int qk_head_idx = get<3>(blk_coord); // == head_idx / heads_per_group int token_offset = cu_seqlens_ptr[batch_idx]; int seq_len = cu_seqlens_ptr[batch_idx + 1] - cu_seqlens_ptr[batch_idx]; int sub_seq_len = min(TileT, seq_len - tile_idx * TileT); Tensor mQ = domain_offset( - make_coord(token_offset, _0{}, _0{}), tma_params.tma_q.get_tma_tensor(tma_params.shape_qkg)); + make_coord(token_offset, _0{}, _0{}), tma_params.tma_q.get_tma_tensor(tma_params.shape_qk)); Tensor mK = domain_offset( - make_coord(token_offset, _0{}, _0{}), tma_params.tma_k.get_tma_tensor(tma_params.shape_qkg)); + make_coord(token_offset, _0{}, _0{}), tma_params.tma_k.get_tma_tensor(tma_params.shape_qk)); Tensor mG = domain_offset( - make_coord(token_offset, _0{}, _0{}), tma_params.tma_g.get_tma_tensor(tma_params.shape_qkg)); + make_coord(token_offset, _0{}, _0{}), tma_params.tma_g.get_tma_tensor(tma_params.shape_vg)); // TMA load body (Q, K, G — unified pipeline, single barrier per stage) CUTE_NO_UNROLL @@ -727,12 +741,13 @@ struct KdaChunkFwdIntraMainloopSm100 { Tensor sK = make_tensor(make_smem_ptr(shared_plan->k[buf_idx].data()), SmemLayoutInputBF16{}); Tensor sG = make_tensor(make_smem_ptr(shared_plan->g[buf_idx].data()), SmemLayoutInputFP32{}); + // GVA: K and Q are sliced by qk_head_idx; G is sliced by head_idx (v-head). Tensor gK = local_tile( - mK(_, _, head_idx), make_shape(Int{}, Int{}), make_coord(tile_idx, k_idx)); + mK(_, _, qk_head_idx), make_shape(Int{}, Int{}), make_coord(tile_idx, k_idx)); Tensor gG = local_tile( mG(_, _, head_idx), make_shape(Int{}, Int{}), make_coord(tile_idx, k_idx)); Tensor gQ = local_tile( - mQ(_, _, head_idx), make_shape(Int{}, Int{}), make_coord(tile_idx, k_idx)); + mQ(_, _, qk_head_idx), make_shape(Int{}, Int{}), make_coord(tile_idx, k_idx)); // Single acquire for all three TMA copies qkg_load_pipeline.producer_acquire(qkg_load_pipe_state_write); @@ -769,7 +784,9 @@ struct KdaChunkFwdIntraMainloopSm100 { for (; tile_scheduler.is_valid(); tile_scheduler.advance()) { int tid = tile_scheduler.get_current_tile_id(); - auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h, chunk_indices_ptr, cu_seqlens_ptr); + // Akk is laid out per v-head (params.shape_Akk uses h_v), so we index by head_idx. + auto blk_coord = TileScheduler::decode_tile_coord( + tid, params.h_v, params.heads_per_group, chunk_indices_ptr, cu_seqlens_ptr); int batch_idx = get<0>(blk_coord); int head_idx = get<1>(blk_coord); int tile_idx = get<2>(blk_coord); @@ -882,7 +899,9 @@ struct KdaChunkFwdIntraMainloopSm100 { for (; tile_scheduler.is_valid(); tile_scheduler.advance()) { int tid = tile_scheduler.get_current_tile_id(); - auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h, chunk_indices_ptr, cu_seqlens_ptr); + // beta is per v-head: layout (total_seqlen, h_v), row stride = h_v. + auto blk_coord = TileScheduler::decode_tile_coord( + tid, params.h_v, params.heads_per_group, chunk_indices_ptr, cu_seqlens_ptr); int batch_idx = get<0>(blk_coord); int head_idx = get<1>(blk_coord); int tile_idx = get<2>(blk_coord); @@ -896,7 +915,7 @@ struct KdaChunkFwdIntraMainloopSm100 { shared_plan->beta_smem[beta_pipe_state_write.index()][thread_idx] = (thread_idx < sub_seq_len) ? float(reinterpret_cast( - params.beta_ptr)[(token_offset + tile_idx * TileT + thread_idx) * params.h + head_idx]) + params.beta_ptr)[(token_offset + tile_idx * TileT + thread_idx) * params.h_v + head_idx]) : float(0); } fence_view_async_shared(); diff --git a/csrc/kda/sm100/kda_fwd_recomp_w_u_kernel_sm100.hpp b/csrc/kda/sm100/kda_fwd_recomp_w_u_kernel_sm100.hpp index 6bfa78c..a0a0318 100644 --- a/csrc/kda/sm100/kda_fwd_recomp_w_u_kernel_sm100.hpp +++ b/csrc/kda/sm100/kda_fwd_recomp_w_u_kernel_sm100.hpp @@ -41,14 +41,16 @@ struct KdaChunkFwdRecompWUKernelSm100 { // TMA params (for host launcher) template < - typename ShapeKVG, + typename ShapeQK, + typename ShapeVG, typename ShapeAkk, typename TMA_V, typename TMA_K, typename TMA_G, typename TMA_Akk, typename TMA_Q = int> - using TmaParams = typename Mainloop::template TmaParams; + using TmaParams = + typename Mainloop::template TmaParams; // Pipeline types (for construction in operator()) using PipelineA = typename Mainloop::PipelineA; @@ -429,25 +431,29 @@ __launch_bounds__(384, 1, 1) kda_fwd_recomp_w_u_sm100_kernel_entry( template inline void run_kda_fwd_recomp_w_u_sm100_impl_dispatch(KDA_fwd_recomp_w_u_params& params, cudaStream_t stream) { - auto shape_KVG = make_shape(params.total_len, params.d, params.h); - auto stride_KVG = make_stride(params.h * params.d, _1{}, params.d); - auto shape_Akk = make_shape(params.total_len, params.chunk_size, params.h); - auto stride_Akk = make_stride(params.h * params.chunk_size, _1{}, params.chunk_size); + // GVA: K and (optional) Q are sized by h_qk; V and G are sized by h_v. + // Akk lives in v-head space (BT x BT per v-head). + auto shape_QK = make_shape(params.total_len, params.d, params.h_qk); + auto stride_QK = make_stride(params.h_qk * params.d, _1{}, params.d); + auto shape_VG = make_shape(params.total_len, params.d, params.h_v); + auto stride_VG = make_stride(params.h_v * params.d, _1{}, params.d); + auto shape_Akk = make_shape(params.total_len, params.chunk_size, params.h_v); + auto stride_Akk = make_stride(params.h_v * params.chunk_size, _1{}, params.chunk_size); // --- Build TMA descriptors --- auto tma_V = cute::make_tma_copy( SM90_TMA_LOAD{}, - make_tensor(make_gmem_ptr((bf16*)params.v_ptr), make_layout(shape_KVG, stride_KVG)), + make_tensor(make_gmem_ptr((bf16*)params.v_ptr), make_layout(shape_VG, stride_VG)), typename Kernel::SmemLayoutInputBF16{}); auto tma_K = cute::make_tma_copy( SM90_TMA_LOAD{}, - make_tensor(make_gmem_ptr((bf16*)params.k_ptr), make_layout(shape_KVG, stride_KVG)), + make_tensor(make_gmem_ptr((bf16*)params.k_ptr), make_layout(shape_QK, stride_QK)), typename Kernel::SmemLayoutInputBF16{}); auto tma_G = cute::make_tma_copy( SM90_TMA_LOAD{}, - make_tensor(make_gmem_ptr((float*)params.g_ptr), make_layout(shape_KVG, stride_KVG)), + make_tensor(make_gmem_ptr((float*)params.g_ptr), make_layout(shape_VG, stride_VG)), typename Kernel::SmemLayoutInputFP32{}); auto tma_Akk = cute::make_tma_copy( @@ -455,12 +461,12 @@ run_kda_fwd_recomp_w_u_sm100_impl_dispatch(KDA_fwd_recomp_w_u_params& params, cu make_tensor(make_gmem_ptr((bf16*)params.A_ptr), make_layout(shape_Akk, stride_Akk)), typename Kernel::SmemLayoutInputAkkBF16{}); - // Q TMA descriptor (only meaningful when StoreQG=true) + // Q TMA descriptor (only meaningful when StoreQG=true). Q lives in h_qk head space. auto tma_Q = [&]() { if constexpr (Kernel::StoreQG) { return cute::make_tma_copy( SM90_TMA_LOAD{}, - make_tensor(make_gmem_ptr((bf16*)params.q_ptr), make_layout(shape_KVG, stride_KVG)), + make_tensor(make_gmem_ptr((bf16*)params.q_ptr), make_layout(shape_QK, stride_QK)), typename Kernel::SmemLayoutInputBF16{}); } else { return 0; // placeholder, not used @@ -469,14 +475,15 @@ run_kda_fwd_recomp_w_u_sm100_impl_dispatch(KDA_fwd_recomp_w_u_params& params, cu // --- Pack TMA params --- typename Kernel::template TmaParams< - decltype(shape_KVG), + decltype(shape_QK), + decltype(shape_VG), decltype(shape_Akk), decltype(tma_V), decltype(tma_K), decltype(tma_G), decltype(tma_Akk), decltype(tma_Q)> - tma_params = {shape_KVG, shape_Akk, tma_V, tma_K, tma_G, tma_Akk, tma_Q}; + tma_params = {shape_QK, shape_VG, shape_Akk, tma_V, tma_K, tma_G, tma_Akk, tma_Q}; // --- Launch config --- auto kernel_fn = &kda_fwd_recomp_w_u_sm100_kernel_entry; diff --git a/csrc/kda/sm100/kda_fwd_recomp_w_u_mainloop_sm100.hpp b/csrc/kda/sm100/kda_fwd_recomp_w_u_mainloop_sm100.hpp index 718e075..07bcf26 100644 --- a/csrc/kda/sm100/kda_fwd_recomp_w_u_mainloop_sm100.hpp +++ b/csrc/kda/sm100/kda_fwd_recomp_w_u_mainloop_sm100.hpp @@ -187,8 +187,11 @@ struct KdaChunkFwdRecompWUMainloopSm100 { }; // ===================== TMA Params ===================== + // GVA: K and (optional) Q live in h_qk head space (shape_qk), while V + // and G live in h_v head space (shape_vg). Akk is per v-head. template < - typename ShapeKVG, + typename ShapeQK, + typename ShapeVG, typename ShapeAkk, typename TMA_V, typename TMA_K, @@ -196,7 +199,8 @@ struct KdaChunkFwdRecompWUMainloopSm100 { typename TMA_Akk, typename TMA_Q = int> struct TmaParams { - ShapeKVG shape_kvg; + ShapeQK shape_qk; + ShapeVG shape_vg; ShapeAkk shape_Akk; TMA_V tma_v; TMA_K tma_k; @@ -252,7 +256,10 @@ struct KdaChunkFwdRecompWUMainloopSm100 { CUTE_NO_UNROLL for (; tile_scheduler.is_valid(); tile_scheduler.advance()) { int tid = tile_scheduler.get_current_tile_id(); - auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h, chunk_indices_ptr, cu_seqlens_ptr); + // Prologue touches K (h_qk) and G (h_v) + beta (h_v) + optional Q (h_qk). + // head_idx is the v-head index; qk_head_idx is derived via heads_per_group. + auto blk_coord = TileScheduler::decode_tile_coord( + tid, params.h_v, params.heads_per_group, chunk_indices_ptr, cu_seqlens_ptr); int batch_idx = get<0>(blk_coord); int head_idx = get<1>(blk_coord); int tile_idx = get<2>(blk_coord); @@ -632,7 +639,9 @@ struct KdaChunkFwdRecompWUMainloopSm100 { CUTE_NO_UNROLL for (; tile_scheduler.is_valid(); tile_scheduler.advance()) { int tid = tile_scheduler.get_current_tile_id(); - auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h, chunk_indices_ptr, cu_seqlens_ptr); + // Epilogue consumes V/beta (both h_v) and writes w/u/kg/qg (all h_v). + auto blk_coord = TileScheduler::decode_tile_coord( + tid, params.h_v, params.heads_per_group, chunk_indices_ptr, cu_seqlens_ptr); int batch_idx = get<0>(blk_coord); int head_idx = get<1>(blk_coord); int tile_idx = get<2>(blk_coord); @@ -732,9 +741,9 @@ struct KdaChunkFwdRecompWUMainloopSm100 { // each thread processes one row of W/U (TileK columns) int row = (idx_in_wg / 32) * 16 + (idx_in_wg % 16); - // GMEM output address: layout [total_len, d, h], stride [d*h, 1, d] + // GMEM output address: layout [total_len, d, h_v], stride [d*h_v, 1, d] __nv_bfloat16* out_row_base = - out_ptr_base + (token_offset_cur + row) * params.d * params.h + head_idx * params.d; + out_ptr_base + (token_offset_cur + row) * params.d * params.h_v + head_idx * params.d; constexpr int QuarK = TileK / 4; @@ -796,7 +805,8 @@ struct KdaChunkFwdRecompWUMainloopSm100 { CUTE_NO_UNROLL for (; tile_scheduler.is_valid(); tile_scheduler.advance()) { // int tid = tile_scheduler.get_current_tile_id(); - // auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h, chunk_indices_ptr, cu_seqlens_ptr); + // auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h_v, params.heads_per_group, + // chunk_indices_ptr, cu_seqlens_ptr); // ============================================================ // Once per WU: Wait for Akk in SMEM (from Load warp) @@ -876,31 +886,36 @@ struct KdaChunkFwdRecompWUMainloopSm100 { for (; tile_scheduler.is_valid(); tile_scheduler.advance()) { int tid = tile_scheduler.get_current_tile_id(); - // Decode tile coordinates - auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h, chunk_indices_ptr, cu_seqlens_ptr); + // Decode tile coordinates. head_idx is the v-head (used for V/G/Akk + // TMA loads); qk_head_idx (= head_idx / heads_per_group) is used for + // K/Q TMA loads under GVA. + auto blk_coord = TileScheduler::decode_tile_coord( + tid, params.h_v, params.heads_per_group, chunk_indices_ptr, cu_seqlens_ptr); int batch_idx = get<0>(blk_coord); - int head_idx = get<1>(blk_coord); + int head_idx = get<1>(blk_coord); // v-head int tile_idx = get<2>(blk_coord); + int qk_head_idx = get<3>(blk_coord); // qk-head int token_offset = cu_seqlens_ptr[batch_idx]; int seq_len = cu_seqlens_ptr[batch_idx + 1] - cu_seqlens_ptr[batch_idx]; int sub_seq_len = min(TileT, seq_len - tile_idx * TileT); // Build GMEM tensor views (with domain offset for batch) + // K and Q live in h_qk head space (shape_qk); V, G and Akk live in h_v space. Tensor mK = domain_offset( - make_coord(token_offset, _0{}, _0{}), tma_params.tma_k.get_tma_tensor(tma_params.shape_kvg)); + make_coord(token_offset, _0{}, _0{}), tma_params.tma_k.get_tma_tensor(tma_params.shape_qk)); Tensor mV = domain_offset( - make_coord(token_offset, _0{}, _0{}), tma_params.tma_v.get_tma_tensor(tma_params.shape_kvg)); + make_coord(token_offset, _0{}, _0{}), tma_params.tma_v.get_tma_tensor(tma_params.shape_vg)); Tensor mG = domain_offset( - make_coord(token_offset, _0{}, _0{}), tma_params.tma_g.get_tma_tensor(tma_params.shape_kvg)); + make_coord(token_offset, _0{}, _0{}), tma_params.tma_g.get_tma_tensor(tma_params.shape_vg)); Tensor mA = domain_offset( make_coord(token_offset, _0{}, _0{}), tma_params.tma_akk.get_tma_tensor(tma_params.shape_Akk)); - // Q GMEM tensor (only used when StoreQG=true) + // Q GMEM tensor (only used when StoreQG=true). Q lives in h_qk space. [[maybe_unused]] auto mQ = [&]() { if constexpr (StoreQG) { return domain_offset( make_coord(token_offset, _0{}, _0{}), - tma_params.tma_q.get_tma_tensor(tma_params.shape_kvg)); + tma_params.tma_q.get_tma_tensor(tma_params.shape_qk)); } else { return 0; // unused placeholder } @@ -933,8 +948,9 @@ struct KdaChunkFwdRecompWUMainloopSm100 { Tensor sG = make_tensor( make_smem_ptr(shared_plan->g[g_pipe_state_write.index()].data()), SmemLayoutInputFP32{}); + // GVA slicing: K uses qk_head_idx; V and G use the v-head index. Tensor gK = local_tile( - mK(_, _, head_idx), make_shape(Int{}, Int{}), make_coord(tile_idx, i_k)); + mK(_, _, qk_head_idx), make_shape(Int{}, Int{}), make_coord(tile_idx, i_k)); Tensor gV = local_tile( mV(_, _, head_idx), make_shape(Int{}, Int{}), make_coord(tile_idx, i_k)); Tensor gG = local_tile( @@ -960,8 +976,9 @@ struct KdaChunkFwdRecompWUMainloopSm100 { Tensor sQ = make_tensor( make_smem_ptr(shared_plan->q_buf.q[q_pipe_state_write.index()].data()), SmemLayoutInputBF16{}); + // Q (StoreQG) lives in h_qk space → slice by qk_head_idx. Tensor gQ = local_tile( - mQ(_, _, head_idx), make_shape(Int{}, Int{}), make_coord(tile_idx, i_k)); + mQ(_, _, qk_head_idx), make_shape(Int{}, Int{}), make_coord(tile_idx, i_k)); q_pipeline.producer_acquire(q_pipe_state_write); ku::launch_tma_copy( tma_params.tma_q, gQ, sQ, *q_pipeline.producer_get_barrier(q_pipe_state_write)); @@ -994,7 +1011,9 @@ struct KdaChunkFwdRecompWUMainloopSm100 { for (; tile_scheduler.is_valid(); tile_scheduler.advance()) { int tid = tile_scheduler.get_current_tile_id(); - auto blk_coord = TileScheduler::decode_tile_coord(tid, params.h, chunk_indices_ptr, cu_seqlens_ptr); + // LoadAux: beta is per v-head (row stride = h_v). + auto blk_coord = TileScheduler::decode_tile_coord( + tid, params.h_v, params.heads_per_group, chunk_indices_ptr, cu_seqlens_ptr); int batch_idx = get<0>(blk_coord); int head_idx = get<1>(blk_coord); int tile_idx = get<2>(blk_coord); @@ -1010,7 +1029,7 @@ struct KdaChunkFwdRecompWUMainloopSm100 { float beta_val = (thread_idx < sub_seq_len) ? float(reinterpret_cast( - params.beta_ptr)[(token_offset + tile_idx * TileT + thread_idx) * params.h + head_idx]) + params.beta_ptr)[(token_offset + tile_idx * TileT + thread_idx) * params.h_v + head_idx]) : float(0); shared_plan->beta_smem[beta_pipe_state_write.index()][thread_idx] = beta_val; } diff --git a/csrc/kda/sm100/tile_scheduler.hpp b/csrc/kda/sm100/tile_scheduler.hpp index 47044aa..695bb26 100644 --- a/csrc/kda/sm100/tile_scheduler.hpp +++ b/csrc/kda/sm100/tile_scheduler.hpp @@ -26,11 +26,20 @@ // No smem synchronization needed — every CTA processes tiles starting // at blockIdx.x and striding by gridDim.x. All warps within a CTA // independently maintain the same tile_id, so no tile pipeline is needed. +// +// GVA (Grouped V-head Attention) support: +// Q/K are sized by `num_qk_heads`; V, g, beta, O and state tensors are +// sized by `num_v_heads`. We enumerate tiles by `num_v_heads` so that +// each v-head is scheduled independently, and derive the companion +// `qk_head_idx = v_head_idx / heads_per_group` on the device side. +// `heads_per_group = num_v_heads / num_qk_heads` is precomputed on the +// host to avoid a per-tile integer division. // =================================================================== struct StaticPersistentTileScheduler { struct Params { - int num_blocks; // number of sequence chunks (from chunk_indices) - int num_heads; + int num_blocks; // number of sequence chunks (from chunk_indices) + int num_heads; // == num_v_heads; tiles are enumerated by v-head + int heads_per_group; // == num_v_heads / num_qk_heads, precomputed on host int num_sm; int* tile_counter; // unused @@ -77,14 +86,22 @@ struct StaticPersistentTileScheduler { return current_tile_id < total_tiles(); } + // Decode tile_id -> (batch_idx, v_head_idx, seq_idx, qk_head_idx). + // `num_v_heads` is the number of V/O/g/beta heads; tile enumeration is + // done in v-head space. `heads_per_group` (= num_v_heads/num_qk_heads) + // is used to derive the companion Q/K head index for GVA. + // For backward compatibility, when HV == HQK, `heads_per_group == 1` + // and `qk_head_idx == v_head_idx`. CUTLASS_DEVICE static auto - decode_tile_coord(int tile_id, int num_heads, int* chunk_indices_ptr, int* cu_seqlens_ptr) { + decode_tile_coord( + int tile_id, int num_v_heads, int heads_per_group, int* chunk_indices_ptr, int* /*cu_seqlens_ptr*/) { using namespace cute; - int tile_idx_raw = tile_id / num_heads; - int head_idx = tile_id % num_heads; + int tile_idx_raw = tile_id / num_v_heads; + int v_head_idx = tile_id % num_v_heads; + int qk_head_idx = v_head_idx / heads_per_group; int batch_idx = chunk_indices_ptr[tile_idx_raw * 2]; int seq_idx = chunk_indices_ptr[tile_idx_raw * 2 + 1]; - return make_coord(batch_idx, head_idx, seq_idx, 0); + return make_coord(batch_idx, v_head_idx, seq_idx, qk_head_idx); } }; \ No newline at end of file diff --git a/cula/kda/chunk_intra.py b/cula/kda/chunk_intra.py index 0703638..9fcc93a 100644 --- a/cula/kda/chunk_intra.py +++ b/cula/kda/chunk_intra.py @@ -759,7 +759,22 @@ def chunk_kda_fwd_intra( unified_gref: bool = False, # Set True for ~5% extra perf (slightly lower precision) ): assert safe_gate, "Only safe_gate=True is supported in chunk_kda_fwd_intra for now" - B, T, H, K = k.shape + # GVA support: Q/K have head-dim HQK; V/g/beta/Aqk/Akk/w/u/kg/qg have head-dim HV. + # Pre-GVA behaviour is preserved when HV == HQK. + B, T, HQK, K = k.shape + HV = v.shape[2] + assert v.shape[0] == B and v.shape[1] == T, ( + f"v must share (B, T) with k; got k.shape={k.shape}, v.shape={v.shape}" + ) + assert HV > 0 and HQK > 0 and HV % HQK == 0, ( + f"v head-dim (HV={HV}) must be a positive multiple of k head-dim (HQK={HQK})" + ) + if gk is not None: + assert gk.shape[0] == B and gk.shape[1] == T and gk.shape[2] == HV, ( + f"gk shape must be (B, T, HV={HV}, K); got {tuple(gk.shape)}" + ) + if beta is not None: + assert beta.shape[-1] == HV, f"beta last dim must equal HV={HV}; got {tuple(beta.shape)}" BT = chunk_size if cu_seqlens is None: @@ -773,18 +788,20 @@ def chunk_kda_fwd_intra( "cu_seqlens and chunk_indices must be int32 for cuda impl" ) - Aqk = torch.empty(B, T, H, BT, device=k.device, dtype=k.dtype) - Akk = torch.empty(B, T, H, BT, device=k.device, dtype=k.dtype) + # Aqk/Akk are produced per v-head (they live in v-head space because g/beta are per v-head). + Aqk = torch.empty(B, T, HV, BT, device=k.device, dtype=k.dtype) + Akk = torch.empty(B, T, HV, BT, device=k.device, dtype=k.dtype) tile_counter = torch.zeros(1, dtype=torch.int32, device=q.device) cula_cuda.chunk_kda_fwd_intra_cuda( q, k, gk, beta, cu_seqlens, chunk_indices, Aqk, Akk, tile_counter, scale, chunk_size, use_tf32_inverse, unified_gref ) - w = torch.empty_like(k) + # w/u/kg/qg are all per-v-head outputs. + w = torch.empty(B, T, HV, K, device=k.device, dtype=k.dtype) u = torch.empty_like(v) - qg = torch.empty_like(q) if disable_recompute else None - kg = torch.empty_like(k) if gk is not None else None + qg = torch.empty(B, T, HV, K, device=q.device, dtype=q.dtype) if disable_recompute else None + kg = torch.empty(B, T, HV, K, device=k.device, dtype=k.dtype) if gk is not None else None cula_cuda.recompute_w_u_cuda( k, v, beta, Akk, gk, cu_seqlens, chunk_indices, w, u, kg, chunk_size, q if disable_recompute else None, qg diff --git a/tests/test_kda_gva_intra_sm100.py b/tests/test_kda_gva_intra_sm100.py new file mode 100644 index 0000000..a86e56f --- /dev/null +++ b/tests/test_kda_gva_intra_sm100.py @@ -0,0 +1,375 @@ +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for SM100 KDA GVA (HV > HQK) support in chunk_kda_fwd_intra. + +The SM100 kernels (kda_fwd_intra / kda_fwd_recomp_w_u) now accept: + * q, k with head-dim ``HQK`` + * v, g, beta with head-dim ``HV`` where ``HV = group_size * HQK`` (group_size >= 1) + +This file verifies that the cuLA GVA path produces numerically matching results +compared to the FLA Triton reference, where the FLA reference does not natively +support GVA and therefore receives ``k`` replicated along the head axis to +``HV`` heads. Both uniform-length and varlen layouts are covered, and an +additional degeneracy test asserts that ``HV == HQK`` (group_size == 1) keeps +the non-GVA behaviour untouched. +""" + +from __future__ import annotations + +import pytest +import torch +from fla.ops.kda.chunk_intra import chunk_kda_fwd_intra as fla_chunk_kda_fwd_intra +from fla.ops.kda.gate import kda_gate_chunk_cumsum +from fla.ops.utils.constant import RCP_LN2 +from fla.ops.utils.index import prepare_chunk_indices +from fla.utils import assert_close, device + +from cula.kda.chunk_intra import chunk_kda_fwd_intra as cula_chunk_kda_fwd_intra +from cula.utils import prepare_uniform_cu_seqlens + +pytestmark = pytest.mark.sm100_only + + +# ========================================================================= +# Helpers +# ========================================================================= + +def _l2norm_last(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.normalize(x.float(), p=2, dim=-1).to(x.dtype) + + +def _repeat_head(x: torch.Tensor, group_size: int, head_dim: int = 2) -> torch.Tensor: + """Replicate ``x`` along the head axis by ``group_size``. + + Mirrors GVA's broadcasting semantics: each QK head is paired with + ``group_size`` consecutive V heads, so ``k[..., h_qk, :]`` is used by + ``v[..., h_qk * group_size : (h_qk + 1) * group_size, :]``. + """ + return x.repeat_interleave(group_size, dim=head_dim).contiguous() + + +def _make_gva_inputs( + B: int, + T: int, + HQK: int, + HV: int, + D: int, + chunk_size: int, + cu_seqlens: torch.Tensor | None = None, + dtype: torch.dtype = torch.bfloat16, + seed: int = 42, +): + """Construct inputs for chunk_kda_fwd_intra in GVA layout. + + Returns: + q, k : (B, T, HQK, D) dtype + v : (B, T, HV, D) dtype + g : (B, T, HV, D) float32, after kda_gate_chunk_cumsum + beta : (B, T, HV) float32 in (0, 1) + scale : float + cu_seqlens : (N+1,) int32 or None + chunk_indices: (NT, 2) int32 or None + """ + assert HV % HQK == 0 and HV >= HQK, f"invalid HV/HQK: {HV}/{HQK}" + + torch.manual_seed(seed) + scale = D ** (-0.5) + + # QK are in HQK head space; V / gates / beta live in HV space. + q = torch.randn(B, T, HQK, D, dtype=dtype, device=device) + k = torch.randn(B, T, HQK, D, dtype=dtype, device=device) + v = torch.randn(B, T, HV, D, dtype=dtype, device=device) + g_raw = torch.randn(B, T, HV, D, dtype=dtype, device=device) + beta = torch.randn(B, T, HV, dtype=torch.float, device=device).sigmoid() + + # l2-normalise q/k so that scale/gate ranges match production use. + q = _l2norm_last(q) + k = _l2norm_last(k) + + # Per-HV gate preprocessing (cumsum inside chunks). + A_log = torch.randn(HV, dtype=torch.float, device=device) + dt_bias = torch.randn(HV * D, dtype=torch.float, device=device) + + chunk_indices = ( + prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + ) + g = kda_gate_chunk_cumsum( + g=g_raw, + A_log=A_log, + dt_bias=dt_bias, + scale=RCP_LN2, + chunk_size=chunk_size, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + lower_bound=-5.0, + ) + return q, k, v, g, beta, scale, cu_seqlens, chunk_indices + + +def _run_fla_ref(q, k_hqk, v, g, beta, scale, cu_seqlens, chunk_indices, chunk_size, group_size, disable_recompute): + """Reference: replicate k along head axis to HV, then call FLA intra. + + FLA's chunk_kda_fwd_intra assumes H == HQK == HV (no GVA), so we construct + the HV-head view of k and q before invoking it. + """ + k_hv = _repeat_head(k_hqk, group_size) + q_hv = _repeat_head(q, group_size) + return fla_chunk_kda_fwd_intra( + q=q_hv, + k=k_hv, + v=v, + gk=g, + beta=beta, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + chunk_indices=chunk_indices, + safe_gate=True, + disable_recompute=disable_recompute, + ) + + +def _run_cula_gva(q, k, v, g, beta, scale, cu_seqlens, chunk_indices, chunk_size, disable_recompute): + return cula_chunk_kda_fwd_intra( + q=q, + k=k, + v=v, + gk=g, + beta=beta, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + chunk_indices=chunk_indices, + safe_gate=True, + disable_recompute=disable_recompute, + ) + + +# ========================================================================= +# Uniform-length tests +# ========================================================================= + +@pytest.mark.parametrize("disable_recompute", [False, True], ids=["recomp", "no_recomp"]) +@pytest.mark.parametrize( + ("B", "T", "HQK", "group_size", "D"), + [ + pytest.param(*cfg, id="B{}-T{}-HQK{}-gs{}-D{}".format(*cfg)) + for cfg in [ + # group_size == 2: classic GVA 2:1 + (1, 256, 2, 2, 128), + (2, 512, 4, 2, 128), + # group_size == 4: wider grouping + (1, 1024, 2, 4, 128), + (2, 1024, 4, 4, 128), + # Non-multiple-of-BT sequence length to stress boundary handling. + (1, 500, 2, 2, 128), + (1, 1000, 4, 2, 128), + ] + ], +) +def test_gva_intra_uniform(B, T, HQK, group_size, D, disable_recompute): + """cuLA GVA path must match FLA(k-replicated-to-HV) for uniform seqlens.""" + HV = HQK * group_size + chunk_size = 64 + + cu_seqlens = prepare_uniform_cu_seqlens(B, T, torch.device(device), torch.int32) + q, k, v, g, beta, scale, cu_seqlens, chunk_indices = _make_gva_inputs( + B=B, T=T, HQK=HQK, HV=HV, D=D, chunk_size=chunk_size, cu_seqlens=cu_seqlens, + ) + + # cuLA GVA path (k in HQK head space). + w_c, u_c, qg_c, kg_c, Aqk_c, Akk_c = _run_cula_gva( + q, k, v, g, beta, scale, cu_seqlens, chunk_indices, chunk_size, disable_recompute, + ) + + # FLA reference (k replicated to HV). + w_r, u_r, qg_r, kg_r, Aqk_r, Akk_r = _run_fla_ref( + q, k, v, g, beta, scale, cu_seqlens, chunk_indices, chunk_size, group_size, disable_recompute, + ) + + # All outputs live in HV head space → shapes must match directly. + assert Aqk_c.shape == Aqk_r.shape, (Aqk_c.shape, Aqk_r.shape) + assert Akk_c.shape == Akk_r.shape, (Akk_c.shape, Akk_r.shape) + assert w_c.shape == w_r.shape, (w_c.shape, w_r.shape) + assert u_c.shape == u_r.shape, (u_c.shape, u_r.shape) + assert kg_c.shape == kg_r.shape, (kg_c.shape, kg_r.shape) + + # Aqk / Akk are the core A-matrices; they drive w/u, so keep tolerances tight. + assert_close("Aqk", Aqk_r, Aqk_c, 0.005) + assert_close("Akk", Akk_r, Akk_c, 0.008) + + # recompute_w_u outputs + assert_close("w", w_r, w_c, 0.008) + assert_close("u", u_r, u_c, 0.008) + assert_close("kg", kg_r, kg_c, 0.005) + + if disable_recompute: + assert qg_c is not None and qg_r is not None + assert qg_c.shape == qg_r.shape, (qg_c.shape, qg_r.shape) + assert_close("qg", qg_r, qg_c, 0.005) + else: + assert qg_c is None, "cuLA must not materialise qg when disable_recompute=False" + + +# ========================================================================= +# Varlen tests +# ========================================================================= + +@pytest.mark.parametrize("disable_recompute", [False, True], ids=["recomp", "no_recomp"]) +@pytest.mark.parametrize( + ("HQK", "group_size", "D", "cu_seqlens"), + [ + pytest.param(*cfg, id="HQK{}-gs{}-D{}-ns{}".format(cfg[0], cfg[1], cfg[2], len(cfg[3]) - 1)) + for cfg in [ + (2, 2, 128, [0, 256, 500, 1000]), + (4, 2, 128, [0, 100, 300, 1200, 2000]), + (2, 4, 128, [0, 15, 100, 300, 1200, 2048]), + # Simulated realistic trace. + ( + 4, 2, 128, + [0, 247, 699, 982, 1688, 1985, 2383, 3081, 3526, 3973, 4096], + ), + ] + ], +) +def test_gva_intra_varlen(HQK, group_size, D, cu_seqlens, disable_recompute): + """GVA correctness under variable-length (packed) inputs.""" + HV = HQK * group_size + chunk_size = 64 + + cu_seqlens_t = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + T = int(cu_seqlens_t[-1].item()) + # Packed layout uses B=1 and a flat time axis. + q, k, v, g, beta, scale, cu_seqlens_t, chunk_indices = _make_gva_inputs( + B=1, T=T, HQK=HQK, HV=HV, D=D, chunk_size=chunk_size, cu_seqlens=cu_seqlens_t, + ) + + w_c, u_c, qg_c, kg_c, Aqk_c, Akk_c = _run_cula_gva( + q, k, v, g, beta, scale, cu_seqlens_t, chunk_indices, chunk_size, disable_recompute, + ) + w_r, u_r, qg_r, kg_r, Aqk_r, Akk_r = _run_fla_ref( + q, k, v, g, beta, scale, cu_seqlens_t, chunk_indices, chunk_size, group_size, disable_recompute, + ) + + assert_close("Aqk", Aqk_r, Aqk_c, 0.005) + assert_close("Akk", Akk_r, Akk_c, 0.008) + assert_close("w", w_r, w_c, 0.008) + assert_close("u", u_r, u_c, 0.008) + assert_close("kg", kg_r, kg_c, 0.005) + + if disable_recompute: + assert_close("qg", qg_r, qg_c, 0.005) + else: + assert qg_c is None + + +# ========================================================================= +# Degeneracy: HV == HQK must match the non-GVA (same-shape) reference +# ========================================================================= + +@pytest.mark.parametrize("disable_recompute", [False, True], ids=["recomp", "no_recomp"]) +@pytest.mark.parametrize( + ("B", "T", "H", "D"), + [ + pytest.param(*cfg, id="B{}-T{}-H{}-D{}".format(*cfg)) + for cfg in [ + (1, 512, 4, 128), + (2, 1024, 4, 128), + ] + ], +) +def test_gva_intra_degenerate_equals_non_gva(B, T, H, D, disable_recompute): + """When HV == HQK, the GVA code path must be byte-for-byte equivalent + to the non-GVA path that existed before this change. + + We do not have a separate "non-GVA" entrypoint, but we can assert the + cuLA path matches FLA with *no* head replication (group_size=1), which + exercises the ``HV == HQK`` fast-path inside the new kernels. + """ + chunk_size = 64 + cu_seqlens = prepare_uniform_cu_seqlens(B, T, torch.device(device), torch.int32) + q, k, v, g, beta, scale, cu_seqlens, chunk_indices = _make_gva_inputs( + B=B, T=T, HQK=H, HV=H, D=D, chunk_size=chunk_size, cu_seqlens=cu_seqlens, + ) + + w_c, u_c, qg_c, kg_c, Aqk_c, Akk_c = _run_cula_gva( + q, k, v, g, beta, scale, cu_seqlens, chunk_indices, chunk_size, disable_recompute, + ) + # group_size=1 → no replication; identical input shape to cuLA. + w_r, u_r, qg_r, kg_r, Aqk_r, Akk_r = fla_chunk_kda_fwd_intra( + q=q, k=k, v=v, gk=g, beta=beta, scale=scale, + cu_seqlens=cu_seqlens, chunk_size=chunk_size, chunk_indices=chunk_indices, + safe_gate=True, disable_recompute=disable_recompute, + ) + + assert_close("Aqk", Aqk_r, Aqk_c, 0.005) + assert_close("Akk", Akk_r, Akk_c, 0.008) + assert_close("w", w_r, w_c, 0.008) + assert_close("u", u_r, u_c, 0.008) + assert_close("kg", kg_r, kg_c, 0.005) + if disable_recompute: + assert_close("qg", qg_r, qg_c, 0.005) + + +# ========================================================================= +# Shape / contract sanity checks (run even without a reference) +# ========================================================================= + +@pytest.mark.parametrize("group_size", [1, 2, 4]) +def test_gva_intra_output_shapes(group_size): + """All outputs of chunk_kda_fwd_intra must live in HV-head space.""" + B, T, HQK, D = 1, 256, 2, 128 + HV = HQK * group_size + chunk_size = 64 + cu_seqlens = prepare_uniform_cu_seqlens(B, T, torch.device(device), torch.int32) + q, k, v, g, beta, scale, cu_seqlens, chunk_indices = _make_gva_inputs( + B=B, T=T, HQK=HQK, HV=HV, D=D, chunk_size=chunk_size, cu_seqlens=cu_seqlens, + ) + w, u, qg, kg, Aqk, Akk = _run_cula_gva( + q, k, v, g, beta, scale, cu_seqlens, chunk_indices, chunk_size, disable_recompute=True, + ) + + assert Aqk.shape == (B, T, HV, chunk_size), Aqk.shape + assert Akk.shape == (B, T, HV, chunk_size), Akk.shape + assert w.shape == (B, T, HV, D), w.shape + assert u.shape == (B, T, HV, D), u.shape + assert kg.shape == (B, T, HV, D), kg.shape + assert qg is not None and qg.shape == (B, T, HV, D), (None if qg is None else qg.shape) + + +# ========================================================================= +# Negative / assertion tests +# ========================================================================= + +def test_gva_intra_rejects_non_multiple_ratio(): + """HV must be a positive integer multiple of HQK.""" + B, T, HQK, HV, D = 1, 128, 3, 5, 128 # 5 % 3 != 0 + chunk_size = 64 + cu_seqlens = prepare_uniform_cu_seqlens(B, T, torch.device(device), torch.int32) + # We intentionally do not use _make_gva_inputs because the assert fires + # before kernel launch on the python side. + dtype = torch.bfloat16 + q = torch.randn(B, T, HQK, D, dtype=dtype, device=device) + k = torch.randn(B, T, HQK, D, dtype=dtype, device=device) + v = torch.randn(B, T, HV, D, dtype=dtype, device=device) + g = torch.randn(B, T, HV, D, dtype=torch.float, device=device) + beta = torch.randn(B, T, HV, dtype=torch.float, device=device).sigmoid() + + with pytest.raises(AssertionError): + cula_chunk_kda_fwd_intra( + q=q, k=k, v=v, gk=g, beta=beta, scale=D ** -0.5, + cu_seqlens=cu_seqlens, chunk_size=chunk_size, + safe_gate=True, disable_recompute=False, + )