Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3078,6 +3078,128 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None)


@pytest.mark.skipif(
torch.cuda.get_device_capability() != (9, 0) and not IS_HIP_EXTENSION,
reason="Only enable CUTLASS/CK grouped gemm on Hopper or ROCm",
)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=str)
@pytest.mark.parametrize("layout", ["TN", "NN"])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NT?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh perhaps the reason NT is left out has to do with your comment about kPadK + column major B

@pytest.mark.parametrize("accumulate", [False, True])
@pytest.mark.parametrize(
"pad_dim",
["K", "M", "N"],
ids=lambda d: f"pad{d}",
)
def test_grouped_gemm_unaligned(dtype, layout, accumulate, pad_dim):
"""Test CK grouped GEMM with M, N, or K not aligned to CK tile size.

CK constraints for bf16/fp16:
- Contiguous dim of A/B must be dword-aligned (even for 2-byte types).
RowMajor: contiguous dim is cols (K for A, N for B).
ColMajor: contiguous dim is rows (M for A, K for B).
- N: must be multiple of 16 (GetVectorSizeC, no dword fallback), tile 128/256
- K tile: 64, M tile: 256
"""
torch.manual_seed(0)
z = 8

# Unaligned values per dimension (all satisfy CK vector-load constraints).
# K: even but not multiple of tile (64). Same for all groups.
# M: not multiples of tile (256), varies per group.
# N: multiple of 16 but not multiple of tile (128).
unaligned_k = 2026
unaligned_m = [100, 300, 150, 200, 50, 350, 250, 180]
unaligned_n = 2032

# Aligned defaults.
k_aligned = 2048
m_aligned = 256
n_aligned = 2048

os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1"

if layout == "TN":
# TN GEMM: M=m_splits[i], N=A.rows, K=A.cols
if pad_dim == "K":
k_val = unaligned_k
m_vals = [m_aligned] * z
n_val = n_aligned
elif pad_dim == "M":
k_val = k_aligned
m_vals = unaligned_m
n_val = n_aligned
else: # N
k_val = k_aligned
m_vals = [m_aligned] * z
n_val = unaligned_n

A = [torch.randn(n_val, k_val, dtype=dtype, device="cuda") for _ in range(z)]
B = [torch.randn(m, k_val, dtype=dtype, device="cuda") for m in m_vals]
total_m = sum(m_vals)
out = [torch.randn(total_m, n_val, dtype=dtype, device="cuda")]
out_ref = [o.clone() for o in torch.split(out[0], m_vals)]
m_splits = m_vals
grad = False
single_output = True
else: # NN
# NN GEMM: M=m_splits[i], N=A.cols, K=A.rows
if pad_dim == "K":
gemm_k = unaligned_k
Comment on lines +3146 to +3147
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test seems to cover one unaligned dimension at a time, but not the combined kPadM && kPadK case. Since the dispatch logic can instantiate a runner with padding set for both dimensions, should we add a case where both M and K are unaligned?

m_vals = [m_aligned] * z
n_out = n_aligned
elif pad_dim == "M":
gemm_k = k_aligned
m_vals = unaligned_m
n_out = n_aligned
else: # N
gemm_k = k_aligned
m_vals = [m_aligned] * z
n_out = unaligned_n

A = [torch.randn(gemm_k, n_out, dtype=dtype, device="cuda") for _ in range(z)]
B = [torch.randn(m, gemm_k, dtype=dtype, device="cuda") for m in m_vals]
total_m = sum(m_vals)
out = [torch.randn(total_m, n_out, dtype=dtype, device="cuda")]
out_ref = [o.clone() for o in torch.split(out[0], m_vals)]
m_splits = m_vals
grad = True
single_output = True

# Reference: individual GEMMs
for i in range(z):
general_gemm(
A[i],
B[i],
dtype,
grad=grad,
accumulate=accumulate,
layout=layout,
out=out_ref[i],
)
if single_output:
out_ref = [torch.cat(out_ref)]

general_grouped_gemm(
A,
B,
out,
[None] * z,
dtype,
m_splits=m_splits,
grad=grad,
accumulate=accumulate,
layout=layout,
single_output=single_output,
)

for o, o_ref in zip(out, out_ref):
if IS_HIP_EXTENSION and accumulate and dtype == torch.bfloat16 and get_device_compute_capability() == (9, 4):
torch.testing.assert_close(o, o_ref, rtol=4e-2, atol=4e-2)
else:
torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2)

os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None)

@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,17 @@ static inline bool launch_grouped_gemm_kernel(const DescContainer& descs,

if (!Kernel::IsSupportedArgument(kargs)) {
NVTE_WARN("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config. "
"Falling back.");
"transA=", ctx.transA, " transB=", ctx.transB,
" accumulate=", ctx.accumulate, " groups=", ctx.group_num,
". Falling back. "
"CK_Tile constraints for bf16/fp16: "
"contiguous dim of A and B must be dword-aligned (even), "
"N must be multiple of 16 (GetVectorSizeC).");
Copy link
Copy Markdown
Contributor

@aris134 aris134 May 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this function (GetVectorSizeC) implies that N must be a multiple of 16. For the M16N16K32 warp-gemm path, the relevant attributes appear to be kN = 16 and kCNLane = 16, so in the row-major non-TransposeC case this returns kCNLane / kN = 1. That seems to describe the per-thread contiguous C vector size rather than an N divisibility requirement. Is there another place where the N % 16 == 0 constraint is enforced or assumed?

for (size_t i = 0; i < descs.size(); ++i) {
NVTE_WARN(" group ", i, ": M=", descs[i].M, " N=", descs[i].N, " K=", descs[i].K,
" stride_A=", descs[i].stride_A, " stride_B=", descs[i].stride_B,
" stride_E=", descs[i].stride_E);
}
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,11 @@ struct TileCfg_256x128x64 : TileCfg_256x256x64 {
static constexpr ck_tile::index_t N_Tile = 128;
};

struct TileCfg_256x128x64_padding : TileCfg_256x128x64 {
static constexpr bool kPadN = true;
template <typename Base, bool PadM_, bool PadN_, bool PadK_>
struct WithPadding : Base {
static constexpr bool kPadM = PadM_;
static constexpr bool kPadN = PadN_;
static constexpr bool kPadK = PadK_;
};

template <typename AType,
Expand Down Expand Up @@ -196,15 +199,15 @@ class GroupedGemmRunner : public RunnerInterface {
}
};

#define MAKE_RUNNER(TileCfg_) \
#define MAKE_RUNNER(BaseCfg_, kPadM_, kPadN_, kPadK_) \
TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.accumulate, accum_option, { \
using Runner = GroupedGemmRunner<AType, \
BType, \
CType, \
ALayout, \
BLayout, \
CLayout, \
TileCfg_, \
WithPadding<BaseCfg_, kPadM_, kPadN_, kPadK_>, \
accum_option>; \
runner = std::make_unique<Runner>(); \
})
Expand All @@ -216,6 +219,37 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype,
const ck_tile::stream_config s{ctx.stream};
std::unique_ptr<RunnerInterface> runner = nullptr;

// Check M and K alignment across all groups.
// All tile configs share the same M_Tile (256) and K_Tile (64).
constexpr ck_tile::index_t M_Tile = TileCfg_256x256x64::M_Tile;
constexpr ck_tile::index_t K_Tile = TileCfg_256x256x64::K_Tile;

bool need_m_pad = false;
bool need_k_pad = false;

for (int i = 0; i < ctx.group_num; ++i) {
const transformer_engine::Tensor* A_te =
transformer_engine::convertNVTETensorCheck(ctx.A[i]);
int64_t Ad0 = 0, Ad1 = 0;
if (get_flat_2d_dims(*A_te, Ad0, Ad1)) {
const int64_t M = ctx.transA ? Ad1 : Ad0;
const int64_t K = ctx.transA ? Ad0 : Ad1;

if (M % M_Tile != 0)
need_m_pad = true;
if (K % K_Tile != 0)
need_k_pad = true;
if (need_m_pad && need_k_pad)
break;
}
}

// CK tile kernel produces incorrect results with kPadK + ColMajor B.
// Fall back to cuBLAS for this combination.
if (need_k_pad && ctx.transB) {
return false;
}
Comment on lines +249 to +251
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming we have access to B column-wise storage, could we avoid this fallback by selecting B’s column-wise buffer and calling CK with transB=false, while preserving the same logical GEMM? In other words, is the incorrect-result issue specific to CK’s kPadK + transB=true / ColMajor-B path, or would the columnwise-buffer normalization still hit the same underlying issue?


TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transA, kTransA, {
using ALayout = std::conditional_t<kTransA, ColMajor, RowMajor>;

Expand All @@ -230,13 +264,17 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, {
using CType = typename TETypeToCKType<d_te_type>::type;

if (ctx.N % 256 == 0) {
MAKE_RUNNER(TileCfg_256x256x64);
} else if (ctx.N % 128 == 0) {
MAKE_RUNNER(TileCfg_256x128x64);
} else {
MAKE_RUNNER(TileCfg_256x128x64_padding);
}
TRANSFORMER_ENGINE_SWITCH_CONDITION(need_m_pad, kPadM, {
TRANSFORMER_ENGINE_SWITCH_CONDITION(need_k_pad, kPadK, {
if (ctx.N % 256 == 0) {
MAKE_RUNNER(TileCfg_256x256x64, kPadM, false, kPadK);
} else if (ctx.N % 128 == 0) {
MAKE_RUNNER(TileCfg_256x128x64, kPadM, false, kPadK);
} else {
MAKE_RUNNER(TileCfg_256x128x64, kPadM, true, kPadK);
}
});
});
});
});
});
Expand Down