-
Notifications
You must be signed in to change notification settings - Fork 28
ck_tile grouped gemm: more padding #574
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3078,6 +3078,128 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): | |
| os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) | ||
|
|
||
|
|
||
| @pytest.mark.skipif( | ||
| torch.cuda.get_device_capability() != (9, 0) and not IS_HIP_EXTENSION, | ||
| reason="Only enable CUTLASS/CK grouped gemm on Hopper or ROCm", | ||
| ) | ||
| @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=str) | ||
| @pytest.mark.parametrize("layout", ["TN", "NN"]) | ||
| @pytest.mark.parametrize("accumulate", [False, True]) | ||
| @pytest.mark.parametrize( | ||
| "pad_dim", | ||
| ["K", "M", "N"], | ||
| ids=lambda d: f"pad{d}", | ||
| ) | ||
| def test_grouped_gemm_unaligned(dtype, layout, accumulate, pad_dim): | ||
| """Test CK grouped GEMM with M, N, or K not aligned to CK tile size. | ||
|
|
||
| CK constraints for bf16/fp16: | ||
| - Contiguous dim of A/B must be dword-aligned (even for 2-byte types). | ||
| RowMajor: contiguous dim is cols (K for A, N for B). | ||
| ColMajor: contiguous dim is rows (M for A, K for B). | ||
| - N: must be multiple of 16 (GetVectorSizeC, no dword fallback), tile 128/256 | ||
| - K tile: 64, M tile: 256 | ||
| """ | ||
| torch.manual_seed(0) | ||
| z = 8 | ||
|
|
||
| # Unaligned values per dimension (all satisfy CK vector-load constraints). | ||
| # K: even but not multiple of tile (64). Same for all groups. | ||
| # M: not multiples of tile (256), varies per group. | ||
| # N: multiple of 16 but not multiple of tile (128). | ||
| unaligned_k = 2026 | ||
| unaligned_m = [100, 300, 150, 200, 50, 350, 250, 180] | ||
| unaligned_n = 2032 | ||
|
|
||
| # Aligned defaults. | ||
| k_aligned = 2048 | ||
| m_aligned = 256 | ||
| n_aligned = 2048 | ||
|
|
||
| os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" | ||
|
|
||
| if layout == "TN": | ||
| # TN GEMM: M=m_splits[i], N=A.rows, K=A.cols | ||
| if pad_dim == "K": | ||
| k_val = unaligned_k | ||
| m_vals = [m_aligned] * z | ||
| n_val = n_aligned | ||
| elif pad_dim == "M": | ||
| k_val = k_aligned | ||
| m_vals = unaligned_m | ||
| n_val = n_aligned | ||
| else: # N | ||
| k_val = k_aligned | ||
| m_vals = [m_aligned] * z | ||
| n_val = unaligned_n | ||
|
|
||
| A = [torch.randn(n_val, k_val, dtype=dtype, device="cuda") for _ in range(z)] | ||
| B = [torch.randn(m, k_val, dtype=dtype, device="cuda") for m in m_vals] | ||
| total_m = sum(m_vals) | ||
| out = [torch.randn(total_m, n_val, dtype=dtype, device="cuda")] | ||
| out_ref = [o.clone() for o in torch.split(out[0], m_vals)] | ||
| m_splits = m_vals | ||
| grad = False | ||
| single_output = True | ||
| else: # NN | ||
| # NN GEMM: M=m_splits[i], N=A.cols, K=A.rows | ||
| if pad_dim == "K": | ||
| gemm_k = unaligned_k | ||
|
Comment on lines
+3146
to
+3147
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test seems to cover one unaligned dimension at a time, but not the combined kPadM && kPadK case. Since the dispatch logic can instantiate a runner with padding set for both dimensions, should we add a case where both M and K are unaligned? |
||
| m_vals = [m_aligned] * z | ||
| n_out = n_aligned | ||
| elif pad_dim == "M": | ||
| gemm_k = k_aligned | ||
| m_vals = unaligned_m | ||
| n_out = n_aligned | ||
| else: # N | ||
| gemm_k = k_aligned | ||
| m_vals = [m_aligned] * z | ||
| n_out = unaligned_n | ||
|
|
||
| A = [torch.randn(gemm_k, n_out, dtype=dtype, device="cuda") for _ in range(z)] | ||
| B = [torch.randn(m, gemm_k, dtype=dtype, device="cuda") for m in m_vals] | ||
| total_m = sum(m_vals) | ||
| out = [torch.randn(total_m, n_out, dtype=dtype, device="cuda")] | ||
| out_ref = [o.clone() for o in torch.split(out[0], m_vals)] | ||
| m_splits = m_vals | ||
| grad = True | ||
| single_output = True | ||
|
|
||
| # Reference: individual GEMMs | ||
| for i in range(z): | ||
| general_gemm( | ||
| A[i], | ||
| B[i], | ||
| dtype, | ||
| grad=grad, | ||
| accumulate=accumulate, | ||
| layout=layout, | ||
| out=out_ref[i], | ||
| ) | ||
| if single_output: | ||
| out_ref = [torch.cat(out_ref)] | ||
|
|
||
| general_grouped_gemm( | ||
| A, | ||
| B, | ||
| out, | ||
| [None] * z, | ||
| dtype, | ||
| m_splits=m_splits, | ||
| grad=grad, | ||
| accumulate=accumulate, | ||
| layout=layout, | ||
| single_output=single_output, | ||
| ) | ||
|
|
||
| for o, o_ref in zip(out, out_ref): | ||
| if IS_HIP_EXTENSION and accumulate and dtype == torch.bfloat16 and get_device_compute_capability() == (9, 4): | ||
| torch.testing.assert_close(o, o_ref, rtol=4e-2, atol=4e-2) | ||
| else: | ||
| torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2) | ||
|
|
||
| os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) | ||
|
|
||
| @pytest.mark.parametrize("N", [32]) | ||
| @pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) | ||
| @pytest.mark.parametrize( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -140,7 +140,17 @@ static inline bool launch_grouped_gemm_kernel(const DescContainer& descs, | |
|
|
||
| if (!Kernel::IsSupportedArgument(kargs)) { | ||
| NVTE_WARN("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config. " | ||
| "Falling back."); | ||
| "transA=", ctx.transA, " transB=", ctx.transB, | ||
| " accumulate=", ctx.accumulate, " groups=", ctx.group_num, | ||
| ". Falling back. " | ||
| "CK_Tile constraints for bf16/fp16: " | ||
| "contiguous dim of A and B must be dword-aligned (even), " | ||
| "N must be multiple of 16 (GetVectorSizeC)."); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure this function (GetVectorSizeC) implies that N must be a multiple of 16. For the M16N16K32 warp-gemm path, the relevant attributes appear to be kN = 16 and kCNLane = 16, so in the row-major non-TransposeC case this returns kCNLane / kN = 1. That seems to describe the per-thread contiguous C vector size rather than an N divisibility requirement. Is there another place where the N % 16 == 0 constraint is enforced or assumed? |
||
| for (size_t i = 0; i < descs.size(); ++i) { | ||
| NVTE_WARN(" group ", i, ": M=", descs[i].M, " N=", descs[i].N, " K=", descs[i].K, | ||
| " stride_A=", descs[i].stride_A, " stride_B=", descs[i].stride_B, | ||
| " stride_E=", descs[i].stride_E); | ||
| } | ||
| return false; | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -41,8 +41,11 @@ struct TileCfg_256x128x64 : TileCfg_256x256x64 { | |
| static constexpr ck_tile::index_t N_Tile = 128; | ||
| }; | ||
|
|
||
| struct TileCfg_256x128x64_padding : TileCfg_256x128x64 { | ||
| static constexpr bool kPadN = true; | ||
| template <typename Base, bool PadM_, bool PadN_, bool PadK_> | ||
| struct WithPadding : Base { | ||
| static constexpr bool kPadM = PadM_; | ||
| static constexpr bool kPadN = PadN_; | ||
| static constexpr bool kPadK = PadK_; | ||
| }; | ||
|
|
||
| template <typename AType, | ||
|
|
@@ -196,15 +199,15 @@ class GroupedGemmRunner : public RunnerInterface { | |
| } | ||
| }; | ||
|
|
||
| #define MAKE_RUNNER(TileCfg_) \ | ||
| #define MAKE_RUNNER(BaseCfg_, kPadM_, kPadN_, kPadK_) \ | ||
| TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.accumulate, accum_option, { \ | ||
| using Runner = GroupedGemmRunner<AType, \ | ||
| BType, \ | ||
| CType, \ | ||
| ALayout, \ | ||
| BLayout, \ | ||
| CLayout, \ | ||
| TileCfg_, \ | ||
| WithPadding<BaseCfg_, kPadM_, kPadN_, kPadK_>, \ | ||
| accum_option>; \ | ||
| runner = std::make_unique<Runner>(); \ | ||
| }) | ||
|
|
@@ -216,6 +219,37 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, | |
| const ck_tile::stream_config s{ctx.stream}; | ||
| std::unique_ptr<RunnerInterface> runner = nullptr; | ||
|
|
||
| // Check M and K alignment across all groups. | ||
| // All tile configs share the same M_Tile (256) and K_Tile (64). | ||
| constexpr ck_tile::index_t M_Tile = TileCfg_256x256x64::M_Tile; | ||
| constexpr ck_tile::index_t K_Tile = TileCfg_256x256x64::K_Tile; | ||
|
|
||
| bool need_m_pad = false; | ||
| bool need_k_pad = false; | ||
|
|
||
| for (int i = 0; i < ctx.group_num; ++i) { | ||
| const transformer_engine::Tensor* A_te = | ||
| transformer_engine::convertNVTETensorCheck(ctx.A[i]); | ||
| int64_t Ad0 = 0, Ad1 = 0; | ||
| if (get_flat_2d_dims(*A_te, Ad0, Ad1)) { | ||
| const int64_t M = ctx.transA ? Ad1 : Ad0; | ||
| const int64_t K = ctx.transA ? Ad0 : Ad1; | ||
|
|
||
| if (M % M_Tile != 0) | ||
| need_m_pad = true; | ||
| if (K % K_Tile != 0) | ||
| need_k_pad = true; | ||
| if (need_m_pad && need_k_pad) | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| // CK tile kernel produces incorrect results with kPadK + ColMajor B. | ||
| // Fall back to cuBLAS for this combination. | ||
| if (need_k_pad && ctx.transB) { | ||
| return false; | ||
| } | ||
|
Comment on lines
+249
to
+251
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Assuming we have access to B column-wise storage, could we avoid this fallback by selecting B’s column-wise buffer and calling CK with transB=false, while preserving the same logical GEMM? In other words, is the incorrect-result issue specific to CK’s kPadK + transB=true / ColMajor-B path, or would the columnwise-buffer normalization still hit the same underlying issue? |
||
|
|
||
| TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transA, kTransA, { | ||
| using ALayout = std::conditional_t<kTransA, ColMajor, RowMajor>; | ||
|
|
||
|
|
@@ -230,13 +264,17 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, | |
| TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { | ||
| using CType = typename TETypeToCKType<d_te_type>::type; | ||
|
|
||
| if (ctx.N % 256 == 0) { | ||
| MAKE_RUNNER(TileCfg_256x256x64); | ||
| } else if (ctx.N % 128 == 0) { | ||
| MAKE_RUNNER(TileCfg_256x128x64); | ||
| } else { | ||
| MAKE_RUNNER(TileCfg_256x128x64_padding); | ||
| } | ||
| TRANSFORMER_ENGINE_SWITCH_CONDITION(need_m_pad, kPadM, { | ||
| TRANSFORMER_ENGINE_SWITCH_CONDITION(need_k_pad, kPadK, { | ||
| if (ctx.N % 256 == 0) { | ||
| MAKE_RUNNER(TileCfg_256x256x64, kPadM, false, kPadK); | ||
| } else if (ctx.N % 128 == 0) { | ||
| MAKE_RUNNER(TileCfg_256x128x64, kPadM, false, kPadK); | ||
| } else { | ||
| MAKE_RUNNER(TileCfg_256x128x64, kPadM, true, kPadK); | ||
| } | ||
| }); | ||
| }); | ||
| }); | ||
| }); | ||
| }); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh perhaps the reason NT is left out has to do with your comment about kPadK + column major B