【Draft】[KDA] sm100 GVA enhance#65
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements Grouped V-head Attention (GVA) support across the KDA kernels for both SM90 and SM100 architectures. Key changes include decoupling head counts for Q/K and V/G tensors, updating TMA descriptors and tile scheduling logic to handle these grouped configurations, and adding comprehensive validation checks. The Python API and test suite have been updated to support and verify GVA functionality. Feedback from the review identifies a documentation mismatch regarding tensor layouts in the SM100 mainloop and suggests correcting terminology in Python error messages to distinguish between head count and head dimension.
| int row = (idx_in_wg / 32) * 16 + (idx_in_wg % 16); | ||
|
|
||
| // GMEM output address: layout [total_len, d, h], stride [d*h, 1, d] | ||
| // GMEM output address: layout [total_len, d, h_v], stride [d*h_v, 1, d] |
There was a problem hiding this comment.
| f"v must share (B, T) with k; got k.shape={k.shape}, v.shape={v.shape}" | ||
| ) | ||
| assert HV > 0 and HQK > 0 and HV % HQK == 0, ( | ||
| f"v head-dim (HV={HV}) must be a positive multiple of k head-dim (HQK={HQK})" |
There was a problem hiding this comment.
The error message incorrectly uses the term 'head-dim' when referring to HV and HQK, which represent the number of heads (head count). The head dimension is represented by K.
| f"v head-dim (HV={HV}) must be a positive multiple of k head-dim (HQK={HQK})" | |
| f"v head count (HV={HV}) must be a positive multiple of k head count (HQK={HQK})" |
Follow the GVA pattern used in the SM90 KDA (and in gated_delta_rule GVA) so that the SM100 KDA forward pass can handle num_v_heads > num_qk_heads. C++ changes: - tile_scheduler: Params now carries heads_per_group; decode_tile_coord enumerates tiles in v-head space and returns both v_head_idx and qk_head_idx (= v_head_idx / heads_per_group). When HV == HQK this degenerates to the previous behaviour. - kda_config: KDA_fwd_intra_params / KDA_fwd_recomp_w_u_params split h into h_qk and h_v and cache heads_per_group; Akk and w/u/kg/qg layouts now live in v-head space. - intra kernel/mainloop: Q/K TMA descriptors use shape_QK (total, d, h_qk); g TMA uses shape_VG (total, d, h_v). Load warp slices Q/K with qk_head_idx and g with v_head_idx; Aqk row stride and beta stride now use params.h_v. - recomp_w_u kernel/mainloop: K/Q TMA descriptors use shape_QK; V/g TMA use shape_VG; Akk TMA uses shape_Akk (total, BT, h_v). Load warp slices K/Q with qk_head_idx and V/g/Akk with v_head_idx; w/u/kg/qg write stride and beta stride now use params.h_v. API / Python: - kda_sm100.cu: derive h_qk from Q/K and h_v from V/g; validate HV % HQK == 0 and beta/qg_out shapes. - cula/kda/chunk_intra.py: infer HQK from k.shape[2] and HV from v.shape[2]; allocate Aqk, Akk, w, kg, qg in v-head space; add shape assertions. Backward compatible: when HV == HQK, heads_per_group == 1 and qk_head_idx == v_head_idx, and all shapes/strides reduce to the pre-GVA layout.
e0e3494 to
58535e2
Compare
Summary
Extend SM100 KDA forward to support num_v_heads > num_qk_heads (GVA), following the pattern already established by the SM90 KDA and by gated_delta_rule GVA.
Branch:
feat/kda-sm100-gva, single commite0e3494.What's changed
Scheduler & config
tile_scheduler.hpp–Paramsnow carriesheads_per_group;decode_tile_coordenumerates tiles in v-head space and returns bothv_head_idxandqk_head_idx = v_head_idx / heads_per_group. WhenHV == HQK,heads_per_group == 1and behaviour is unchanged.kda_config.hpp–KDA_fwd_intra_params/KDA_fwd_recomp_w_u_paramssplithintoh_qk,h_v, and cacheheads_per_group.Akkandw/u/kg/qglayouts now live in v-head space.Intra kernel / mainloop
shape_QK (total, d, h_qk); g TMA usesshape_VG (total, d, h_v).qk_head_idxand g with the v-head index.Aqkrow stride andbetastride now useparams.h_v.Recomp W/U kernel / mainloop
shape_QK; V/g TMA useshape_VG; Akk TMA usesshape_Akk (total, BT, h_v).qk_head_idxand V/g/Akk with the v-head index.w/u/kg/qgwrite stride andbetastride now useparams.h_v.API / Python
csrc/api/kda_sm100.cu– deriveh_qkfrom Q/K andh_vfrom V/g; assertHV % HQK == 0plus beta/qg_out shape checks.cula/kda/chunk_intra.py– inferHQK = k.shape[2],HV = v.shape[2]; allocateAqk, Akk, w, kg, qgin v-head space; add shape assertions.Backward compatibility
When
HV == HQK:heads_per_group == 1qk_head_idx == v_head_idxshape_qk == shape_vgNo existing HV == HQK workloads should observe any behavioural change.
Known follow-ups (not part of this PR)
The end-to-end SM100 path (
chunk_kda_fwdincula/kda/chunk_fwd.py) feeds the intra/recomp outputs intochunk_gated_delta_rule_fwd_handchunk_gla_fwd_o, which currently assumeq/v/g/A/oshare the same head count. This PR intentionally leaves those two CuTe kernels untouched (SM90 does not go through them, so mirroring SM90 leaves them out of scope). A follow-up PR is needed to teach those two kernels GVA before the full SM100 pipeline runs withHV > HQK.Testing
HV == HQK.HV > HQKafter the downstream CuTe kernels are GVA-ready.Draft because the downstream GVA-enablement work and end-to-end validation are still pending.
#55