Skip to content

HipKittens MXFP8 GEMM Support#566

Open
alextmagro wants to merge 11 commits into
devfrom
hipkittens_mxfp8
Open

HipKittens MXFP8 GEMM Support#566
alextmagro wants to merge 11 commits into
devfrom
hipkittens_mxfp8

Conversation

@alextmagro
Copy link
Copy Markdown
Contributor

@alextmagro alextmagro commented Apr 28, 2026

Creates an MXFP8 GEMM with HipKittens that outperforms hipBLASlt, and offers additional epilogues such as BIAS and GELU AUX

Requires a workspace sized relative to the model. Often larger than hipBLASlt, but with significant performance improvements. Only builds for gfx950, and requires M / 256 and N / 256.

Adds hipKittens header library as a submodule.

Comment thread tests/cpp/operator/test_cublaslt_gemm.cu Outdated
Comment thread tests/cpp/operator/test_cublaslt_gemm.cu Outdated
Comment thread tests/cpp/operator/test_cublaslt_gemm.cu
Comment thread tests/cpp/operator/test_cublaslt_gemm.cu
Comment thread tests/jax/utils.py
)
if use_bias:
pytest.skip("hipblaslt GEMM does not yet support MXFP8 with bias.")
hipkittens_eligible = (m % 256 == 0) and (n % 256 == 0) and (k >= 256)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same hardcoding 256s...

size_t mismatch_limit = use_mxfp8 ? std::max((size_t)1, params.m * params.n / 1'000'000) : 0;
RefD.to_cpu();
compareResults("D", D, RefD.rowwise_cpu_dptr<D_Type>(), true, atol, rtol);
compareResults("D", D, RefD.rowwise_cpu_dptr<D_Type>(), true, atol, rtol, true, mismatch_limit);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our output is mxfp8?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output is not mxfp8. Changing the seeds of the tests caused a few small mismatches due to the extra noise from MXFP8. This occurred for hipBLASlt, not hipKittens (although I assume there are seeds where this is the other way around).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. How about keep the seeds for the original hipblaslt flow or relax their tolerance?

Previously this mismatch_limit relaxation was used to allow 2X scale difference, if I recall correctly

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we didn't hit it with CI before?

Comment thread transformer_engine/common/gemm/kittens/mxfp8_gemm.hip Outdated
launch_pack_scales((const uint8_t *)scale_B, packed_sb, N, scale_K, k_iters, stream);

GemmEpilogue ep = select_epilogue(bias, aux_gelu);
dispatch_tn_gemm(ep, a_fp8_code, b_fp8_code,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Emm, if you convert other NN, NT to TN, then launch as tn layout, why not request TN directly from TE upstream during the cast transpose and modify the canonicalize gemm function?

Also I thought for mxfp8, transpose quantized data and scales does not give us rowwise -> columnwise conversion, right?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think your first question is answered by your second comment. We can't swap in rowwise for colwise, so we can't request TN directly from TE upstream.

To your second question, we're not swapping in rowwise for colwise here, or the other way around. We are transposing the relevant colwise or rowwise data and scales after quantization.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then is it possible to change hipkitten mxfp8 kernels to ask non-TN layout to consume columnwise data from TE directly? Then we should be able to remove the dynamic workspace entirely?

BTW, do we or NV upstream currently requires non-TN layout in mxfp8?

Comment on lines +107 to +121
key = (device, ub, grouped_gemm)
ws = _workspace_cache.get(key)
if ws is None:
ws = torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device=device)
_workspace_cache[key] = ws
return ws


def check_mxfp8_workspace(device: int, needed: int) -> None:
"""Grow the workspace to required size"""
key = (device, False, False)
ws = _workspace_cache.get(key)
if ws is not None and ws.shape[0] >= needed:
return
_workspace_cache[key] = torch.empty(needed, dtype=torch.uint8, device=device)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have concerns for the proposed workspace cache system:
1). In non-moe runs, it will try to allocate the largest size kitten_gemm needs, replace previous allocated smaller buffers, relying on pytorch garbage collection to deallocate. Then the biggest single buffer will stay in the process starting from the second iteration.
2). For the MOE run, sizes are dynamic, so probably the cache system can still change after the warm up runs

If we can force TE upstream to always provide you TN layout, then we can remove this dynamic workspace entirely?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand your concern, but think we are ok for current models.

1.) This is correct, we only keep the largest workspace, relying on pytorch GC to delete the old workspace. This only affects iteration 1.
2.) Since the workspace is shared for all GEMMs in the model, I think this is unlikely. For example, with DeepSeek 671B with BS=2, the largest non-MoE workspace needed is for the dense layers FFN, where wgrad GEMM will need 200 MB compared to the theoretically maximum MoE GEMM size of 72 MB so this wouldn't occur. For a full MoE Model like Qwen 235B, we still don't run into this issue as the largest non-MoE GEMM would use 96 MB vs 44 MB worst case for MoE.

It is possible that there is a model that exists or could exists where the MoE GEMM is the largest, but convergence theory would imply that we hit the maximum allocation threshold fairly quickly with a many-layer model, and it almost certainly wouldn't affect the performance of a full training run.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Emm, in addition to my another comment on the possibility to remove this dynamic workspace directly, if we really need this dynamic buffer:
1). let's try to allocate buffer without cache to see if it really hurts the e2e training before working on this delicate buffer cache?
2). Convergence theory usually works in theory papers with input distributional assumptions. I agree for our qwen or ds it works fine. Our library may run into strange corner cases when used by customers.

Comment thread transformer_engine/pytorch/cpp_extensions/gemm.py Outdated
@alextmagro alextmagro requested a review from wangye805 May 5, 2026 20:26
kittens::zero(cA); kittens::zero(cB); kittens::zero(cC); kittens::zero(cD);

const int NUM_XCDS = 8;
const int WGM = 8;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering was WGM = 8 tuned for this kernel?

Comment on lines +306 to +341
if constexpr (HAS_BIAS) {
int m_base_lo = block_m + warp_m * REG_M;
int m_base_hi = block_m + (WARPS_ROW + warp_m) * REG_M;
int lane = kittens::laneid();
int row_off = cA.base_tile_stride * (lane / cA.base_tile_cols);

#pragma unroll
for (int i = 0; i < cA.height; i++) {
#pragma unroll
for (int j = 0; j < cA.width; j++) {
#pragma unroll
for (int k = 0; k < cA.base_tile_num_strides; k++) {
#pragma unroll
for (int l = 0; l < cA.base_tile_stride / 2; l++) {
int idx = l + k * cA.base_tile_stride / 2;
int m_lo_x = m_base_lo + i * 16 + row_off + l * 2;
int m_lo_y = m_lo_x + 1;
int m_hi_x = m_base_hi + i * 16 + row_off + l * 2;
int m_hi_y = m_hi_x + 1;
float b_lo_x = read_bias(bias, bias_dtype, m_lo_x);
float b_lo_y = read_bias(bias, bias_dtype, m_lo_y);
float b_hi_x = read_bias(bias, bias_dtype, m_hi_x);
float b_hi_y = read_bias(bias, bias_dtype, m_hi_y);
cA.tiles[i][j].data[idx].x += b_lo_x;
cA.tiles[i][j].data[idx].y += b_lo_y;
cB.tiles[i][j].data[idx].x += b_lo_x;
cB.tiles[i][j].data[idx].y += b_lo_y;
cC.tiles[i][j].data[idx].x += b_hi_x;
cC.tiles[i][j].data[idx].y += b_hi_y;
cD.tiles[i][j].data[idx].x += b_hi_x;
cD.tiles[i][j].data[idx].y += b_hi_y;
}
}
}
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comment: would it make sense to factor the bias-add block into an inline helper? It might make the fused epilogue logic easier to scan.

Comment on lines +811 to +821
} else if (!transa && !transb) {
return mxfp8_gemm_nn(A, B, C, scale_A, scale_B, M, N, K,
a_fp8, b_fp8, bias, bias_dc,
aux_gelu, out_dc, aux_dc,
workspace, workspace_size, stream);
} else if (!transa && transb) {
return mxfp8_gemm_nt(A, B, C, scale_A, scale_B, M, N, K,
a_fp8, b_fp8, bias, bias_dc,
aux_gelu, out_dc, aux_dc,
workspace, workspace_size, stream);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was the transpose-to-TN strategy benchmarked against native NN/NT handling, assuming native NN/NT is available? I’m wondering about the overhead from the data and scale transposes for realistic problem sizes.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, do we expect to ever need support for the TT case?

aux_gelu, M, N, K, out_dtype, aux_dtype, stream);
break;
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: could we use a small local macro here to reduce the repeated argument list? Perhaps something like:

#define DISPATCH_EPILOGUE_CASE(EPI)                                      \
  case EPI: {                                                            \
    dispatch_fp8_types<EPI>(a_fp8, b_fp8, A, B, C, packed_sa, packed_sb,  \
                            bias, bias_dtype, aux_gelu, M, N, K,         \
                            out_dtype, aux_dtype, stream);               \
    break;                                                               \
  }

and then:

switch (epilogue) {
  DISPATCH_EPILOGUE_CASE(GemmEpilogue::DEFAULT)
  DISPATCH_EPILOGUE_CASE(GemmEpilogue::BIAS)
  DISPATCH_EPILOGUE_CASE(GemmEpilogue::GELU_AUX)
  DISPATCH_EPILOGUE_CASE(GemmEpilogue::GELU_AUX_BIAS)
}
#undef DISPATCH_EPILOGUE_CASE

if (!use_mxfp8 && params.force_hipblaslt) {
GTEST_SKIP() << "force_hipblaslt only relevant for MXFP8";
}
if (use_mxfp8) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add new const bool use_hipblaslt_fp8 = (!use_mxfp8 || param.force_hipblaslt) - this combination is used below for many skips. And all this should be below, under ifdef HIP_PLATFORM_AMD under has_fp8

workspace_size = 67108864;
}
if (use_mxfp8) {
workspace_size = compute_mxfp8_workspace_size(params.m, params.k, params.n,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skip it if force_hipblaslt?

size_t mismatch_limit = use_mxfp8 ? std::max((size_t)1, params.m * params.n / 1'000'000) : 0;
RefD.to_cpu();
compareResults("D", D, RefD.rowwise_cpu_dptr<D_Type>(), true, atol, rtol);
compareResults("D", D, RefD.rowwise_cpu_dptr<D_Type>(), true, atol, rtol, true, mismatch_limit);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we didn't hit it with CI before?

[](const testing::TestParamInfo<DqGEMMTestSuite::ParamType>& info) {
return MKN(std::get<0>(info.param)) + "x" + TN(std::get<3>(info.param));
return MKN(std::get<0>(info.param)) + "x" +
std::to_string(std::get<1>(info.param)) + "x" +
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is a point, they are set to false only

GTEST_SKIP() << "MXFP8 is not supported in current config";
}
if (params.use_bias || params.use_gelu) {
if (params.force_hipblaslt) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is skipped below anyway, if add it for future, move it after more generic one

#include <hip/hip_runtime.h>
#include <cstddef>

enum KittensDType {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it copied from some hipKittent enum? Put comment then

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid confusion with .cu hipification results, please use .cu or .cpp


if(USE_HIPKITTENS_GEMM)
list(FIND CMAKE_HIP_ARCHITECTURES "gfx950" _gfx950_index)
if(_gfx950_index EQUAL -1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be move it to a function in kittens CMakeList.txt and call it from here?


return torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device=device)
key = (device, ub, grouped_gemm)
ws = _workspace_cache.get(key)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we don't rely on torch memory caching?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 1 CI test level 1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants