Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,19 @@ bool ck_tile_grouped_gemm(const NVTETensor* A,
// FP8 special handling.
//
// A_use/B_use and transA_use/transB_use have already gone through the
// upstream-style grouped GEMM normalization above. This block only rewrites
// that normalized presentation into the CK FP8 preferred NT presentation by selecting
// `columnwise_data` when needed.
// upstream-style grouped GEMM normalization above. CK FP8 grouped GEMM is
// compiled only for the preferred NT presentation:
//
// CK FP8 target presentation:
// A_use: N
// B_use: T
// transA_use = false
// transB_use = true
//
// The outer condition checks whether this NT presentation is possible:
// - A_use is already N, or can be made N using columnwise_data
// - B_use is already T, or can be made T using columnwise_data
// This block rewrites the normalized presentation into that NT form by
// selecting columnwise_data when needed. If the required columnwise_data view
// is unavailable, this CK FP8 backend cannot represent the GEMM in its
// supported layout form, so we fall back instead of compiling/running an
// unsupported layout variant.
//
// Then each operand is rewritten independently only if needed:
// Rewrite cases:
// NN -> rewrite B only
// TN -> rewrite A and B
// NT -> already in target form
Expand All @@ -81,16 +81,23 @@ bool ck_tile_grouped_gemm(const NVTETensor* A,
const bool has_a_col = A0_te->has_columnwise_data();
const bool has_b_col = B0_te->has_columnwise_data();

if ((!transA_use || has_a_col) && (transB_use || has_b_col)) {
if (transA_use) {
use_a_colwise_data = true;
transA_use = false;
}
const bool can_make_a_nt = !transA_use || has_a_col;
const bool can_make_b_nt = transB_use || has_b_col;

if (!transB_use) {
use_b_colwise_data = true;
transB_use = true;
}
if (!can_make_a_nt || !can_make_b_nt) {
NVTE_WARN("ck_tile_grouped_gemm: FP8 grouped GEMM requires NT presentation. "
"Missing required columnwise_data for layout rewrite; falling back.");
return false;
}

if (transA_use) {
use_a_colwise_data = true;
transA_use = false;
}

if (!transB_use) {
use_b_colwise_data = true;
transB_use = true;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#pragma once

#include <hip/hip_runtime.h>
#include "common/util/cuda_runtime.h"

#include <array>
#include <type_traits>
Expand Down Expand Up @@ -70,6 +71,28 @@ static inline const transformer_engine::SimpleTensor& scale_inv_view(const trans
return t.scale_inv;
}

enum class GPUArch {
GFX942,
GFX950,
GFX1250,
UNKNOWN
};

static inline GPUArch detect_gpu_arch() {
int arch = cuda::sm_arch(0);

if (arch == 94) {
return GPUArch::GFX942;
}
if (arch == 95) {
return GPUArch::GFX950;
}
if (arch == 125 || arch == 1250) {
return GPUArch::GFX1250;
}
return GPUArch::UNKNOWN;
}

struct GroupedGemmRunContext {
const NVTETensor* A = nullptr;
const NVTETensor* B = nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace grouped_gemm {
// Tile configs: FP16/BF16
// -------------------------

struct TileCfg_256x256x64 {
struct TileCfg_256x256x64_MFMA {
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 64;
Expand All @@ -37,14 +37,37 @@ struct TileCfg_256x256x64 {
static constexpr ck_tile::index_t TilePartitionerM01 = 4;
};

struct TileCfg_256x128x64 : TileCfg_256x256x64 {
struct TileCfg_256x128x64_MFMA : TileCfg_256x256x64_MFMA {
static constexpr ck_tile::index_t N_Tile = 128;
};

struct TileCfg_256x128x64_padding : TileCfg_256x128x64 {
struct TileCfg_256x128x64_MFMA_padding : TileCfg_256x128x64_MFMA {
static constexpr bool kPadN = true;
};

struct TileCfg_256x256x64_WMMA {
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 64;

static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;

static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 32;

static constexpr bool kPadM = true;
static constexpr bool kPadN = true;
static constexpr bool kPadK = true;

static constexpr bool DoubleSmemBuffer = false;

static constexpr ck_tile::index_t TilePartitionerGroupNum = 8;
static constexpr ck_tile::index_t TilePartitionerM01 = 4;
};

template <typename AType,
typename BType,
typename CType,
Expand Down Expand Up @@ -209,7 +232,26 @@ class GroupedGemmRunner : public RunnerInterface {
runner = std::make_unique<Runner>(); \
})

bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype,
template <GPUArch Arch>
struct FP16TileCfg;

template <>
struct FP16TileCfg<GPUArch::GFX942> {
using type = TileCfg_256x256x64_MFMA;
};

template <>
struct FP16TileCfg<GPUArch::GFX950> {
using type = TileCfg_256x256x64_MFMA;
};

template <>
struct FP16TileCfg<GPUArch::GFX1250> {
using type = TileCfg_256x256x64_WMMA;
};

template <GPUArch Arch>
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it need template over reguler if-else or switch-case?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The template is needed because the arch selection affects CK kernel template instantiation, not just runtime control flow. GPUArch must be a compile-time value so if constexpr can prune unsupported tile/kernel combinations for a given architecture. In this case, it prevents the MFMA configs from being instantiated for gfx1250.

bool ck_tile_grouped_gemm_fp16_dispatch_arch(DType a_dtype,
DType b_dtype,
DType d_dtype,
const GroupedGemmRunContext& ctx) {
Expand All @@ -229,13 +271,17 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype,

TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, {
using CType = typename TETypeToCKType<d_te_type>::type;

if (ctx.N % 256 == 0) {
MAKE_RUNNER(TileCfg_256x256x64);
} else if (ctx.N % 128 == 0) {
MAKE_RUNNER(TileCfg_256x128x64);

if constexpr (Arch == GPUArch::GFX1250) {
MAKE_RUNNER(TileCfg_256x256x64_WMMA);
} else {
MAKE_RUNNER(TileCfg_256x128x64_padding);
if (ctx.N % 256 == 0) {
MAKE_RUNNER(TileCfg_256x256x64_MFMA);
} else if (ctx.N % 128 == 0) {
MAKE_RUNNER(TileCfg_256x128x64_MFMA);
} else {
MAKE_RUNNER(TileCfg_256x128x64_MFMA_padding);
}
}
});
});
Expand All @@ -249,6 +295,23 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype,
return runner->run(s, ctx);
}

bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype,
DType b_dtype,
DType d_dtype,
const GroupedGemmRunContext& ctx) {
switch (detect_gpu_arch()) {
case GPUArch::GFX942:
return ck_tile_grouped_gemm_fp16_dispatch_arch<GPUArch::GFX942>(a_dtype, b_dtype, d_dtype, ctx);
case GPUArch::GFX950:
return ck_tile_grouped_gemm_fp16_dispatch_arch<GPUArch::GFX950>(a_dtype, b_dtype, d_dtype, ctx);
case GPUArch::GFX1250:
return ck_tile_grouped_gemm_fp16_dispatch_arch<GPUArch::GFX1250>(a_dtype, b_dtype, d_dtype, ctx);
default:
NVTE_ERROR("ck_tile_grouped_gemm: available architectures = {gfx942, gfx950, gfx1250}");
return false;
}
}

#undef MAKE_RUNNER

} // namespace grouped_gemm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#include "ck_grouped_gemm_common.h"
#include "ck_grouped_gemm_fp8.h"
#include "common/util/cuda_runtime.h"

#include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp"
Expand All @@ -16,12 +15,6 @@
namespace transformer_engine {
namespace grouped_gemm {

enum class GPUArch {
GFX942,
GFX950,
UNKNOWN
};

struct TileCfg_128x128x128_16x16x128_2x2x1 {
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
Expand All @@ -45,6 +38,29 @@ struct TileCfg_128x128x128_16x16x128_2x2x1 {
static constexpr ck_tile::index_t TilePartitionerM01 = 8;
};

struct TileCfg_128x128x128_16x16x64_2x2x1 {
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;

static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;

static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 64;

static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;

static constexpr bool DoubleSmemBuffer = false;

static constexpr ck_tile::index_t TilePartitionerGroupNum = 16;
static constexpr ck_tile::index_t TilePartitionerM01 = 8;
};

// gfx950 device compilation cannot instantiate the literal 32x32x16 FP8 tile
// configuration due to an unsupported warp GEMM dispatcher configuration.
// See: ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp for supported variants.
Expand Down Expand Up @@ -115,8 +131,7 @@ class QuantGroupedGemmRunner : public RunnerInterface {
AccType,
GemmShape,
UniversalTraits,
false,
AccType>;
false>;

using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3<Problem>;

Expand Down Expand Up @@ -265,18 +280,6 @@ class QuantGroupedGemmRunner : public RunnerInterface {
}
};

static inline GPUArch detect_gpu_arch() {
int arch = cuda::sm_arch(0);

if (arch == 94) {
return GPUArch::GFX942;
}
if (arch == 95) {
return GPUArch::GFX950;
}
return GPUArch::UNKNOWN;
}

template <GPUArch Arch>
struct FP8TileCfg;

Expand All @@ -290,6 +293,11 @@ struct FP8TileCfg<GPUArch::GFX950> {
using type = TileCfg_128x128x128_16x16x128_2x2x1;
};

template <>
struct FP8TileCfg<GPUArch::GFX1250> {
using type = TileCfg_128x128x128_16x16x64_2x2x1;
};

template <GPUArch Arch>
static bool ck_tile_grouped_gemm_fp8_dispatch_arch(DType a_dtype,
DType b_dtype,
Expand All @@ -301,31 +309,38 @@ static bool ck_tile_grouped_gemm_fp8_dispatch_arch(DType a_dtype,
using CTypeLayout = RowMajor;
using TileCfg = typename FP8TileCfg<Arch>::type;

TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transA, kTransA, {
using ALayout = std::conditional_t<kTransA, ColMajor, RowMajor>;

TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transB, kTransB, {
using BLayout = std::conditional_t<kTransB, ColMajor, RowMajor>;

TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(a_dtype, a_te_type, {
using AType = typename TETypeToCKType<a_te_type>::type;

TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(b_dtype, b_te_type, {
using BType = typename TETypeToCKType<b_te_type>::type;

TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, {
using CType = typename TETypeToCKType<d_te_type>::type;
using Runner = QuantGroupedGemmRunner<AType,
BType,
CType,
ALayout,
BLayout,
CTypeLayout,
TileCfg,
ck_tile::memory_operation_enum::set>;
runner = std::make_unique<Runner>();
});
});
// FP8 grouped GEMM is only compiled for CK's preferred NT presentation:
// transA=false, transB=true
// which maps to:
// ALayout=RowMajor, BLayout=ColMajor.
//
// The caller is responsible for rewriting other FP8 layouts into this form
// using columnwise_data when needed. Reject anything that did not normalize
// successfully so we do not instantiate unreachable/unsupported layout variants.
if (ctx.transA || !ctx.transB) {
return false;
}

using ALayout = RowMajor;
using BLayout = ColMajor;

TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(a_dtype, a_te_type, {
using AType = typename TETypeToCKType<a_te_type>::type;

TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(b_dtype, b_te_type, {
using BType = typename TETypeToCKType<b_te_type>::type;

TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, {
using CType = typename TETypeToCKType<d_te_type>::type;
using Runner = QuantGroupedGemmRunner<AType,
BType,
CType,
ALayout,
BLayout,
CTypeLayout,
TileCfg,
ck_tile::memory_operation_enum::set>;
runner = std::make_unique<Runner>();
});
});
});
Expand All @@ -346,8 +361,10 @@ bool ck_tile_grouped_gemm_fp8_dispatch(DType a_dtype,
return ck_tile_grouped_gemm_fp8_dispatch_arch<GPUArch::GFX942>(a_dtype, b_dtype, d_dtype, ctx);
case GPUArch::GFX950:
return ck_tile_grouped_gemm_fp8_dispatch_arch<GPUArch::GFX950>(a_dtype, b_dtype, d_dtype, ctx);
case GPUArch::GFX1250:
return ck_tile_grouped_gemm_fp8_dispatch_arch<GPUArch::GFX1250>(a_dtype, b_dtype, d_dtype, ctx);
default:
NVTE_ERROR("ck_tile_grouped_gemm: available architectures = {gfx942, gfx950}");
NVTE_ERROR("ck_tile_grouped_gemm: available architectures = {gfx942, gfx950, gfx1250}");
return false;
}
}
Expand Down
Loading
Loading