From 95f984ce6937fcd5f0c8aaea62c025e4af2b9f81 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 4 May 2026 19:14:20 -0500 Subject: [PATCH] ck_tile grouped gemm: more padding --- tests/pytorch/test_numerics.py | 122 ++++++++++++++++++ .../ck_grouped_gemm/ck_grouped_gemm_common.h | 12 +- .../ck_grouped_gemm/ck_grouped_gemm_fp16.cpp | 60 +++++++-- 3 files changed, 182 insertions(+), 12 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 4a768377e..f548b36d6 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -3078,6 +3078,128 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) +@pytest.mark.skipif( + torch.cuda.get_device_capability() != (9, 0) and not IS_HIP_EXTENSION, + reason="Only enable CUTLASS/CK grouped gemm on Hopper or ROCm", +) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=str) +@pytest.mark.parametrize("layout", ["TN", "NN"]) +@pytest.mark.parametrize("accumulate", [False, True]) +@pytest.mark.parametrize( + "pad_dim", + ["K", "M", "N"], + ids=lambda d: f"pad{d}", +) +def test_grouped_gemm_unaligned(dtype, layout, accumulate, pad_dim): + """Test CK grouped GEMM with M, N, or K not aligned to CK tile size. + + CK constraints for bf16/fp16: + - Contiguous dim of A/B must be dword-aligned (even for 2-byte types). + RowMajor: contiguous dim is cols (K for A, N for B). + ColMajor: contiguous dim is rows (M for A, K for B). + - N: must be multiple of 16 (GetVectorSizeC, no dword fallback), tile 128/256 + - K tile: 64, M tile: 256 + """ + torch.manual_seed(0) + z = 8 + + # Unaligned values per dimension (all satisfy CK vector-load constraints). + # K: even but not multiple of tile (64). Same for all groups. + # M: not multiples of tile (256), varies per group. + # N: multiple of 16 but not multiple of tile (128). + unaligned_k = 2026 + unaligned_m = [100, 300, 150, 200, 50, 350, 250, 180] + unaligned_n = 2032 + + # Aligned defaults. + k_aligned = 2048 + m_aligned = 256 + n_aligned = 2048 + + os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" + + if layout == "TN": + # TN GEMM: M=m_splits[i], N=A.rows, K=A.cols + if pad_dim == "K": + k_val = unaligned_k + m_vals = [m_aligned] * z + n_val = n_aligned + elif pad_dim == "M": + k_val = k_aligned + m_vals = unaligned_m + n_val = n_aligned + else: # N + k_val = k_aligned + m_vals = [m_aligned] * z + n_val = unaligned_n + + A = [torch.randn(n_val, k_val, dtype=dtype, device="cuda") for _ in range(z)] + B = [torch.randn(m, k_val, dtype=dtype, device="cuda") for m in m_vals] + total_m = sum(m_vals) + out = [torch.randn(total_m, n_val, dtype=dtype, device="cuda")] + out_ref = [o.clone() for o in torch.split(out[0], m_vals)] + m_splits = m_vals + grad = False + single_output = True + else: # NN + # NN GEMM: M=m_splits[i], N=A.cols, K=A.rows + if pad_dim == "K": + gemm_k = unaligned_k + m_vals = [m_aligned] * z + n_out = n_aligned + elif pad_dim == "M": + gemm_k = k_aligned + m_vals = unaligned_m + n_out = n_aligned + else: # N + gemm_k = k_aligned + m_vals = [m_aligned] * z + n_out = unaligned_n + + A = [torch.randn(gemm_k, n_out, dtype=dtype, device="cuda") for _ in range(z)] + B = [torch.randn(m, gemm_k, dtype=dtype, device="cuda") for m in m_vals] + total_m = sum(m_vals) + out = [torch.randn(total_m, n_out, dtype=dtype, device="cuda")] + out_ref = [o.clone() for o in torch.split(out[0], m_vals)] + m_splits = m_vals + grad = True + single_output = True + + # Reference: individual GEMMs + for i in range(z): + general_gemm( + A[i], + B[i], + dtype, + grad=grad, + accumulate=accumulate, + layout=layout, + out=out_ref[i], + ) + if single_output: + out_ref = [torch.cat(out_ref)] + + general_grouped_gemm( + A, + B, + out, + [None] * z, + dtype, + m_splits=m_splits, + grad=grad, + accumulate=accumulate, + layout=layout, + single_output=single_output, + ) + + for o, o_ref in zip(out, out_ref): + if IS_HIP_EXTENSION and accumulate and dtype == torch.bfloat16 and get_device_compute_capability() == (9, 4): + torch.testing.assert_close(o, o_ref, rtol=4e-2, atol=4e-2) + else: + torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2) + + os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) + @pytest.mark.parametrize("N", [32]) @pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize( diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h index c89f10232..75746ab8f 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h @@ -140,7 +140,17 @@ static inline bool launch_grouped_gemm_kernel(const DescContainer& descs, if (!Kernel::IsSupportedArgument(kargs)) { NVTE_WARN("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config. " - "Falling back."); + "transA=", ctx.transA, " transB=", ctx.transB, + " accumulate=", ctx.accumulate, " groups=", ctx.group_num, + ". Falling back. " + "CK_Tile constraints for bf16/fp16: " + "contiguous dim of A and B must be dword-aligned (even), " + "N must be multiple of 16 (GetVectorSizeC)."); + for (size_t i = 0; i < descs.size(); ++i) { + NVTE_WARN(" group ", i, ": M=", descs[i].M, " N=", descs[i].N, " K=", descs[i].K, + " stride_A=", descs[i].stride_A, " stride_B=", descs[i].stride_B, + " stride_E=", descs[i].stride_E); + } return false; } diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp index 660dbefb8..1f66cdf57 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp @@ -41,8 +41,11 @@ struct TileCfg_256x128x64 : TileCfg_256x256x64 { static constexpr ck_tile::index_t N_Tile = 128; }; -struct TileCfg_256x128x64_padding : TileCfg_256x128x64 { - static constexpr bool kPadN = true; +template +struct WithPadding : Base { + static constexpr bool kPadM = PadM_; + static constexpr bool kPadN = PadN_; + static constexpr bool kPadK = PadK_; }; template , \ accum_option>; \ runner = std::make_unique(); \ }) @@ -216,6 +219,37 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, const ck_tile::stream_config s{ctx.stream}; std::unique_ptr runner = nullptr; + // Check M and K alignment across all groups. + // All tile configs share the same M_Tile (256) and K_Tile (64). + constexpr ck_tile::index_t M_Tile = TileCfg_256x256x64::M_Tile; + constexpr ck_tile::index_t K_Tile = TileCfg_256x256x64::K_Tile; + + bool need_m_pad = false; + bool need_k_pad = false; + + for (int i = 0; i < ctx.group_num; ++i) { + const transformer_engine::Tensor* A_te = + transformer_engine::convertNVTETensorCheck(ctx.A[i]); + int64_t Ad0 = 0, Ad1 = 0; + if (get_flat_2d_dims(*A_te, Ad0, Ad1)) { + const int64_t M = ctx.transA ? Ad1 : Ad0; + const int64_t K = ctx.transA ? Ad0 : Ad1; + + if (M % M_Tile != 0) + need_m_pad = true; + if (K % K_Tile != 0) + need_k_pad = true; + if (need_m_pad && need_k_pad) + break; + } + } + + // CK tile kernel produces incorrect results with kPadK + ColMajor B. + // Fall back to cuBLAS for this combination. + if (need_k_pad && ctx.transB) { + return false; + } + TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transA, kTransA, { using ALayout = std::conditional_t; @@ -230,13 +264,17 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { using CType = typename TETypeToCKType::type; - if (ctx.N % 256 == 0) { - MAKE_RUNNER(TileCfg_256x256x64); - } else if (ctx.N % 128 == 0) { - MAKE_RUNNER(TileCfg_256x128x64); - } else { - MAKE_RUNNER(TileCfg_256x128x64_padding); - } + TRANSFORMER_ENGINE_SWITCH_CONDITION(need_m_pad, kPadM, { + TRANSFORMER_ENGINE_SWITCH_CONDITION(need_k_pad, kPadK, { + if (ctx.N % 256 == 0) { + MAKE_RUNNER(TileCfg_256x256x64, kPadM, false, kPadK); + } else if (ctx.N % 128 == 0) { + MAKE_RUNNER(TileCfg_256x128x64, kPadM, false, kPadK); + } else { + MAKE_RUNNER(TileCfg_256x128x64, kPadM, true, kPadK); + } + }); + }); }); }); });