HipKittens MXFP8 GEMM Support#566
Conversation
| ) | ||
| if use_bias: | ||
| pytest.skip("hipblaslt GEMM does not yet support MXFP8 with bias.") | ||
| hipkittens_eligible = (m % 256 == 0) and (n % 256 == 0) and (k >= 256) |
| size_t mismatch_limit = use_mxfp8 ? std::max((size_t)1, params.m * params.n / 1'000'000) : 0; | ||
| RefD.to_cpu(); | ||
| compareResults("D", D, RefD.rowwise_cpu_dptr<D_Type>(), true, atol, rtol); | ||
| compareResults("D", D, RefD.rowwise_cpu_dptr<D_Type>(), true, atol, rtol, true, mismatch_limit); |
There was a problem hiding this comment.
The output is not mxfp8. Changing the seeds of the tests caused a few small mismatches due to the extra noise from MXFP8. This occurred for hipBLASlt, not hipKittens (although I assume there are seeds where this is the other way around).
There was a problem hiding this comment.
I see. How about keep the seeds for the original hipblaslt flow or relax their tolerance?
Previously this mismatch_limit relaxation was used to allow 2X scale difference, if I recall correctly
There was a problem hiding this comment.
Why we didn't hit it with CI before?
| launch_pack_scales((const uint8_t *)scale_B, packed_sb, N, scale_K, k_iters, stream); | ||
|
|
||
| GemmEpilogue ep = select_epilogue(bias, aux_gelu); | ||
| dispatch_tn_gemm(ep, a_fp8_code, b_fp8_code, |
There was a problem hiding this comment.
Emm, if you convert other NN, NT to TN, then launch as tn layout, why not request TN directly from TE upstream during the cast transpose and modify the canonicalize gemm function?
Also I thought for mxfp8, transpose quantized data and scales does not give us rowwise -> columnwise conversion, right?
There was a problem hiding this comment.
I think your first question is answered by your second comment. We can't swap in rowwise for colwise, so we can't request TN directly from TE upstream.
To your second question, we're not swapping in rowwise for colwise here, or the other way around. We are transposing the relevant colwise or rowwise data and scales after quantization.
There was a problem hiding this comment.
Then is it possible to change hipkitten mxfp8 kernels to ask non-TN layout to consume columnwise data from TE directly? Then we should be able to remove the dynamic workspace entirely?
BTW, do we or NV upstream currently requires non-TN layout in mxfp8?
| key = (device, ub, grouped_gemm) | ||
| ws = _workspace_cache.get(key) | ||
| if ws is None: | ||
| ws = torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device=device) | ||
| _workspace_cache[key] = ws | ||
| return ws | ||
|
|
||
|
|
||
| def check_mxfp8_workspace(device: int, needed: int) -> None: | ||
| """Grow the workspace to required size""" | ||
| key = (device, False, False) | ||
| ws = _workspace_cache.get(key) | ||
| if ws is not None and ws.shape[0] >= needed: | ||
| return | ||
| _workspace_cache[key] = torch.empty(needed, dtype=torch.uint8, device=device) |
There was a problem hiding this comment.
I have concerns for the proposed workspace cache system:
1). In non-moe runs, it will try to allocate the largest size kitten_gemm needs, replace previous allocated smaller buffers, relying on pytorch garbage collection to deallocate. Then the biggest single buffer will stay in the process starting from the second iteration.
2). For the MOE run, sizes are dynamic, so probably the cache system can still change after the warm up runs
If we can force TE upstream to always provide you TN layout, then we can remove this dynamic workspace entirely?
There was a problem hiding this comment.
I understand your concern, but think we are ok for current models.
1.) This is correct, we only keep the largest workspace, relying on pytorch GC to delete the old workspace. This only affects iteration 1.
2.) Since the workspace is shared for all GEMMs in the model, I think this is unlikely. For example, with DeepSeek 671B with BS=2, the largest non-MoE workspace needed is for the dense layers FFN, where wgrad GEMM will need 200 MB compared to the theoretically maximum MoE GEMM size of 72 MB so this wouldn't occur. For a full MoE Model like Qwen 235B, we still don't run into this issue as the largest non-MoE GEMM would use 96 MB vs 44 MB worst case for MoE.
It is possible that there is a model that exists or could exists where the MoE GEMM is the largest, but convergence theory would imply that we hit the maximum allocation threshold fairly quickly with a many-layer model, and it almost certainly wouldn't affect the performance of a full training run.
There was a problem hiding this comment.
Emm, in addition to my another comment on the possibility to remove this dynamic workspace directly, if we really need this dynamic buffer:
1). let's try to allocate buffer without cache to see if it really hurts the e2e training before working on this delicate buffer cache?
2). Convergence theory usually works in theory papers with input distributional assumptions. I agree for our qwen or ds it works fine. Our library may run into strange corner cases when used by customers.
| kittens::zero(cA); kittens::zero(cB); kittens::zero(cC); kittens::zero(cD); | ||
|
|
||
| const int NUM_XCDS = 8; | ||
| const int WGM = 8; |
There was a problem hiding this comment.
Just wondering was WGM = 8 tuned for this kernel?
| if constexpr (HAS_BIAS) { | ||
| int m_base_lo = block_m + warp_m * REG_M; | ||
| int m_base_hi = block_m + (WARPS_ROW + warp_m) * REG_M; | ||
| int lane = kittens::laneid(); | ||
| int row_off = cA.base_tile_stride * (lane / cA.base_tile_cols); | ||
|
|
||
| #pragma unroll | ||
| for (int i = 0; i < cA.height; i++) { | ||
| #pragma unroll | ||
| for (int j = 0; j < cA.width; j++) { | ||
| #pragma unroll | ||
| for (int k = 0; k < cA.base_tile_num_strides; k++) { | ||
| #pragma unroll | ||
| for (int l = 0; l < cA.base_tile_stride / 2; l++) { | ||
| int idx = l + k * cA.base_tile_stride / 2; | ||
| int m_lo_x = m_base_lo + i * 16 + row_off + l * 2; | ||
| int m_lo_y = m_lo_x + 1; | ||
| int m_hi_x = m_base_hi + i * 16 + row_off + l * 2; | ||
| int m_hi_y = m_hi_x + 1; | ||
| float b_lo_x = read_bias(bias, bias_dtype, m_lo_x); | ||
| float b_lo_y = read_bias(bias, bias_dtype, m_lo_y); | ||
| float b_hi_x = read_bias(bias, bias_dtype, m_hi_x); | ||
| float b_hi_y = read_bias(bias, bias_dtype, m_hi_y); | ||
| cA.tiles[i][j].data[idx].x += b_lo_x; | ||
| cA.tiles[i][j].data[idx].y += b_lo_y; | ||
| cB.tiles[i][j].data[idx].x += b_lo_x; | ||
| cB.tiles[i][j].data[idx].y += b_lo_y; | ||
| cC.tiles[i][j].data[idx].x += b_hi_x; | ||
| cC.tiles[i][j].data[idx].y += b_hi_y; | ||
| cD.tiles[i][j].data[idx].x += b_hi_x; | ||
| cD.tiles[i][j].data[idx].y += b_hi_y; | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
Minor comment: would it make sense to factor the bias-add block into an inline helper? It might make the fused epilogue logic easier to scan.
| } else if (!transa && !transb) { | ||
| return mxfp8_gemm_nn(A, B, C, scale_A, scale_B, M, N, K, | ||
| a_fp8, b_fp8, bias, bias_dc, | ||
| aux_gelu, out_dc, aux_dc, | ||
| workspace, workspace_size, stream); | ||
| } else if (!transa && transb) { | ||
| return mxfp8_gemm_nt(A, B, C, scale_A, scale_B, M, N, K, | ||
| a_fp8, b_fp8, bias, bias_dc, | ||
| aux_gelu, out_dc, aux_dc, | ||
| workspace, workspace_size, stream); | ||
| } |
There was a problem hiding this comment.
Was the transpose-to-TN strategy benchmarked against native NN/NT handling, assuming native NN/NT is available? I’m wondering about the overhead from the data and scale transposes for realistic problem sizes.
There was a problem hiding this comment.
Also, do we expect to ever need support for the TT case?
| aux_gelu, M, N, K, out_dtype, aux_dtype, stream); | ||
| break; | ||
| } | ||
| } |
There was a problem hiding this comment.
Nit: could we use a small local macro here to reduce the repeated argument list? Perhaps something like:
#define DISPATCH_EPILOGUE_CASE(EPI) \
case EPI: { \
dispatch_fp8_types<EPI>(a_fp8, b_fp8, A, B, C, packed_sa, packed_sb, \
bias, bias_dtype, aux_gelu, M, N, K, \
out_dtype, aux_dtype, stream); \
break; \
}
and then:
switch (epilogue) {
DISPATCH_EPILOGUE_CASE(GemmEpilogue::DEFAULT)
DISPATCH_EPILOGUE_CASE(GemmEpilogue::BIAS)
DISPATCH_EPILOGUE_CASE(GemmEpilogue::GELU_AUX)
DISPATCH_EPILOGUE_CASE(GemmEpilogue::GELU_AUX_BIAS)
}
#undef DISPATCH_EPILOGUE_CASE
| if (!use_mxfp8 && params.force_hipblaslt) { | ||
| GTEST_SKIP() << "force_hipblaslt only relevant for MXFP8"; | ||
| } | ||
| if (use_mxfp8) { |
There was a problem hiding this comment.
Add new const bool use_hipblaslt_fp8 = (!use_mxfp8 || param.force_hipblaslt) - this combination is used below for many skips. And all this should be below, under ifdef HIP_PLATFORM_AMD under has_fp8
| workspace_size = 67108864; | ||
| } | ||
| if (use_mxfp8) { | ||
| workspace_size = compute_mxfp8_workspace_size(params.m, params.k, params.n, |
There was a problem hiding this comment.
Skip it if force_hipblaslt?
| size_t mismatch_limit = use_mxfp8 ? std::max((size_t)1, params.m * params.n / 1'000'000) : 0; | ||
| RefD.to_cpu(); | ||
| compareResults("D", D, RefD.rowwise_cpu_dptr<D_Type>(), true, atol, rtol); | ||
| compareResults("D", D, RefD.rowwise_cpu_dptr<D_Type>(), true, atol, rtol, true, mismatch_limit); |
There was a problem hiding this comment.
Why we didn't hit it with CI before?
| [](const testing::TestParamInfo<DqGEMMTestSuite::ParamType>& info) { | ||
| return MKN(std::get<0>(info.param)) + "x" + TN(std::get<3>(info.param)); | ||
| return MKN(std::get<0>(info.param)) + "x" + | ||
| std::to_string(std::get<1>(info.param)) + "x" + |
There was a problem hiding this comment.
What is a point, they are set to false only
| GTEST_SKIP() << "MXFP8 is not supported in current config"; | ||
| } | ||
| if (params.use_bias || params.use_gelu) { | ||
| if (params.force_hipblaslt) { |
There was a problem hiding this comment.
It is skipped below anyway, if add it for future, move it after more generic one
| #include <hip/hip_runtime.h> | ||
| #include <cstddef> | ||
|
|
||
| enum KittensDType { |
There was a problem hiding this comment.
Is it copied from some hipKittent enum? Put comment then
There was a problem hiding this comment.
To avoid confusion with .cu hipification results, please use .cu or .cpp
|
|
||
| if(USE_HIPKITTENS_GEMM) | ||
| list(FIND CMAKE_HIP_ARCHITECTURES "gfx950" _gfx950_index) | ||
| if(_gfx950_index EQUAL -1) |
There was a problem hiding this comment.
May be move it to a function in kittens CMakeList.txt and call it from here?
|
|
||
| return torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device=device) | ||
| key = (device, ub, grouped_gemm) | ||
| ws = _workspace_cache.get(key) |
There was a problem hiding this comment.
Why we don't rely on torch memory caching?
Creates an MXFP8 GEMM with HipKittens that outperforms hipBLASlt, and offers additional epilogues such as BIAS and GELU AUX
Requires a workspace sized relative to the model. Often larger than hipBLASlt, but with significant performance improvements. Only builds for gfx950, and requires M / 256 and N / 256.
Adds hipKittens header library as a submodule.