ck_tile grouped gemm: more padding#574
Conversation
| if (need_k_pad && ctx.transB) { | ||
| return false; | ||
| } |
There was a problem hiding this comment.
Assuming we have access to B column-wise storage, could we avoid this fallback by selecting B’s column-wise buffer and calling CK with transB=false, while preserving the same logical GEMM? In other words, is the incorrect-result issue specific to CK’s kPadK + transB=true / ColMajor-B path, or would the columnwise-buffer normalization still hit the same underlying issue?
| ". Falling back. " | ||
| "CK_Tile constraints for bf16/fp16: " | ||
| "contiguous dim of A and B must be dword-aligned (even), " | ||
| "N must be multiple of 16 (GetVectorSizeC)."); |
There was a problem hiding this comment.
I'm not sure this function (GetVectorSizeC) implies that N must be a multiple of 16. For the M16N16K32 warp-gemm path, the relevant attributes appear to be kN = 16 and kCNLane = 16, so in the row-major non-TransposeC case this returns kCNLane / kN = 1. That seems to describe the per-thread contiguous C vector size rather than an N divisibility requirement. Is there another place where the N % 16 == 0 constraint is enforced or assumed?
| reason="Only enable CUTLASS/CK grouped gemm on Hopper or ROCm", | ||
| ) | ||
| @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=str) | ||
| @pytest.mark.parametrize("layout", ["TN", "NN"]) |
There was a problem hiding this comment.
Oh perhaps the reason NT is left out has to do with your comment about kPadK + column major B
| if pad_dim == "K": | ||
| gemm_k = unaligned_k |
There was a problem hiding this comment.
This test seems to cover one unaligned dimension at a time, but not the combined kPadM && kPadK case. Since the dispatch logic can instantiate a runner with padding set for both dimensions, should we add a case where both M and K are unaligned?
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: