Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 79 additions & 9 deletions csrc/api/kda_sm100.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,33 @@ ChunkKDAFwdIntra(
KDA_fwd_intra_params params;
params.total_q_len = q.size(0) * q.size(1);
params.b = cu_seqlens.size(0) - 1;
params.h = q.size(2);
// GVA: Q/K are in h_qk head space (from q.size(2)); g/beta/Aqk/Akk are in h_v head
// space (from g.size(2)). When HV == HQK, heads_per_group == 1 and behaviour matches
// the pre-GVA path.
params.h_qk = q.size(2);
params.h_v = g.size(2);
TORCH_CHECK(
k.size(2) == params.h_qk,
"ChunkKDAFwdIntra: k.size(2) (",
k.size(2),
") must match q.size(2) (",
params.h_qk,
") under GVA (Q/K share h_qk).");
TORCH_CHECK(
beta.size(-1) == params.h_v,
"ChunkKDAFwdIntra: beta.size(-1) (",
beta.size(-1),
") must equal h_v (",
params.h_v,
").");
TORCH_CHECK(
params.h_qk > 0 && params.h_v > 0 && params.h_v % params.h_qk == 0,
"ChunkKDAFwdIntra: h_v (",
params.h_v,
") must be a positive multiple of h_qk (",
params.h_qk,
").");
params.heads_per_group = params.h_v / params.h_qk;
params.d = q.size(3);
params.chunk_size = chunk_size;
params.scale = scale;
Expand All @@ -56,13 +82,15 @@ ChunkKDAFwdIntra(
params.chunk_indices_ptr = chunk_indices.data_ptr();
params.Aqk_out_ptr = Aqk_out.data_ptr();
params.Akk_out_ptr = Akk_out.data_ptr();
params.shape_Akk = cute::make_shape(params.total_q_len, params.chunk_size, params.h);
params.stride_Akk = cute::make_stride(params.chunk_size * params.h, cute::_1{}, params.chunk_size);
// Akk is laid out per v-head: (total_len, chunk_size, h_v).
params.shape_Akk = cute::make_shape(params.total_q_len, params.chunk_size, params.h_v);
params.stride_Akk = cute::make_stride(params.chunk_size * params.h_v, cute::_1{}, params.chunk_size);
int tile_num = chunk_indices.size(0);
auto device_prop = at::cuda::getCurrentDeviceProperties();
params.num_sm = device_prop->multiProcessorCount;
params.tile_scheduler_params =
StaticPersistentTileScheduler::Params{tile_num, params.h, params.num_sm, (int*)tile_counter.data_ptr()};
// Tiles are enumerated in v-head space.
params.tile_scheduler_params = StaticPersistentTileScheduler::Params{
tile_num, params.h_v, params.heads_per_group, params.num_sm, (int*)tile_counter.data_ptr()};

kda::sm100::run_kda_fwd_intra_sm100(params, at::cuda::getCurrentCUDAStream());
}
Expand All @@ -85,7 +113,31 @@ ChunkKDAFwdRecompWU(
KDA_fwd_recomp_w_u_params params;
params.total_len = k.size(0) * k.size(1);
params.b = cu_seqlens.size(0) - 1;
params.h = k.size(2);
// GVA: K (and optional Q) live in h_qk space; V/G/beta/A/w/u/kg/qg live in h_v space.
params.h_qk = k.size(2);
params.h_v = v.size(2);
TORCH_CHECK(
g.size(2) == params.h_v,
"ChunkKDAFwdRecompWU: g.size(2) (",
g.size(2),
") must equal v.size(2) (",
params.h_v,
").");
TORCH_CHECK(
beta.size(-1) == params.h_v,
"ChunkKDAFwdRecompWU: beta.size(-1) (",
beta.size(-1),
") must equal h_v (",
params.h_v,
").");
TORCH_CHECK(
params.h_qk > 0 && params.h_v > 0 && params.h_v % params.h_qk == 0,
"ChunkKDAFwdRecompWU: h_v (",
params.h_v,
") must be a positive multiple of h_qk (",
params.h_qk,
").");
params.heads_per_group = params.h_v / params.h_qk;
params.d = k.size(3);
params.chunk_size = chunk_size;
TORCH_CHECK(
Expand All @@ -108,14 +160,32 @@ ChunkKDAFwdRecompWU(
TORCH_CHECK(
has_q == has_qg_out, "ChunkKDAFwdRecompWU: q and qg_out must either both be provided or both be omitted.");
params.store_qg = has_q && has_qg_out;
if (params.store_qg) {
TORCH_CHECK(
q->size(2) == params.h_qk,
"ChunkKDAFwdRecompWU: q.size(2) (",
q->size(2),
") must equal h_qk (",
params.h_qk,
").");
TORCH_CHECK(
qg_out->size(2) == params.h_v,
"ChunkKDAFwdRecompWU: qg_out.size(2) (",
qg_out->size(2),
") must equal h_v (",
params.h_v,
").");
}
params.q_ptr = params.store_qg ? q->data_ptr() : nullptr;
params.qg_out_ptr = params.store_qg ? qg_out->data_ptr() : nullptr;
params.shape_wukg = cute::make_shape(params.total_len, params.d, params.h);
params.stride_wukg = cute::make_stride(params.d * params.h, cute::_1{}, params.d);
// w/u/kg/qg are per v-head: (total_len, d, h_v).
params.shape_wukg = cute::make_shape(params.total_len, params.d, params.h_v);
params.stride_wukg = cute::make_stride(params.d * params.h_v, cute::_1{}, params.d);
int tile_num = chunk_indices.size(0);
auto device_prop = at::cuda::getCurrentDeviceProperties();
params.num_sm = device_prop->multiProcessorCount;
params.tile_scheduler_params = StaticPersistentTileScheduler::Params{tile_num, params.h, params.num_sm, nullptr};
params.tile_scheduler_params = StaticPersistentTileScheduler::Params{
tile_num, params.h_v, params.heads_per_group, params.num_sm, nullptr};

kda::sm100::run_kda_fwd_recomp_w_u_sm100(params, at::cuda::getCurrentCUDAStream());
}
50 changes: 30 additions & 20 deletions csrc/kda/sm100/kda_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,31 @@
#include "kda/sm100/tile_scheduler.hpp"

struct KDA_fwd_intra_params {
using GmemShapeAkk = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_kv, seqlen_kv, h)
// Akk shape is (total_seqlen, chunk_size, num_v_heads). Under GVA (num_v_heads > num_qk_heads),
// Aqk and Akk are produced per v-head because g/beta/Akk scaling all live in v-head space.
using GmemShapeAkk = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_kv, chunk_size, h_v)
using GmemStrideAkk = cute::Stride<int32_t, cute::_1, int32_t>;

int total_q_len;
int b;
int h;
// GVA: Q/K are sized by num_qk_heads; V, g, beta are sized by num_v_heads; Aqk/Akk are per v-head.
// When num_v_heads == num_qk_heads, heads_per_group == 1 and behaviour matches the pre-GVA path.
int h_qk;
int h_v;
int heads_per_group; // = h_v / h_qk, precomputed on host
int d;
int chunk_size;
float scale;
bool use_tf32_inverse;
bool unified_gref;
bool is_beta_bf16;

void* __restrict__ q_ptr; //[b, t, h, d]
void* __restrict__ k_ptr; //[b, t, h, d]
void* __restrict__ g_ptr; //[b, t, h, d]
void* __restrict__ beta_ptr; //[b, t, h]
void* __restrict__ Aqk_out_ptr; //[b, t, h, BT]
void* __restrict__ Akk_out_ptr; //[b, t, h, BT]
void* __restrict__ q_ptr; //[b, t, h_qk, d]
void* __restrict__ k_ptr; //[b, t, h_qk, d]
void* __restrict__ g_ptr; //[b, t, h_v, d]
void* __restrict__ beta_ptr; //[b, t, h_v]
void* __restrict__ Aqk_out_ptr; //[b, t, h_v, BT]
void* __restrict__ Akk_out_ptr; //[b, t, h_v, BT]
void* __restrict__ cu_seqlens_ptr; //[b + 1]
void* __restrict__ chunk_indices_ptr; //[(b * t) / chunk_size, 2]

Expand All @@ -48,28 +54,32 @@ struct KDA_fwd_intra_params {
};

struct KDA_fwd_recomp_w_u_params {
using GmemShapeWUKg = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_kv, seqlen_kv, h)
// w/u/kg/qg all have shape (total_seqlen, d, num_v_heads) under GVA.
using GmemShapeWUKg = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_kv, d, h_v)
using GmemStrideWUKg = cute::Stride<int32_t, cute::_1, int32_t>;

int total_len;
int b;
int h;
// GVA: K and (optional) Q are sized by num_qk_heads; V/G/beta/Akk/w/u/kg/qg are per v-head.
int h_qk;
int h_v;
int heads_per_group; // = h_v / h_qk, precomputed on host
int d;
int chunk_size;
bool is_beta_bf16;

void* __restrict__ k_ptr; //[b, t, h, d]
void* __restrict__ v_ptr; //[b, t, h, d]
void* __restrict__ q_ptr; //[b, t, h, d] (optional, for StoreQG)
void* __restrict__ beta_ptr; //[b, t, h]
void* __restrict__ A_ptr; //[b. t, h, BT]
void* __restrict__ g_ptr; //[b, t, h, d]
void* __restrict__ k_ptr; //[b, t, h_qk, d]
void* __restrict__ v_ptr; //[b, t, h_v, d]
void* __restrict__ q_ptr; //[b, t, h_qk, d] (optional, for StoreQG)
void* __restrict__ beta_ptr; //[b, t, h_v]
void* __restrict__ A_ptr; //[b, t, h_v, BT]
void* __restrict__ g_ptr; //[b, t, h_v, d]
void* __restrict__ cu_seqlens_ptr; //[b + 1]
void* __restrict__ chunk_indices_ptr; //[(b * t) / chunk_size, 2]
void* __restrict__ w_out_ptr; //[b, t, h, d]
void* __restrict__ u_out_ptr; //[b, t, h, d]
void* __restrict__ kg_out_ptr; //[b, t, h, d]
void* __restrict__ qg_out_ptr; //[b, t, h, d] (optional, for StoreQG)
void* __restrict__ w_out_ptr; //[b, t, h_v, d]
void* __restrict__ u_out_ptr; //[b, t, h_v, d]
void* __restrict__ kg_out_ptr; //[b, t, h_v, d]
void* __restrict__ qg_out_ptr; //[b, t, h_v, d] (optional, for StoreQG)

bool store_qg;

Expand Down
29 changes: 20 additions & 9 deletions csrc/kda/sm100/kda_fwd_intra_kernel_sm100.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ struct KdaChunkFwdIntraKernelSm100 {
using SmemLayoutInputFP32 = typename Mainloop::SmemLayoutInputFP32;

// TMA params (for host launcher)
template <typename ShapeQKG, typename TMA_Q, typename TMA_K, typename TMA_G>
using TmaParams = typename Mainloop::template TmaParams<ShapeQKG, TMA_Q, TMA_K, TMA_G>;
template <typename ShapeQK, typename ShapeVG, typename TMA_Q, typename TMA_K, typename TMA_G>
using TmaParams = typename Mainloop::template TmaParams<ShapeQK, ShapeVG, TMA_Q, TMA_K, TMA_G>;

// Pipeline types (for construction in operator())
using PipelineQKG = typename Mainloop::PipelineQKG;
Expand Down Expand Up @@ -321,29 +321,40 @@ __launch_bounds__(512, 1, 1) kda_fwd_intra_sm100_kernel_entry(
template <typename Kernel>
inline void
run_kda_fwd_intra_sm100_impl_dispatch(KDA_fwd_intra_params& params, cudaStream_t stream) {
auto shape_QKG = make_shape(params.total_q_len, params.d, params.h);
auto stride_QKG = make_stride(params.h * params.d, _1{}, params.d);
// GVA: Q/K are sized by `h_qk`; G is sized by `h_v`. When HV == HQK
// (heads_per_group == 1), shape_QK and shape_VG coincide with the
// pre-GVA shape_QKG and behaviour is unchanged.
auto shape_QK = make_shape(params.total_q_len, params.d, params.h_qk);
auto stride_QK = make_stride(params.h_qk * params.d, _1{}, params.d);
auto shape_VG = make_shape(params.total_q_len, params.d, params.h_v);
auto stride_VG = make_stride(params.h_v * params.d, _1{}, params.d);

// --- Build TMA descriptors ---
auto tma_Q = cute::make_tma_copy(
SM90_TMA_LOAD{},
make_tensor(make_gmem_ptr((ku::bf16*)params.q_ptr), make_layout(shape_QKG, stride_QKG)),
make_tensor(make_gmem_ptr((ku::bf16*)params.q_ptr), make_layout(shape_QK, stride_QK)),
typename Kernel::SmemLayoutInputBF16{});

auto tma_K = cute::make_tma_copy(
SM90_TMA_LOAD{},
make_tensor(make_gmem_ptr((ku::bf16*)params.k_ptr), make_layout(shape_QKG, stride_QKG)),
make_tensor(make_gmem_ptr((ku::bf16*)params.k_ptr), make_layout(shape_QK, stride_QK)),
typename Kernel::SmemLayoutInputBF16{});

auto tma_G = cute::make_tma_copy(
SM90_TMA_LOAD{},
make_tensor(make_gmem_ptr((float*)params.g_ptr), make_layout(shape_QKG, stride_QKG)),
make_tensor(make_gmem_ptr((float*)params.g_ptr), make_layout(shape_VG, stride_VG)),
typename Kernel::SmemLayoutInputFP32{});

// --- Pack TMA params ---
typename Kernel::template TmaParams<decltype(shape_QKG), decltype(tma_Q), decltype(tma_K), decltype(tma_G)>
typename Kernel::template TmaParams<
decltype(shape_QK),
decltype(shape_VG),
decltype(tma_Q),
decltype(tma_K),
decltype(tma_G)>
tma_params = {
shape_QKG,
shape_QK,
shape_VG,
tma_Q,
tma_K,
tma_G,
Expand Down
Loading