-
Notifications
You must be signed in to change notification settings - Fork 28
HipKittens MXFP8 GEMM Support #566
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
alextmagro
wants to merge
11
commits into
dev
Choose a base branch
from
hipkittens_mxfp8
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
f9d5ce2
HipKittens MXFP8 GEMM Support
alextmagro aac5860
Update HipKittens branch after upstream MXFP8 merge
alextmagro c917ed0
Merge remote-tracking branch 'origin/dev' into hipkittens_mxfp8
alextmagro 3a91321
Update HipKittens commit and address PR comments
alextmagro cc719fe
Merge remote-tracking branch 'origin/dev' into hipkittens_mxfp8 with
alextmagro fcda154
Resolve conflicts, ensure fp4 workspace changes are harmonious
alextmagro 70fba6d
min workspace size guaranteed
alextmagro 455002e
add hipkittens to wheels
alextmagro ba60ef5
fix issue with gfx942 for unified build
alextmagro f72b7b8
Cleanup and workspace changes
alextmagro 731640a
Merge remote-tracking branch 'origin/dev' into hipkittens_mxfp8
alextmagro File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Submodule hipkittens
added at
778274
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,7 +30,9 @@ std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes = { | |
|
|
||
| std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes_mxfp8 = { | ||
| {32, 128, 16}, | ||
| {256, 256, 256}, | ||
| {768, 3072, 4096}, | ||
| {4096, 16384, 4096}, | ||
| }; | ||
|
|
||
| // A, B, Bias, Gelu, D | ||
|
|
@@ -168,6 +170,20 @@ __global__ void compute_ref_kernel( | |
| } | ||
|
|
||
|
|
||
| constexpr size_t kMXFP8GroupSize = 32; | ||
| constexpr size_t kKTileSize = 128; | ||
|
|
||
| static size_t compute_mxfp8_workspace_size(size_t m, size_t k, size_t n, bool transa, bool transb, size_t base_size) { | ||
| size_t k_iters = k / kKTileSize; | ||
| size_t scale_k = k / kMXFP8GroupSize; | ||
| size_t sa_pk = round_up_to_nearest_multiple(k_iters * m * 4, 256); | ||
| size_t sb_pk = k_iters * n * 4; | ||
| size_t needed = round_up_to_nearest_multiple(sa_pk, 256) + sb_pk; | ||
| if (!transa) needed += round_up_to_nearest_multiple(m * k, 256) + round_up_to_nearest_multiple(m * scale_k, 256) + round_up_to_nearest_multiple(sa_pk, 256); | ||
| if (transb) needed += round_up_to_nearest_multiple(n * k, 256) + round_up_to_nearest_multiple(n * scale_k, 256) + round_up_to_nearest_multiple(sb_pk, 256); | ||
| return std::max(base_size, needed); | ||
| } | ||
|
|
||
| struct TestParams { | ||
| size_t m; | ||
| size_t k; | ||
|
|
@@ -177,6 +193,7 @@ struct TestParams { | |
| bool transa; | ||
| bool transb; | ||
| NVTEScalingMode scaling_mode; | ||
| bool force_hipblaslt; | ||
| }; | ||
|
|
||
|
|
||
|
|
@@ -341,8 +358,7 @@ void performTest(const TestParams& params) { | |
| const bool has_fp8 = isFp8Type(atype) || isFp8Type(btype); | ||
| const bool use_mxfp8 = params.scaling_mode == NVTEScalingMode::NVTE_MXFP8_1D_SCALING; | ||
|
|
||
| if (use_mxfp8) | ||
| { | ||
| if (use_mxfp8) { | ||
| if (!has_fp8) { | ||
| GTEST_SKIP() << "MXFP8 scaling mode requires Float8 types"; | ||
| } | ||
|
|
@@ -352,6 +368,9 @@ void performTest(const TestParams& params) { | |
| if (params.k % 128) { | ||
| GTEST_SKIP() << "MXFP8 requires K to be a multiple of 128"; | ||
| } | ||
| if (!params.force_hipblaslt && (params.m % 256 || params.n % 256 || params.k < 256)) { | ||
| GTEST_SKIP() << "HipKittens requires M and N 256-aligned, K >= 256"; | ||
| } | ||
|
wangye805 marked this conversation as resolved.
|
||
| } | ||
|
|
||
| cudaDeviceProp prop; | ||
|
|
@@ -387,22 +406,14 @@ void performTest(const TestParams& params) { | |
| if (!fp8_supported) { | ||
| GTEST_SKIP() << "FP8 is not supported in current config"; | ||
| } | ||
|
|
||
| if (use_mxfp8) | ||
| { | ||
| bool mxfp8_supported = (prop.major == 9 && prop.minor >= 5) || prop.major >= 12; | ||
| if (!mxfp8_supported) { | ||
| GTEST_SKIP() << "MXFP8 is not supported in current config"; | ||
| } | ||
| if (isFp8Type(dtype)){ | ||
| GTEST_SKIP() << "MXFP8 with float8 output is not supported"; | ||
| } | ||
| if (params.use_bias) { | ||
| GTEST_SKIP() << "MXFP8 GEMM with bias is not supported"; | ||
| } | ||
| bool mxfp8_supported = (prop.major == 9 && prop.minor >= 5) || prop.major >= 12; | ||
| if (use_mxfp8 && !mxfp8_supported) { | ||
| GTEST_SKIP() << "MXFP8 is not supported in current config"; | ||
| } | ||
|
|
||
| if (params.use_gelu && !fp8_gelu_fusion_config) { | ||
| if (use_mxfp8 && params.use_bias && params.force_hipblaslt) { | ||
| GTEST_SKIP() << "MXFP8 GEMM with bias is not supported by hipBLASLt"; | ||
| } | ||
| if (params.use_gelu && !fp8_gelu_fusion_config && (params.force_hipblaslt || !use_mxfp8)) { | ||
| GTEST_SKIP() << "FP8 GEMM with GELU is not supported in current config"; | ||
| } | ||
| if (params.use_bias && dtype == DType::kFloat16) { | ||
|
|
@@ -412,29 +423,27 @@ void performTest(const TestParams& params) { | |
|
|
||
| if (prop.major == 9 && prop.minor == 5) //gfx950 specific hipblasLt limitations | ||
| { | ||
| if (isFp8Type(dtype)){ | ||
| if (isFp8Type(dtype)) { | ||
| GTEST_SKIP() << "GEMM with float8 output is not supported"; | ||
| } | ||
| if (params.use_gelu && dtype == DType::kBFloat16) { | ||
| if (params.use_gelu && dtype == DType::kBFloat16 && (params.force_hipblaslt || !use_mxfp8)) { | ||
| GTEST_SKIP() << "BF16 GEMM with GELU is not supported in current config"; | ||
| } | ||
| if constexpr ((std::is_same<A_Type, bf8>::value || std::is_same<B_Type, bf8>::value) && | ||
| std::is_same<D_Type, fp32>::value) | ||
| { | ||
| //GEMM with bias and fp32 output is not supported with bf8 A/B | ||
| if constexpr ((std::is_same_v<A_Type, bf8> || std::is_same_v<B_Type, bf8>) && | ||
| std::is_same_v<D_Type, fp32>) { | ||
| if (params.use_bias) { | ||
| GTEST_SKIP() << "FP8 GEMM with bias is not supported in current config"; | ||
| } | ||
| } | ||
| } | ||
| if (prop.major == 9 && prop.minor == 4) //gfx942 specific hipblasLt limitations | ||
| else if (prop.major == 9 && prop.minor == 4) //gfx942 specific hipblasLt limitations | ||
| { | ||
| #if HIP_VERSION < 70100000 | ||
| if (params.use_gelu && dtype == DType::kBFloat16 && !params.transa) { | ||
| GTEST_SKIP() << "BF16 GEMM with GELU is not supported in current config"; | ||
| } | ||
| #endif | ||
| if constexpr (std::is_same<D_Type, fp8>::value && std::is_same<Bias_Type, bf16>::value) { | ||
| if constexpr (std::is_same_v<D_Type, fp8> && std::is_same_v<Bias_Type, bf16>) { | ||
| if (params.use_bias && !fp8_gelu_fusion_config) { | ||
| GTEST_SKIP() << "GEMM with BF16 bias and FP8 output is not supported in current config"; | ||
| } | ||
|
|
@@ -493,6 +502,11 @@ void performTest(const TestParams& params) { | |
| if ((prop.major == 9 && prop.minor == 5) || prop.major >= 12) { | ||
| workspace_size = 67108864; | ||
| } | ||
| if (use_mxfp8 && !params.force_hipblaslt) { | ||
| workspace_size = compute_mxfp8_workspace_size(params.m, params.k, params.n, | ||
|
wangye805 marked this conversation as resolved.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Skip it if force_hipblaslt? |
||
| params.transa, params.transb, | ||
| workspace_size); | ||
| } | ||
| #endif | ||
| Tensor Workspace("Workspace", TShape{ workspace_size }, DType::kByte); | ||
|
|
||
|
|
@@ -547,11 +561,12 @@ void performTest(const TestParams& params) { | |
| } | ||
|
|
||
| auto [atol, rtol] = getTestTolerances(dtype, has_fp8, use_mxfp8); | ||
| size_t mismatch_limit = use_mxfp8 ? std::max((size_t)1, params.m * params.n / 1'000'000) : 0; | ||
| RefD.to_cpu(); | ||
| compareResults("D", D, RefD.rowwise_cpu_dptr<D_Type>(), true, atol, rtol); | ||
|
|
||
| if(params.use_gelu){ | ||
| auto [atol, rtol] = getTestTolerances(gelu_type, false, false); | ||
| auto [atol, rtol] = getTestTolerances(gelu_type, has_fp8, use_mxfp8); | ||
| RefPreGeluOut.to_cpu(); | ||
| compareResults("gelu", pre_gelu_out, RefPreGeluOut.rowwise_cpu_dptr<Gelu_Type>(), true, atol, rtol); | ||
| } | ||
|
|
@@ -581,6 +596,12 @@ void performDqTest(const TestParams ¶ms) { | |
| if (!mxfp8_supported) { | ||
| GTEST_SKIP() << "MXFP8 is not supported in current config"; | ||
| } | ||
| if (params.use_bias || params.use_gelu) { | ||
| GTEST_SKIP() << "DqGEMMTestSuite does not yet have reference for bias/gelu epilogues"; | ||
| } | ||
| if (!params.force_hipblaslt && (params.m % 256 || params.n % 256 || params.k % 128 || params.k < 256)) { | ||
| GTEST_SKIP() << "HipKittens requires M and N 256-aligned, K >= 256"; | ||
| } | ||
|
|
||
| DType ref_type = dtype; | ||
| TShape a_shape = params.transa ? TShape{params.m, params.k} : TShape{params.k, params.m}; | ||
|
|
@@ -608,7 +629,9 @@ void performDqTest(const TestParams ¶ms) { | |
| Tensor bias; | ||
| Tensor pre_gelu_out; | ||
|
|
||
| size_t workspace_size = 67108864; | ||
| size_t workspace_size = compute_mxfp8_workspace_size(params.m, params.k, params.n, | ||
| params.transa, params.transb, | ||
| 67108864); // 64 MiB required for hipBLASlt | ||
| Tensor Workspace("Workspace", TShape{workspace_size}, DType::kByte); | ||
|
|
||
| //perform FP8 gemm and copy the output results from GPU memory to CPU memory | ||
|
|
@@ -638,6 +661,12 @@ void performDqTest(const TestParams ¶ms) { | |
| #endif // __HIP_PLATFORM_AMD__ | ||
|
|
||
| #define MAKE_TEST_PARAMS(P_) \ | ||
| bool force_hipblaslt_ = std::get<5>(GetParam()); \ | ||
| if (force_hipblaslt_) { \ | ||
| setenv("NVTE_ROCM_USE_HIPBLASLT_MXFP8", "1", 1); \ | ||
| } else { \ | ||
| unsetenv("NVTE_ROCM_USE_HIPBLASLT_MXFP8"); \ | ||
| } \ | ||
| TestParams P_ = {.m = std::get<0>(std::get<0>(GetParam())), \ | ||
| .k = std::get<1>(std::get<0>(GetParam())), \ | ||
| .n = std::get<2>(std::get<0>(GetParam())), \ | ||
|
|
@@ -646,13 +675,14 @@ void performDqTest(const TestParams ¶ms) { | |
| .transa = std::get<3>(GetParam()).first, \ | ||
| .transb = std::get<3>(GetParam()).second, \ | ||
| .scaling_mode = std::get<4>(GetParam()) \ | ||
| ? NVTEScalingMode::NVTE_MXFP8_1D_SCALING \ | ||
| : NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING} | ||
| ? NVTEScalingMode::NVTE_MXFP8_1D_SCALING \ | ||
| : NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING,\ | ||
| .force_hipblaslt = force_hipblaslt_} | ||
|
|
||
| // <m, k, n>, use_bias, use_gelu, Layout, fp8_scalinig | ||
| // <m, k, n>, use_bias, use_gelu, Layout, fp8_scaling, force_hipblaslt | ||
| class GEMMTestSuite | ||
| : public ::testing::TestWithParam< | ||
| std::tuple<std::tuple<size_t, size_t, size_t>, bool, bool, Layout, NVTEScalingMode>> {}; | ||
| std::tuple<std::tuple<size_t, size_t, size_t>, bool, bool, Layout, NVTEScalingMode, bool>> {}; | ||
|
|
||
| #define MAKE_GEMM_TEST(NAME_, A_, B_, BIAS_, GELU_, D_) \ | ||
| TEST_P(GEMMTestSuite, NAME_) { \ | ||
|
|
@@ -713,19 +743,32 @@ static inline auto MKN(const std::tuple<size_t, size_t, size_t>& shape) { | |
| std::to_string(std::get<2>(shape)); | ||
| } | ||
|
|
||
| static std::string GEMMTestName(const testing::TestParamInfo<GEMMTestSuite::ParamType>& info) { | ||
| return MKN(std::get<0>(info.param)) + "x" + | ||
| std::to_string(std::get<1>(info.param)) + "x" + | ||
| std::to_string(std::get<2>(info.param)) + "x" + | ||
| TN(std::get<3>(info.param)) + "x" + | ||
| (std::get<4>(info.param) ? "M" : "S") + "x" + | ||
| (std::get<5>(info.param) ? "HB" : "HK"); | ||
| } | ||
|
|
||
| INSTANTIATE_TEST_SUITE_P(OperatorTest, GEMMTestSuite, | ||
| ::testing::Combine(::testing::ValuesIn(test_case_sizes), | ||
| ::testing::Values(false, true), //use bias | ||
| ::testing::Values(false, true), //use_gelu | ||
| ::testing::ValuesIn(kLayouts), //transa,transb | ||
| ::testing::Values(false, true)), //use mxfp8 | ||
| [](const testing::TestParamInfo<GEMMTestSuite::ParamType>& info) { | ||
| return MKN(std::get<0>(info.param)) + "x" + | ||
| std::to_string(std::get<1>(info.param)) + "x" + | ||
| std::to_string(std::get<2>(info.param)) + "x" + | ||
| TN(std::get<3>(info.param)) + "x" + | ||
| (std::get<4>(info.param) ? "M" : "S"); | ||
| }); | ||
| ::testing::Values(false), //use mxfp8 | ||
| ::testing::Values(false)), //force hipblaslt | ||
| GEMMTestName); | ||
|
|
||
| INSTANTIATE_TEST_SUITE_P(OperatorTestMXFP8, GEMMTestSuite, | ||
| ::testing::Combine(::testing::ValuesIn(test_case_sizes), | ||
| ::testing::Values(false, true), //use bias | ||
| ::testing::Values(false, true), //use_gelu | ||
| ::testing::ValuesIn(kLayouts), //transa,transb | ||
| ::testing::Values(true), //use mxfp8 | ||
| ::testing::Values(false, true)), //force hipblaslt | ||
| GEMMTestName); | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| class DqGEMMTestSuite: public GEMMTestSuite {}; | ||
|
|
@@ -743,12 +786,15 @@ MAKE_DQ_GEMM_TEST(Testfp8xfp8xfp16, fp8, fp8, fp16) | |
|
|
||
| INSTANTIATE_TEST_SUITE_P(OperatorTest, DqGEMMTestSuite, | ||
| ::testing::Combine(::testing::ValuesIn(test_case_sizes_mxfp8), | ||
| ::testing::Values(false), // bias - unused | ||
| ::testing::Values(false), // gelu - unused | ||
| ::testing::ValuesIn(kLayouts), //transa,transb | ||
| ::testing::Values(true)), //use mxfp8 | ||
| ::testing::Values(false), // use bias | ||
| ::testing::Values(false), // use gelu | ||
| ::testing::ValuesIn(kLayouts), // transa,transb | ||
| ::testing::Values(true), // use mxfp8 | ||
| ::testing::Values(false, true)), // force hipblaslt | ||
| [](const testing::TestParamInfo<DqGEMMTestSuite::ParamType>& info) { | ||
| return MKN(std::get<0>(info.param)) + "x" + TN(std::get<3>(info.param)); | ||
| return MKN(std::get<0>(info.param)) + "x" + | ||
| TN(std::get<3>(info.param)) + "x" + | ||
| (std::get<5>(info.param) ? "HB" : "HK"); | ||
| }); | ||
|
|
||
| TEST(InputGenTest, FillUniform_DoesNotGetOverwrittenByFromCpu) { | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -61,8 +61,9 @@ def _check_mxfp8_gemm_support(with_jax_gemm, m, n, k, use_bias=False): | |
| pytest.skip( | ||
| f"Input shape {(m, k)} x {(k, n)} is not supported by hipblaslt MXFP8 GEMM." | ||
| ) | ||
| if use_bias: | ||
| pytest.skip("hipblaslt GEMM does not yet support MXFP8 with bias.") | ||
| hipkittens_eligible = (m % 256 == 0) and (n % 256 == 0) and (k >= 256) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same hardcoding 256s... |
||
| if use_bias and not hipkittens_eligible: | ||
| pytest.skip("hipblaslt GEMM does not support MXFP8 with bias.") | ||
| else: | ||
| jax_version = version.parse(jax.__version__) | ||
| if jax_version < version.parse("0.8.2"): | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add new const bool use_hipblaslt_fp8 = (!use_mxfp8 || param.force_hipblaslt) - this combination is used below for many skips. And all this should be below, under ifdef HIP_PLATFORM_AMD under has_fp8