diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp index 5684be1cd..c5f8f4086 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp @@ -57,19 +57,19 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, // FP8 special handling. // // A_use/B_use and transA_use/transB_use have already gone through the - // upstream-style grouped GEMM normalization above. This block only rewrites - // that normalized presentation into the CK FP8 preferred NT presentation by selecting - // `columnwise_data` when needed. + // upstream-style grouped GEMM normalization above. CK FP8 grouped GEMM is + // compiled only for the preferred NT presentation: // - // CK FP8 target presentation: - // A_use: N - // B_use: T + // transA_use = false + // transB_use = true // - // The outer condition checks whether this NT presentation is possible: - // - A_use is already N, or can be made N using columnwise_data - // - B_use is already T, or can be made T using columnwise_data + // This block rewrites the normalized presentation into that NT form by + // selecting columnwise_data when needed. If the required columnwise_data view + // is unavailable, this CK FP8 backend cannot represent the GEMM in its + // supported layout form, so we fall back instead of compiling/running an + // unsupported layout variant. // - // Then each operand is rewritten independently only if needed: + // Rewrite cases: // NN -> rewrite B only // TN -> rewrite A and B // NT -> already in target form @@ -81,16 +81,23 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, const bool has_a_col = A0_te->has_columnwise_data(); const bool has_b_col = B0_te->has_columnwise_data(); - if ((!transA_use || has_a_col) && (transB_use || has_b_col)) { - if (transA_use) { - use_a_colwise_data = true; - transA_use = false; - } + const bool can_make_a_nt = !transA_use || has_a_col; + const bool can_make_b_nt = transB_use || has_b_col; - if (!transB_use) { - use_b_colwise_data = true; - transB_use = true; - } + if (!can_make_a_nt || !can_make_b_nt) { + NVTE_WARN("ck_tile_grouped_gemm: FP8 grouped GEMM requires NT presentation. " + "Missing required columnwise_data for layout rewrite; falling back."); + return false; + } + + if (transA_use) { + use_a_colwise_data = true; + transA_use = false; + } + + if (!transB_use) { + use_b_colwise_data = true; + transB_use = true; } } diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h index c89f10232..552a5639d 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h @@ -7,6 +7,7 @@ #pragma once #include +#include "common/util/cuda_runtime.h" #include #include @@ -70,6 +71,28 @@ static inline const transformer_engine::SimpleTensor& scale_inv_view(const trans return t.scale_inv; } +enum class GPUArch { + GFX942, + GFX950, + GFX1250, + UNKNOWN +}; + +static inline GPUArch detect_gpu_arch() { + int arch = cuda::sm_arch(0); + + if (arch == 94) { + return GPUArch::GFX942; + } + if (arch == 95) { + return GPUArch::GFX950; + } + if (arch == 125 || arch == 1250) { + return GPUArch::GFX1250; + } + return GPUArch::UNKNOWN; +} + struct GroupedGemmRunContext { const NVTETensor* A = nullptr; const NVTETensor* B = nullptr; diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp index 660dbefb8..325a7ef53 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp @@ -14,7 +14,7 @@ namespace grouped_gemm { // Tile configs: FP16/BF16 // ------------------------- -struct TileCfg_256x256x64 { +struct TileCfg_256x256x64_MFMA { static constexpr ck_tile::index_t M_Tile = 256; static constexpr ck_tile::index_t N_Tile = 256; static constexpr ck_tile::index_t K_Tile = 64; @@ -37,14 +37,37 @@ struct TileCfg_256x256x64 { static constexpr ck_tile::index_t TilePartitionerM01 = 4; }; -struct TileCfg_256x128x64 : TileCfg_256x256x64 { +struct TileCfg_256x128x64_MFMA : TileCfg_256x256x64_MFMA { static constexpr ck_tile::index_t N_Tile = 128; }; -struct TileCfg_256x128x64_padding : TileCfg_256x128x64 { +struct TileCfg_256x128x64_MFMA_padding : TileCfg_256x128x64_MFMA { static constexpr bool kPadN = true; }; +struct TileCfg_256x256x64_WMMA { + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 32; + + static constexpr bool kPadM = true; + static constexpr bool kPadN = true; + static constexpr bool kPadK = true; + + static constexpr bool DoubleSmemBuffer = false; + + static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + static constexpr ck_tile::index_t TilePartitionerM01 = 4; +}; + template (); \ }) -bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, +template +struct FP16TileCfg; + +template <> +struct FP16TileCfg { + using type = TileCfg_256x256x64_MFMA; +}; + +template <> +struct FP16TileCfg { + using type = TileCfg_256x256x64_MFMA; +}; + +template <> +struct FP16TileCfg { + using type = TileCfg_256x256x64_WMMA; +}; + +template +bool ck_tile_grouped_gemm_fp16_dispatch_arch(DType a_dtype, DType b_dtype, DType d_dtype, const GroupedGemmRunContext& ctx) { @@ -229,13 +271,17 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { using CType = typename TETypeToCKType::type; - - if (ctx.N % 256 == 0) { - MAKE_RUNNER(TileCfg_256x256x64); - } else if (ctx.N % 128 == 0) { - MAKE_RUNNER(TileCfg_256x128x64); + + if constexpr (Arch == GPUArch::GFX1250) { + MAKE_RUNNER(TileCfg_256x256x64_WMMA); } else { - MAKE_RUNNER(TileCfg_256x128x64_padding); + if (ctx.N % 256 == 0) { + MAKE_RUNNER(TileCfg_256x256x64_MFMA); + } else if (ctx.N % 128 == 0) { + MAKE_RUNNER(TileCfg_256x128x64_MFMA); + } else { + MAKE_RUNNER(TileCfg_256x128x64_MFMA_padding); + } } }); }); @@ -249,6 +295,23 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, return runner->run(s, ctx); } +bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { + switch (detect_gpu_arch()) { + case GPUArch::GFX942: + return ck_tile_grouped_gemm_fp16_dispatch_arch(a_dtype, b_dtype, d_dtype, ctx); + case GPUArch::GFX950: + return ck_tile_grouped_gemm_fp16_dispatch_arch(a_dtype, b_dtype, d_dtype, ctx); + case GPUArch::GFX1250: + return ck_tile_grouped_gemm_fp16_dispatch_arch(a_dtype, b_dtype, d_dtype, ctx); + default: + NVTE_ERROR("ck_tile_grouped_gemm: available architectures = {gfx942, gfx950, gfx1250}"); + return false; + } +} + #undef MAKE_RUNNER } // namespace grouped_gemm diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp index 50b701c05..8ccaf702e 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp @@ -6,7 +6,6 @@ #include "ck_grouped_gemm_common.h" #include "ck_grouped_gemm_fp8.h" -#include "common/util/cuda_runtime.h" #include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp" @@ -16,12 +15,6 @@ namespace transformer_engine { namespace grouped_gemm { -enum class GPUArch { - GFX942, - GFX950, - UNKNOWN -}; - struct TileCfg_128x128x128_16x16x128_2x2x1 { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -45,6 +38,29 @@ struct TileCfg_128x128x128_16x16x128_2x2x1 { static constexpr ck_tile::index_t TilePartitionerM01 = 8; }; +struct TileCfg_128x128x128_16x16x64_2x2x1 { + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 64; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool DoubleSmemBuffer = false; + + static constexpr ck_tile::index_t TilePartitionerGroupNum = 16; + static constexpr ck_tile::index_t TilePartitionerM01 = 8; +}; + // gfx950 device compilation cannot instantiate the literal 32x32x16 FP8 tile // configuration due to an unsupported warp GEMM dispatcher configuration. // See: ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp for supported variants. @@ -115,8 +131,7 @@ class QuantGroupedGemmRunner : public RunnerInterface { AccType, GemmShape, UniversalTraits, - false, - AccType>; + false>; using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -265,18 +280,6 @@ class QuantGroupedGemmRunner : public RunnerInterface { } }; -static inline GPUArch detect_gpu_arch() { - int arch = cuda::sm_arch(0); - - if (arch == 94) { - return GPUArch::GFX942; - } - if (arch == 95) { - return GPUArch::GFX950; - } - return GPUArch::UNKNOWN; -} - template struct FP8TileCfg; @@ -290,6 +293,11 @@ struct FP8TileCfg { using type = TileCfg_128x128x128_16x16x128_2x2x1; }; +template <> +struct FP8TileCfg { + using type = TileCfg_128x128x128_16x16x64_2x2x1; +}; + template static bool ck_tile_grouped_gemm_fp8_dispatch_arch(DType a_dtype, DType b_dtype, @@ -301,31 +309,38 @@ static bool ck_tile_grouped_gemm_fp8_dispatch_arch(DType a_dtype, using CTypeLayout = RowMajor; using TileCfg = typename FP8TileCfg::type; - TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transA, kTransA, { - using ALayout = std::conditional_t; - - TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transB, kTransB, { - using BLayout = std::conditional_t; - - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(a_dtype, a_te_type, { - using AType = typename TETypeToCKType::type; - - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(b_dtype, b_te_type, { - using BType = typename TETypeToCKType::type; - - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { - using CType = typename TETypeToCKType::type; - using Runner = QuantGroupedGemmRunner; - runner = std::make_unique(); - }); - }); + // FP8 grouped GEMM is only compiled for CK's preferred NT presentation: + // transA=false, transB=true + // which maps to: + // ALayout=RowMajor, BLayout=ColMajor. + // + // The caller is responsible for rewriting other FP8 layouts into this form + // using columnwise_data when needed. Reject anything that did not normalize + // successfully so we do not instantiate unreachable/unsupported layout variants. + if (ctx.transA || !ctx.transB) { + return false; + } + + using ALayout = RowMajor; + using BLayout = ColMajor; + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(a_dtype, a_te_type, { + using AType = typename TETypeToCKType::type; + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(b_dtype, b_te_type, { + using BType = typename TETypeToCKType::type; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { + using CType = typename TETypeToCKType::type; + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); }); }); }); @@ -346,8 +361,10 @@ bool ck_tile_grouped_gemm_fp8_dispatch(DType a_dtype, return ck_tile_grouped_gemm_fp8_dispatch_arch(a_dtype, b_dtype, d_dtype, ctx); case GPUArch::GFX950: return ck_tile_grouped_gemm_fp8_dispatch_arch(a_dtype, b_dtype, d_dtype, ctx); + case GPUArch::GFX1250: + return ck_tile_grouped_gemm_fp8_dispatch_arch(a_dtype, b_dtype, d_dtype, ctx); default: - NVTE_ERROR("ck_tile_grouped_gemm: available architectures = {gfx942, gfx950}"); + NVTE_ERROR("ck_tile_grouped_gemm: available architectures = {gfx942, gfx950, gfx1250}"); return false; } } diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 7326f330f..efdff259d 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1123,6 +1123,7 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor #else // Currently only support cutlass group gemm on Hopper Arch if (!(is_hopper && use_cutlass)) { + if (!use_cutlass) { #endif cublas_path(); return;