Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/rocm-wheels-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ jobs:
3rdparty/aotriton \
3rdparty/aiter \
3rdparty/QoLA \
3rdparty/hipify_torch
3rdparty/hipify_torch \
3rdparty/hipkittens

- name: Derive Docker image tag
id: set-tag
Expand Down
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,6 @@
[submodule "3rdparty/QoLA"]
path = 3rdparty/QoLA
url = https://github.com/Micky774/QoLA.git
[submodule "3rdparty/hipkittens"]
path = 3rdparty/hipkittens
url = https://github.com/HazyResearch/HipKittens.git
1 change: 1 addition & 0 deletions 3rdparty/hipkittens
Submodule hipkittens added at 778274
1 change: 1 addition & 0 deletions ci/pytorch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ run_test_config(){
run 1 test_jit.py
NVTE_ROCM_ENABLE_MXFP8=1 run_default_fa 1 test_multi_tensor.py
run 1 test_numerics.py
NVTE_ROCM_ENABLE_MXFP8=1 run_default_fa_lbl "mxfp8" 1 test_numerics.py -k "recipe0 and 126m and not grouped"
run_default_fa 1 test_permutation.py
run_default_fa 1 test_recipe.py
run 1 test_sanity.py
Expand Down
134 changes: 90 additions & 44 deletions tests/cpp/operator/test_cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes = {

std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes_mxfp8 = {
{32, 128, 16},
{256, 256, 256},
{768, 3072, 4096},
{4096, 16384, 4096},
};

// A, B, Bias, Gelu, D
Expand Down Expand Up @@ -168,6 +170,20 @@ __global__ void compute_ref_kernel(
}


constexpr size_t kMXFP8GroupSize = 32;
constexpr size_t kKTileSize = 128;

static size_t compute_mxfp8_workspace_size(size_t m, size_t k, size_t n, bool transa, bool transb, size_t base_size) {
size_t k_iters = k / kKTileSize;
size_t scale_k = k / kMXFP8GroupSize;
size_t sa_pk = round_up_to_nearest_multiple(k_iters * m * 4, 256);
size_t sb_pk = k_iters * n * 4;
size_t needed = round_up_to_nearest_multiple(sa_pk, 256) + sb_pk;
if (!transa) needed += round_up_to_nearest_multiple(m * k, 256) + round_up_to_nearest_multiple(m * scale_k, 256) + round_up_to_nearest_multiple(sa_pk, 256);
if (transb) needed += round_up_to_nearest_multiple(n * k, 256) + round_up_to_nearest_multiple(n * scale_k, 256) + round_up_to_nearest_multiple(sb_pk, 256);
return std::max(base_size, needed);
}

struct TestParams {
size_t m;
size_t k;
Expand All @@ -177,6 +193,7 @@ struct TestParams {
bool transa;
bool transb;
NVTEScalingMode scaling_mode;
bool force_hipblaslt;
};


Expand Down Expand Up @@ -341,8 +358,7 @@ void performTest(const TestParams& params) {
const bool has_fp8 = isFp8Type(atype) || isFp8Type(btype);
const bool use_mxfp8 = params.scaling_mode == NVTEScalingMode::NVTE_MXFP8_1D_SCALING;

if (use_mxfp8)
{
if (use_mxfp8) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add new const bool use_hipblaslt_fp8 = (!use_mxfp8 || param.force_hipblaslt) - this combination is used below for many skips. And all this should be below, under ifdef HIP_PLATFORM_AMD under has_fp8

if (!has_fp8) {
GTEST_SKIP() << "MXFP8 scaling mode requires Float8 types";
}
Expand All @@ -352,6 +368,9 @@ void performTest(const TestParams& params) {
if (params.k % 128) {
GTEST_SKIP() << "MXFP8 requires K to be a multiple of 128";
}
if (!params.force_hipblaslt && (params.m % 256 || params.n % 256 || params.k < 256)) {
GTEST_SKIP() << "HipKittens requires M and N 256-aligned, K >= 256";
}
Comment thread
wangye805 marked this conversation as resolved.
}

cudaDeviceProp prop;
Expand Down Expand Up @@ -387,22 +406,14 @@ void performTest(const TestParams& params) {
if (!fp8_supported) {
GTEST_SKIP() << "FP8 is not supported in current config";
}

if (use_mxfp8)
{
bool mxfp8_supported = (prop.major == 9 && prop.minor >= 5) || prop.major >= 12;
if (!mxfp8_supported) {
GTEST_SKIP() << "MXFP8 is not supported in current config";
}
if (isFp8Type(dtype)){
GTEST_SKIP() << "MXFP8 with float8 output is not supported";
}
if (params.use_bias) {
GTEST_SKIP() << "MXFP8 GEMM with bias is not supported";
}
bool mxfp8_supported = (prop.major == 9 && prop.minor >= 5) || prop.major >= 12;
if (use_mxfp8 && !mxfp8_supported) {
GTEST_SKIP() << "MXFP8 is not supported in current config";
}

if (params.use_gelu && !fp8_gelu_fusion_config) {
if (use_mxfp8 && params.use_bias && params.force_hipblaslt) {
GTEST_SKIP() << "MXFP8 GEMM with bias is not supported by hipBLASLt";
}
if (params.use_gelu && !fp8_gelu_fusion_config && (params.force_hipblaslt || !use_mxfp8)) {
GTEST_SKIP() << "FP8 GEMM with GELU is not supported in current config";
}
if (params.use_bias && dtype == DType::kFloat16) {
Expand All @@ -412,29 +423,27 @@ void performTest(const TestParams& params) {

if (prop.major == 9 && prop.minor == 5) //gfx950 specific hipblasLt limitations
{
if (isFp8Type(dtype)){
if (isFp8Type(dtype)) {
GTEST_SKIP() << "GEMM with float8 output is not supported";
}
if (params.use_gelu && dtype == DType::kBFloat16) {
if (params.use_gelu && dtype == DType::kBFloat16 && (params.force_hipblaslt || !use_mxfp8)) {
GTEST_SKIP() << "BF16 GEMM with GELU is not supported in current config";
}
if constexpr ((std::is_same<A_Type, bf8>::value || std::is_same<B_Type, bf8>::value) &&
std::is_same<D_Type, fp32>::value)
{
//GEMM with bias and fp32 output is not supported with bf8 A/B
if constexpr ((std::is_same_v<A_Type, bf8> || std::is_same_v<B_Type, bf8>) &&
std::is_same_v<D_Type, fp32>) {
if (params.use_bias) {
GTEST_SKIP() << "FP8 GEMM with bias is not supported in current config";
}
}
}
if (prop.major == 9 && prop.minor == 4) //gfx942 specific hipblasLt limitations
else if (prop.major == 9 && prop.minor == 4) //gfx942 specific hipblasLt limitations
{
#if HIP_VERSION < 70100000
if (params.use_gelu && dtype == DType::kBFloat16 && !params.transa) {
GTEST_SKIP() << "BF16 GEMM with GELU is not supported in current config";
}
#endif
if constexpr (std::is_same<D_Type, fp8>::value && std::is_same<Bias_Type, bf16>::value) {
if constexpr (std::is_same_v<D_Type, fp8> && std::is_same_v<Bias_Type, bf16>) {
if (params.use_bias && !fp8_gelu_fusion_config) {
GTEST_SKIP() << "GEMM with BF16 bias and FP8 output is not supported in current config";
}
Expand Down Expand Up @@ -493,6 +502,11 @@ void performTest(const TestParams& params) {
if ((prop.major == 9 && prop.minor == 5) || prop.major >= 12) {
workspace_size = 67108864;
}
if (use_mxfp8 && !params.force_hipblaslt) {
workspace_size = compute_mxfp8_workspace_size(params.m, params.k, params.n,
Comment thread
wangye805 marked this conversation as resolved.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skip it if force_hipblaslt?

params.transa, params.transb,
workspace_size);
}
#endif
Tensor Workspace("Workspace", TShape{ workspace_size }, DType::kByte);

Expand Down Expand Up @@ -547,11 +561,12 @@ void performTest(const TestParams& params) {
}

auto [atol, rtol] = getTestTolerances(dtype, has_fp8, use_mxfp8);
size_t mismatch_limit = use_mxfp8 ? std::max((size_t)1, params.m * params.n / 1'000'000) : 0;
RefD.to_cpu();
compareResults("D", D, RefD.rowwise_cpu_dptr<D_Type>(), true, atol, rtol);

if(params.use_gelu){
auto [atol, rtol] = getTestTolerances(gelu_type, false, false);
auto [atol, rtol] = getTestTolerances(gelu_type, has_fp8, use_mxfp8);
RefPreGeluOut.to_cpu();
compareResults("gelu", pre_gelu_out, RefPreGeluOut.rowwise_cpu_dptr<Gelu_Type>(), true, atol, rtol);
}
Expand Down Expand Up @@ -581,6 +596,12 @@ void performDqTest(const TestParams &params) {
if (!mxfp8_supported) {
GTEST_SKIP() << "MXFP8 is not supported in current config";
}
if (params.use_bias || params.use_gelu) {
GTEST_SKIP() << "DqGEMMTestSuite does not yet have reference for bias/gelu epilogues";
}
if (!params.force_hipblaslt && (params.m % 256 || params.n % 256 || params.k % 128 || params.k < 256)) {
GTEST_SKIP() << "HipKittens requires M and N 256-aligned, K >= 256";
}

DType ref_type = dtype;
TShape a_shape = params.transa ? TShape{params.m, params.k} : TShape{params.k, params.m};
Expand Down Expand Up @@ -608,7 +629,9 @@ void performDqTest(const TestParams &params) {
Tensor bias;
Tensor pre_gelu_out;

size_t workspace_size = 67108864;
size_t workspace_size = compute_mxfp8_workspace_size(params.m, params.k, params.n,
params.transa, params.transb,
67108864); // 64 MiB required for hipBLASlt
Tensor Workspace("Workspace", TShape{workspace_size}, DType::kByte);

//perform FP8 gemm and copy the output results from GPU memory to CPU memory
Expand Down Expand Up @@ -638,6 +661,12 @@ void performDqTest(const TestParams &params) {
#endif // __HIP_PLATFORM_AMD__

#define MAKE_TEST_PARAMS(P_) \
bool force_hipblaslt_ = std::get<5>(GetParam()); \
if (force_hipblaslt_) { \
setenv("NVTE_ROCM_USE_HIPBLASLT_MXFP8", "1", 1); \
} else { \
unsetenv("NVTE_ROCM_USE_HIPBLASLT_MXFP8"); \
} \
TestParams P_ = {.m = std::get<0>(std::get<0>(GetParam())), \
.k = std::get<1>(std::get<0>(GetParam())), \
.n = std::get<2>(std::get<0>(GetParam())), \
Expand All @@ -646,13 +675,14 @@ void performDqTest(const TestParams &params) {
.transa = std::get<3>(GetParam()).first, \
.transb = std::get<3>(GetParam()).second, \
.scaling_mode = std::get<4>(GetParam()) \
? NVTEScalingMode::NVTE_MXFP8_1D_SCALING \
: NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING}
? NVTEScalingMode::NVTE_MXFP8_1D_SCALING \
: NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING,\
.force_hipblaslt = force_hipblaslt_}

// <m, k, n>, use_bias, use_gelu, Layout, fp8_scalinig
// <m, k, n>, use_bias, use_gelu, Layout, fp8_scaling, force_hipblaslt
class GEMMTestSuite
: public ::testing::TestWithParam<
std::tuple<std::tuple<size_t, size_t, size_t>, bool, bool, Layout, NVTEScalingMode>> {};
std::tuple<std::tuple<size_t, size_t, size_t>, bool, bool, Layout, NVTEScalingMode, bool>> {};

#define MAKE_GEMM_TEST(NAME_, A_, B_, BIAS_, GELU_, D_) \
TEST_P(GEMMTestSuite, NAME_) { \
Expand Down Expand Up @@ -713,19 +743,32 @@ static inline auto MKN(const std::tuple<size_t, size_t, size_t>& shape) {
std::to_string(std::get<2>(shape));
}

static std::string GEMMTestName(const testing::TestParamInfo<GEMMTestSuite::ParamType>& info) {
return MKN(std::get<0>(info.param)) + "x" +
std::to_string(std::get<1>(info.param)) + "x" +
std::to_string(std::get<2>(info.param)) + "x" +
TN(std::get<3>(info.param)) + "x" +
(std::get<4>(info.param) ? "M" : "S") + "x" +
(std::get<5>(info.param) ? "HB" : "HK");
}

INSTANTIATE_TEST_SUITE_P(OperatorTest, GEMMTestSuite,
::testing::Combine(::testing::ValuesIn(test_case_sizes),
::testing::Values(false, true), //use bias
::testing::Values(false, true), //use_gelu
::testing::ValuesIn(kLayouts), //transa,transb
::testing::Values(false, true)), //use mxfp8
[](const testing::TestParamInfo<GEMMTestSuite::ParamType>& info) {
return MKN(std::get<0>(info.param)) + "x" +
std::to_string(std::get<1>(info.param)) + "x" +
std::to_string(std::get<2>(info.param)) + "x" +
TN(std::get<3>(info.param)) + "x" +
(std::get<4>(info.param) ? "M" : "S");
});
::testing::Values(false), //use mxfp8
::testing::Values(false)), //force hipblaslt
GEMMTestName);

INSTANTIATE_TEST_SUITE_P(OperatorTestMXFP8, GEMMTestSuite,
::testing::Combine(::testing::ValuesIn(test_case_sizes),
::testing::Values(false, true), //use bias
::testing::Values(false, true), //use_gelu
::testing::ValuesIn(kLayouts), //transa,transb
::testing::Values(true), //use mxfp8
::testing::Values(false, true)), //force hipblaslt
GEMMTestName);

#ifdef __HIP_PLATFORM_AMD__
class DqGEMMTestSuite: public GEMMTestSuite {};
Expand All @@ -743,12 +786,15 @@ MAKE_DQ_GEMM_TEST(Testfp8xfp8xfp16, fp8, fp8, fp16)

INSTANTIATE_TEST_SUITE_P(OperatorTest, DqGEMMTestSuite,
::testing::Combine(::testing::ValuesIn(test_case_sizes_mxfp8),
::testing::Values(false), // bias - unused
::testing::Values(false), // gelu - unused
::testing::ValuesIn(kLayouts), //transa,transb
::testing::Values(true)), //use mxfp8
::testing::Values(false), // use bias
::testing::Values(false), // use gelu
::testing::ValuesIn(kLayouts), // transa,transb
::testing::Values(true), // use mxfp8
::testing::Values(false, true)), // force hipblaslt
[](const testing::TestParamInfo<DqGEMMTestSuite::ParamType>& info) {
return MKN(std::get<0>(info.param)) + "x" + TN(std::get<3>(info.param));
return MKN(std::get<0>(info.param)) + "x" +
TN(std::get<3>(info.param)) + "x" +
(std::get<5>(info.param) ? "HB" : "HK");
});

TEST(InputGenTest, FillUniform_DoesNotGetOverwrittenByFromCpu) {
Expand Down
2 changes: 1 addition & 1 deletion tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
]
TEST_SHAPES = [(64, 32, 64)]
if is_hip_extension():
TEST_SHAPES += [(64, 64, 128), (128, 256, 256)]
TEST_SHAPES += [(64, 64, 128), (128, 256, 256), (256, 256, 256)]
jnp_float8_e4m3_type = get_jnp_float8_e4m3_type()
jnp_float8_e5m2_type = get_jnp_float8_e5m2_type()

Expand Down
5 changes: 3 additions & 2 deletions tests/jax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ def _check_mxfp8_gemm_support(with_jax_gemm, m, n, k, use_bias=False):
pytest.skip(
f"Input shape {(m, k)} x {(k, n)} is not supported by hipblaslt MXFP8 GEMM."
)
if use_bias:
pytest.skip("hipblaslt GEMM does not yet support MXFP8 with bias.")
hipkittens_eligible = (m % 256 == 0) and (n % 256 == 0) and (k >= 256)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same hardcoding 256s...

if use_bias and not hipkittens_eligible:
pytest.skip("hipblaslt GEMM does not support MXFP8 with bias.")
else:
jax_version = version.parse(jax.__version__)
if jax_version < version.parse("0.8.2"):
Expand Down
9 changes: 9 additions & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ cmake_minimum_required(VERSION 3.21)
option(USE_ROCM "Use ROCm" ON)
option(USE_FUSED_ATTN_AOTRITON "Use aotriton backend" ON)
option(USE_FUSED_ATTN_CK "Use ck backend" ON)
option(USE_HIPKITTENS_GEMM "Use HipKittens MXFP8 GEMM kernels" ON)
set(USE_CUDA OFF)

if (USE_ROCM)
Expand Down Expand Up @@ -453,6 +454,10 @@ else()
add_subdirectory(ck_fused_attn ${CMAKE_CURRENT_BINARY_DIR}/ck_fused_attn)
endif()

if(USE_HIPKITTENS_GEMM)
add_subdirectory(gemm/kittens ${CMAKE_CURRENT_BINARY_DIR}/kittens)
endif()

find_package(hip)
list(APPEND transformer_engine_LINKER_LIBS hip::host hip::device roctx64)
find_package(hiprtc)
Expand All @@ -467,6 +472,10 @@ else()
target_compile_definitions(transformer_engine PUBLIC USE_FUSED_ATTN_CK)
list(APPEND transformer_engine_LINKER_LIBS ck_fused_attn)
endif()
if(USE_HIPKITTENS_GEMM)
target_compile_definitions(transformer_engine PUBLIC USE_HIPKITTENS_GEMM)
list(APPEND transformer_engine_LINKER_LIBS kittens_gemm)
endif()
target_link_libraries(transformer_engine PUBLIC ${transformer_engine_LINKER_LIBS})
endif()

Expand Down
Loading
Loading