diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 0ebd7fdfe..20d6919cc 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -31,11 +31,11 @@ list(APPEND test_cuda_sources test_multi_unpadding.cu test_causal_softmax.cu test_swap_first_dims.cu + test_swizzle.cu ../test_common.cu) if(USE_CUDA) list(APPEND test_cuda_sources - test_cast_float8blockwise.cu - test_swizzle.cu) + test_cast_float8blockwise.cu) else() list(APPEND test_cuda_sources test_cublaslt_gemm.cu diff --git a/tests/cpp/operator/test_swizzle.cu b/tests/cpp/operator/test_swizzle.cu index 3209d2335..ec0877776 100644 --- a/tests/cpp/operator/test_swizzle.cu +++ b/tests/cpp/operator/test_swizzle.cu @@ -166,3 +166,428 @@ INSTANTIATE_TEST_SUITE_P( std::to_string(std::get<2>(info.param)); return name; }); + +#ifdef __HIP_PLATFORM_AMD__ + +// ============================================================================ +// End-to-end MXFP8 GEMM test with pre-swizzled scales +// +// Verifies that the full pipeline works: +// 1. Create MXFP8 FP8 tensors with random data + scales +// 2. Run a reference GEMM (using un-swizzled scales) +// 3. Swizzle the scales via nvte_swizzle_scaling_factors +// 4. Run the actual hipBLASlt GEMM +// 5. Compare results +// ============================================================================ + +#include + +// Helper: swizzle the MXFP8 scale_inv of a test::Tensor in-place. +// Allocates a temp device buffer, swizzles into it, copies back. +static void swizzle_tensor_scales(test::Tensor &t, bool rowwise) { + using namespace transformer_engine; + + void *scale_ptr = rowwise ? t.rowwise_scale_inv_dptr() + : t.columnwise_scale_inv_dptr(); + if (!scale_ptr) return; + + const NVTEShape scale_shape = rowwise ? t.rowwise_scale_inv_shape() + : t.columnwise_scale_inv_shape(); + const NVTEShape data_shape = rowwise ? t.rowwise_shape() + : t.columnwise_shape(); + + size_t num_scales = 1; + for (size_t d = 0; d < scale_shape.ndim; d++) { + num_scales *= scale_shape.data[d]; + } + + // Allocate temp buffer for swizzled output + uint8_t *d_tmp = nullptr; + ASSERT_EQ(cudaMalloc(&d_tmp, num_scales), cudaSuccess); + + // Build TensorWrapper pair for the swizzle call + TensorWrapper input_tw(NVTE_MXFP8_1D_SCALING); + TensorWrapper output_tw(NVTE_MXFP8_1D_SCALING); + output_tw.set_with_gemm_swizzled_scales(true); + + if (rowwise) { + input_tw.set_rowwise_data(nullptr, t.dtype(), data_shape); + input_tw.set_rowwise_scale_inv(scale_ptr, DType::kFloat8E8M0, scale_shape); + output_tw.set_rowwise_data(nullptr, t.dtype(), data_shape); + output_tw.set_rowwise_scale_inv(d_tmp, DType::kFloat8E8M0, scale_shape); + } else { + input_tw.set_columnwise_data(nullptr, t.dtype(), data_shape); + input_tw.set_columnwise_scale_inv(scale_ptr, DType::kFloat8E8M0, scale_shape); + output_tw.set_columnwise_data(nullptr, t.dtype(), data_shape); + output_tw.set_columnwise_scale_inv(d_tmp, DType::kFloat8E8M0, scale_shape); + } + + nvte_swizzle_scaling_factors(input_tw.data(), output_tw.data(), 0); + ASSERT_EQ(cudaDeviceSynchronize(), cudaSuccess); + + // Copy swizzled scales back over the original + ASSERT_EQ(cudaMemcpy(scale_ptr, d_tmp, num_scales, cudaMemcpyDeviceToDevice), cudaSuccess); + cudaFree(d_tmp); + + // Mark tensor as having swizzled scales + t.set_with_gemm_swizzled_scales(true); +} + +// Simple GPU reference kernel for MXFP8 GEMM: D = A * B^T (TN layout) +// A is [M, K] row-major, B is [N, K] row-major, D is [M, N] column-major +// Scales are E8M0, one per group of 32 elements along K. +__global__ void mxfp8_gemm_ref_kernel( + const test::fp8e4m3 *a_data, const uint8_t *a_scale, size_t a_scale_ld, + const test::fp8e4m3 *b_data, const uint8_t *b_scale, size_t b_scale_ld, + test::bf16 *d_data, + size_t M, size_t K, size_t N) { + const size_t i = blockIdx.y * blockDim.y + threadIdx.y; + const size_t j = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= M || j >= N) return; + + float acc = 0.0f; + for (size_t kk = 0; kk < K; kk++) { + size_t kc = kk / 32; + float a_sinv = exp2f(static_cast(a_scale[i * a_scale_ld + kc]) - 127.0f); + float b_sinv = exp2f(static_cast(b_scale[j * b_scale_ld + kc]) - 127.0f); + float a_val = static_cast(a_data[i * K + kk]); + float b_val = static_cast(b_data[j * K + kk]); + acc += a_sinv * a_val * b_sinv * b_val; + } + d_data[i + j * M] = static_cast(acc); +} + +struct MxGemmParams { + size_t m, k, n; +}; + +class MxGemmTestSuite + : public ::testing::TestWithParam {}; + +TEST_P(MxGemmTestSuite, TestMxfp8GemmE2E) { + using namespace transformer_engine; + using namespace test; + + const auto &p = GetParam(); + const size_t M = p.m; + const size_t K = p.k; + const size_t N = p.n; + + cudaDeviceProp prop; + ASSERT_EQ(cudaGetDeviceProperties(&prop, 0), cudaSuccess); + + // MXFP8 requires gfx950+ + bool mxfp8_supported = (prop.major == 9 && prop.minor >= 5) || + (prop.major >= 10); + if (!mxfp8_supported) { + GTEST_SKIP() << "MXFP8 GEMM not supported on this GPU"; + } + + // TN layout: A is [M, K], B is [N, K] + const bool transa = true; + const bool transb = false; + + Tensor A("A", std::vector{M, K}, DType::kFloat8E4M3, true, false, NVTE_MXFP8_1D_SCALING); + Tensor B("B", std::vector{N, K}, DType::kFloat8E4M3, true, false, NVTE_MXFP8_1D_SCALING); + Tensor D("D", std::vector{N, M}, DType::kBFloat16); + Tensor RefD("RefD", std::vector{N, M}, DType::kBFloat16); + Tensor bias; + Tensor pre_gelu_out; + + fillUniform(&A); + fillUniform(&B); + + // Override scales with values in [120,127] so layout errors are detectable. + // Default random [0,127] produces mostly tiny scales (2^(-127)..2^0), + // making the test insensitive to permutation errors. + { + auto fill_discriminating_scales = [](void *scale_ptr, size_t count) { + std::vector h(count); + std::mt19937 rng(42); + std::uniform_int_distribution dist(120, 127); + for (size_t i = 0; i < count; i++) h[i] = dist(rng); + cudaMemcpy(scale_ptr, h.data(), count, cudaMemcpyHostToDevice); + }; + auto a_sh = A.rowwise_scale_inv_shape(); + auto b_sh = B.rowwise_scale_inv_shape(); + fill_discriminating_scales(A.rowwise_scale_inv_dptr(), a_sh.data[0] * a_sh.data[1]); + fill_discriminating_scales(B.rowwise_scale_inv_dptr(), b_sh.data[0] * b_sh.data[1]); + } + + // GPU reference with un-swizzled (compact) scales + const auto a_scale_shape = A.rowwise_scale_inv_shape(); + const auto b_scale_shape = B.rowwise_scale_inv_shape(); + + std::cout << " A_scale shape: [" << a_scale_shape.data[0] << ", " << a_scale_shape.data[1] + << "], B_scale shape: [" << b_scale_shape.data[0] << ", " << b_scale_shape.data[1] + << "]" << std::endl; + + { + dim3 block(16, 16); + dim3 grid((N + block.x - 1) / block.x, (M + block.y - 1) / block.y); + mxfp8_gemm_ref_kernel<<>>( + static_cast(A.rowwise_dptr()), + static_cast(A.rowwise_scale_inv_dptr()), + a_scale_shape.data[1], + static_cast(B.rowwise_dptr()), + static_cast(B.rowwise_scale_inv_dptr()), + b_scale_shape.data[1], + static_cast(RefD.rowwise_dptr()), + M, K, N); + ASSERT_EQ(cudaDeviceSynchronize(), cudaSuccess); + } + + // Swizzle scales to K-tiled layout for hipBLASlt BLK32_UE8M0_32_8_EXT on gfx1250. + // Layout: {M, K_scale}.reshape({M, K_scale/4, 4}).permute({1,0,2}) + // dst(m,k) = (k/4)*M*4 + m*4 + (k%4) + if (prop.major >= 12) { + swizzle_tensor_scales(A, true); + swizzle_tensor_scales(B, true); + } + + // Run actual GEMM + size_t workspace_size = 134217728; // 128MB + Tensor Workspace("Workspace", std::vector{workspace_size}, DType::kByte); + + nvte_cublas_gemm(A.data(), B.data(), D.data(), + bias.data(), pre_gelu_out.data(), + transa, transb, + /*grad=*/false, + Workspace.data(), + /*accumulate=*/false, + /*use_split_accumulator=*/false, + prop.multiProcessorCount, + 0); + + ASSERT_EQ(cudaDeviceSynchronize(), cudaSuccess); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + // Compare + D.to_cpu(); + RefD.to_cpu(); + + const bf16 *d_ptr = D.rowwise_cpu_dptr(); + const bf16 *ref_ptr = RefD.rowwise_cpu_dptr(); + double max_atol = 0.0; + double max_rtol = 0.0; + const double log_threshold = 5e-2 + K * 2e-4; + int mismatch_count = 0; + for (size_t i = 0; i < M * N; i++) { + float actual = static_cast(d_ptr[i]); + float expected = static_cast(ref_ptr[i]); + double diff = std::abs(actual - expected); + double denom = std::max(std::abs((double)expected), 1e-6); + if (diff > log_threshold && mismatch_count < 10) { + size_t row = i / N; + size_t col = i % N; + std::cout << " MISMATCH [" << row << "," << col << "]: actual=" << actual + << " expected=" << expected << " diff=" << diff << std::endl; + mismatch_count++; + } + max_atol = std::max(max_atol, diff); + max_rtol = std::max(max_rtol, diff / denom); + } + + // MXFP8 GEMM tolerance: FP8 E4M3 accumulation errors grow with K + // because hardware and reference kernels use different reduction orders. + const double ATOL = 5e-2 + K * 2e-4; + constexpr double RTOL = 1.5e-2; + EXPECT_LE(max_atol, ATOL) << "Absolute error too large: " << max_atol; + EXPECT_LE(max_rtol, RTOL) << "Relative error too large: " << max_rtol; +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + MxGemmTestSuite, + ::testing::Values( + MxGemmParams{32, 128, 16}, + MxGemmParams{64, 128, 32}, + MxGemmParams{128, 128, 64}, + MxGemmParams{64, 256, 32}, + MxGemmParams{128, 384, 64}, + MxGemmParams{256, 512, 128}, + MxGemmParams{512, 1024, 256}, + MxGemmParams{1024, 2048, 128}, + MxGemmParams{4096, 8192, 64} + ), + [](const testing::TestParamInfo &info) { + return "M" + std::to_string(info.param.m) + + "_K" + std::to_string(info.param.k) + + "_N" + std::to_string(info.param.n); + }); + +// MX pre-swizzle test (gfx1250 Tensile 3D layout) +// +// Tensile 3D: {K_scale, M}.reshape({K_scale, padM/4, 4}).permute({1, 0, 2}) +// For source (m, k): dst = (m/4) * (K*4) + k*4 + (m%4) + +// CPU reference for Tensile 3D MX scale pre-swizzle. +// Row-major input [M, K], output is a flat permuted array. +void compute_ref_mx_swizzle_row(const uint8_t *h_input, uint8_t *h_output, + const int M, const int K, + const int orig_M, const int orig_K) { + constexpr int GROUP = 4; + for (int m = 0; m < M; m++) { + for (int k = 0; k < K; k++) { + uint8_t val = 127; // E8M0 identity: 2^0 = 1.0 + if (m < orig_M && k < orig_K) { + val = h_input[m * orig_K + k]; + } + int group = k / GROUP; + int within = k % GROUP; + int dst = group * (M * GROUP) + m * GROUP + within; + h_output[dst] = val; + } + } +} + +void compute_ref_mx_swizzle_col(const uint8_t *h_input, uint8_t *h_output, + const int M, const int K, + const int orig_M, const int orig_K) { + constexpr int GROUP = 4; + for (int m = 0; m < M; m++) { + for (int k = 0; k < K; k++) { + uint8_t val = 127; + if (m < orig_M && k < orig_K) { + val = h_input[k * orig_M + m]; + } + int group = k / GROUP; + int within = k % GROUP; + int dst = group * (M * GROUP) + m * GROUP + within; + h_output[dst] = val; + } + } +} + +static size_t roundup_sz(size_t val, size_t mult) { + return ((val + mult - 1) / mult) * mult; +} + +class MxSwizzleTestSuite + : public ::testing::TestWithParam< + std::tuple, bool>> {}; + +TEST_P(MxSwizzleTestSuite, TestMxSwizzle) { + using namespace transformer_engine; + using namespace test; + + const auto dims = std::get<0>(GetParam()); + const bool rowwise = std::get<1>(GetParam()); + + // Original (unpadded) scale dimensions + const size_t orig_M = dims.first; + const size_t orig_K = dims.second; + + // Padded dimensions: K-tiled layout requires K_scale padded to multiple of 4 + const size_t M = orig_M; + const size_t K = roundup_sz(orig_K, 4); + + // Allocate host input (unpadded) and fill with random data + const size_t input_size = orig_M * orig_K; + std::unique_ptr h_input(new uint8_t[input_size]); + std::mt19937 rng(42); + for (size_t i = 0; i < input_size; i++) { + h_input[i] = static_cast(rng() % 256); + } + + // Allocate device input + uint8_t *d_input = nullptr; + ASSERT_EQ(cudaMalloc(&d_input, input_size), cudaSuccess); + ASSERT_EQ(cudaMemcpy(d_input, h_input.get(), input_size, cudaMemcpyHostToDevice), cudaSuccess); + + // Allocate device output (padded size) + const size_t output_size = M * K; + uint8_t *d_output = nullptr; + ASSERT_EQ(cudaMalloc(&d_output, output_size), cudaSuccess); + ASSERT_EQ(cudaMemset(d_output, 0, output_size), cudaSuccess); + + // Build TensorWrapper for input and output + TensorWrapper input_tw(NVTE_MXFP8_1D_SCALING); + TensorWrapper output_tw(NVTE_MXFP8_1D_SCALING); + output_tw.set_with_gemm_swizzled_scales(true); + + // Data shape must be consistent with scale shape for validation. + // Scale shapes use padded K; data shapes use unpadded dims + // (kernel derives original_M/K from them). + if (rowwise) { + std::vector data_shape_in = {orig_M, orig_K * 32}; + std::vector data_shape_out = {M, K * 32}; + std::vector scale_shape_in = {M, K}; + std::vector scale_shape_out = {M, K}; + input_tw.set_rowwise_data(nullptr, DType::kFloat8E4M3, data_shape_in); + input_tw.set_rowwise_scale_inv(d_input, DType::kFloat8E8M0, scale_shape_in); + output_tw.set_rowwise_data(nullptr, DType::kFloat8E4M3, data_shape_out); + output_tw.set_rowwise_scale_inv(d_output, DType::kFloat8E8M0, scale_shape_out); + } else { + std::vector data_shape_in = {orig_K * 32, orig_M}; + std::vector data_shape_out = {K * 32, M}; + std::vector scale_shape_in = {K, M}; + std::vector scale_shape_out = {K, M}; + input_tw.set_columnwise_data(nullptr, DType::kFloat8E4M3, data_shape_in); + input_tw.set_columnwise_scale_inv(d_input, DType::kFloat8E8M0, scale_shape_in); + output_tw.set_columnwise_data(nullptr, DType::kFloat8E4M3, data_shape_out); + output_tw.set_columnwise_scale_inv(d_output, DType::kFloat8E8M0, scale_shape_out); + } + + nvte_swizzle_scaling_factors_mx(input_tw.data(), output_tw.data(), 0); + + ASSERT_EQ(cudaDeviceSynchronize(), cudaSuccess); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + // Copy output back to host + std::unique_ptr h_output(new uint8_t[output_size]); + ASSERT_EQ(cudaMemcpy(h_output.get(), d_output, output_size, cudaMemcpyDeviceToHost), + cudaSuccess); + + // Compute reference + std::unique_ptr h_ref(new uint8_t[output_size]); + memset(h_ref.get(), 0, output_size); + if (rowwise) { + compute_ref_mx_swizzle_row(h_input.get(), h_ref.get(), M, K, orig_M, orig_K); + } else { + compute_ref_mx_swizzle_col(h_input.get(), h_ref.get(), M, K, orig_M, orig_K); + } + + // Compare + compareResults("mx_swizzle", h_output.get(), h_ref.get(), output_size); + + cudaFree(d_input); + cudaFree(d_output); +} + +namespace { + +// Scale dimensions (M_scale, K_scale). +// K_scale will be padded to multiple of 4 by the test. +std::vector> mx_scale_dims = { + {4, 4}, // minimal + {8, 4}, // small + {32, 8}, // medium + {64, 16}, // larger + {96, 8}, // non-power-of-2 M + {128, 32}, // big + {256, 64}, // bigger + {512, 128}, // stress inter-tile + {1024, 256}, // large + {4096, 256}, // max stress +}; + +} // namespace + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + MxSwizzleTestSuite, + ::testing::Combine( + ::testing::ValuesIn(mx_scale_dims), + ::testing::Values(true, false) + ), + [](const testing::TestParamInfo& info) { + std::string name = "M" + std::to_string(std::get<0>(info.param).first) + + "_K" + std::to_string(std::get<0>(info.param).second) + + (std::get<1>(info.param) ? "_row" : "_col"); + return name; + }); + +#endif // __HIP_PLATFORM_AMD__ diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index c0e82b8ff..c37245727 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -26,6 +26,7 @@ #include #include "../common.h" +#include "../util/cuda_runtime.h" #include "../util/vectorized_pointwise.h" #include "../util/logging.h" diff --git a/transformer_engine/common/include/transformer_engine/swizzle.h b/transformer_engine/common/include/transformer_engine/swizzle.h index 5e420b2d4..aeed0e8fe 100644 --- a/transformer_engine/common/include/transformer_engine/swizzle.h +++ b/transformer_engine/common/include/transformer_engine/swizzle.h @@ -64,6 +64,22 @@ void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETen void nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Swizzle MX (E8M0) scaling factors into gfx1250 Tensile 3D layout for GEMM + * + * Tensile 3D layout: groups M into blocks of 4, then permutes {1, 0, 2}. + * For source (m, k): dst = (m/4) * (K_scale * 4) + k * 4 + (m % 4) + * + * \param[in] input Input tensor with non-swizzled scale_inv (MXFP8). + * \param[in,out] output Output tensor which hosts swizzled scale_inv. + * \param[in] stream CUDA stream used for the operation. + * + * Requirements: + * - Input scaling mode is NVTE_MXFP8_1D_SCALING. + * - scale_inv M dimension is padded to a multiple of 4. + */ +void nvte_swizzle_scaling_factors_mx(const NVTETensor input, NVTETensor output, + cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index c634c73fb..e7fa9a5fc 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -14,6 +14,7 @@ #include #include "../common.h" +#include "../util/cuda_runtime.h" #include "../util/logging.h" #include "transformer_engine/transformer_engine.h" @@ -347,9 +348,178 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_ input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); } +// ============================================================================ +// MX scale pre-swizzle kernel for gfx1250 — K-tiled 3D layout +// +// hipBLASlt Tensile kernels expect scales in a permuted 3D layout that +// groups K_scale into tiles of 4 (= 128 / MXBlock32): +// Tensor({M, K_scale}).pad(K_scale to mult of 4).reshape({M, K_scale/4, 4}).permute({1, 0, 2}) +// +// For source position (m, k) in the [M, K_scale] scale matrix: +// group = k / 4 +// within = k % 4 +// dst = group * (M * 4) + m * 4 + within +// +// Padding: K_scale to multiple of 4. No M padding required. +// Identity padding value: E8M0 127 = 2^0 = 1.0 +// +// Reference: swizzle_mx_scale() in hipblaslt/clients/common/include/testing_matmul.hpp +// ============================================================================ + +constexpr int MX_PRESWIZZLE_GROUP_SIZE = 4; + +// Row-wise: input is [M, K_scale] row-major (K_scale contiguous) +__global__ void __launch_bounds__(256) + swizzle_row_scaling_mx_kernel(const uint8_t* __restrict__ input, + uint8_t* __restrict__ output, + const int M, const int K_scale, + const int original_M, const int original_K) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int total = M * K_scale; + if (idx >= total) return; + + const int m = idx / K_scale; + const int k = idx % K_scale; + + uint8_t val = 127; // E8M0 identity: 2^0 = 1.0 + if (m < original_M && k < original_K) { + val = input[m * original_K + k]; + } + + const int group = k / MX_PRESWIZZLE_GROUP_SIZE; + const int within = k % MX_PRESWIZZLE_GROUP_SIZE; + const int dst = group * (M * MX_PRESWIZZLE_GROUP_SIZE) + + m * MX_PRESWIZZLE_GROUP_SIZE + within; + + output[dst] = val; +} + +// Col-wise: input is [K_scale, M] row-major (M contiguous), representing +// the column-wise scale matrix logically shaped [M, K_scale]. +// Logical (m, k) maps to physical address k * original_M + m. +__global__ void __launch_bounds__(256) + swizzle_col_scaling_mx_kernel(const uint8_t* __restrict__ input, + uint8_t* __restrict__ output, + const int M, const int K_scale, + const int original_M, const int original_K) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int total = M * K_scale; + if (idx >= total) return; + + const int m = idx / K_scale; + const int k = idx % K_scale; + + uint8_t val = 127; + if (m < original_M && k < original_K) { + val = input[k * original_M + m]; // column-major read + } + + const int group = k / MX_PRESWIZZLE_GROUP_SIZE; + const int within = k % MX_PRESWIZZLE_GROUP_SIZE; + const int dst = group * (M * MX_PRESWIZZLE_GROUP_SIZE) + + m * MX_PRESWIZZLE_GROUP_SIZE + within; + + output[dst] = val; +} + } // namespace +void swizzle_scaling_factors_mx(const Tensor* input, Tensor* output, cudaStream_t stream) { + // Check scaling mode + const auto& scaling_mode = input->scaling_mode; + NVTE_CHECK(scaling_mode == NVTE_MXFP8_1D_SCALING, + "MX pre-swizzle only supports MXFP8 scaling mode (got ", + to_string(input->scaling_mode), ")."); + + // Check tensors + CheckInputTensor(*input, "scaling_factor_input"); + CheckInputTensor(*output, "scaling_factor_output"); + NVTE_CHECK(!input->with_gemm_swizzled_scales, + "Expected input tensor with scales in compact format."); + NVTE_CHECK(output->with_gemm_swizzled_scales, + "Expected output tensor with scales in GEMM swizzled format."); + NVTE_CHECK(is_fp8_dtype(input->dtype()), "Input tensor has invalid dtype (expected FP8, got ", + to_string(input->dtype()), ")."); + + // Check if scaling factors are non-trivial + const bool has_rowwise_scale_inv = input->scale_inv.has_data(); + const bool has_columnwise_scale_inv = input->columnwise_scale_inv.has_data(); + NVTE_CHECK(!has_rowwise_scale_inv || !has_columnwise_scale_inv, + "Input tensor has both row-wise and column-wise scaling factors"); + if (!has_rowwise_scale_inv && !has_columnwise_scale_inv) { + return; + } + + // Deduce tensor dims + int m{0}, k{0}; + if (has_rowwise_scale_inv) { + NVTE_CHECK(input->scale_inv.shape.size() == 2, + "Expected 2D scaling factors, got shape=", input->scale_inv.shape, "."); + m = input->scale_inv.shape[0]; + k = input->scale_inv.shape[1]; + } else if (has_columnwise_scale_inv) { + NVTE_CHECK(input->columnwise_scale_inv.shape.size() == 2, + "Expected 2D scaling factors, got shape=", input->columnwise_scale_inv.shape, "."); + m = input->columnwise_scale_inv.shape[1]; + k = input->columnwise_scale_inv.shape[0]; + } + + // Check dims -- K-tiled layout requires K_scale padded to multiple of 4 + NVTE_CHECK(k % MX_PRESWIZZLE_GROUP_SIZE == 0, + "Scale K dimension must be padded to multiple of ", MX_PRESWIZZLE_GROUP_SIZE, + ", got ", k, "."); + + // Validate output dimensions match + if (has_rowwise_scale_inv) { + NVTE_CHECK(output->scale_inv.has_data(), + "Output tensor does not have row-wise scaling factors."); + NVTE_CHECK(m * k == output->scale_inv.numel(), "Expected output tensor to have ", m * k, + " row-wise scaling factors, but got shape=", output->scale_inv.shape, "."); + } + if (has_columnwise_scale_inv) { + NVTE_CHECK(output->columnwise_scale_inv.has_data(), + "Output tensor does not have column-wise scaling factors."); + NVTE_CHECK(m * k == output->columnwise_scale_inv.numel(), + "Expected output tensor to have ", m * k, + " column-wise scaling factors, but got shape=", + output->columnwise_scale_inv.shape, "."); + } + + const int total = m * k; + constexpr int block = 256; + const int grid = (total + block - 1) / block; + + // Row-wise swizzle + if (has_rowwise_scale_inv) { + const int original_M = input->flat_first_dim(); + const int original_K = input->flat_last_dim() / MXFP8_BLOCK_SIZE; + swizzle_row_scaling_mx_kernel<<>>( + reinterpret_cast(input->scale_inv.dptr), + reinterpret_cast(output->scale_inv.dptr), + m, k, original_M, original_K); + NVTE_CHECK_CUDA(cudaGetLastError()); + } + + // Column-wise swizzle + if (has_columnwise_scale_inv) { + const int original_M = input->flat_last_dim(); + const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE; + swizzle_col_scaling_mx_kernel<<>>( + reinterpret_cast(input->columnwise_scale_inv.dptr), + reinterpret_cast(output->columnwise_scale_inv.dptr), + m, k, original_M, original_K); + NVTE_CHECK_CUDA(cudaGetLastError()); + } +} + void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) { + // On gfx1250, MXFP8 uses the K-tiled pre-swizzle layout + // (K_scale grouped by 4, matching hipBLASlt BLK32_UE8M0_32_8_EXT). + if (input->scaling_mode == NVTE_MXFP8_1D_SCALING && cuda::sm_arch() == 125) { + swizzle_scaling_factors_mx(input, output, stream); + return; + } + // Check scaling mode const auto& scaling_mode = input->scaling_mode; NVTE_CHECK(scaling_mode == NVTE_MXFP8_1D_SCALING || scaling_mode == NVTE_NVFP4_1D_SCALING, @@ -667,6 +837,23 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args, void multi_tensor_swizzle_scaling_factors(const std::vector& input, std::vector& output, cudaStream_t stream) { + // On gfx1250, MXFP8 uses the MX pre-swizzle layout. + // Dispatch each tensor individually through the MX pre-swizzle path. + if (cuda::sm_arch() == 125) { + bool any_mxfp8 = false; + for (size_t i = 0; i < input.size(); i++) { + if (is_mxfp8_scaling(input[i]->scaling_mode)) { + any_mxfp8 = true; + } + } + if (any_mxfp8) { + for (size_t i = 0; i < input.size(); i++) { + swizzle_scaling_factors_mx(input[i], output[i], stream); + } + return; + } + } + auto num_tensors = input.size(); bool all_has_data = true; bool all_has_columnwise_data = true; @@ -859,3 +1046,11 @@ void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETen } multi_tensor_swizzle_scaling_factors(input_list, output_list, stream); } + +void nvte_swizzle_scaling_factors_mx(const NVTETensor input, NVTETensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_swizzle_scaling_factors_mx); + using namespace transformer_engine; + swizzle_scaling_factors_mx(convertNVTETensorCheck(input), convertNVTETensorCheck(output), + stream); +} diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index bb960406d..f1c5a882c 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -9,6 +9,9 @@ #include #include "common.h" +#ifdef USE_ROCM +#include "common/util/cuda_runtime.h" +#endif #include "pybind.h" #include "torch/torch.h" @@ -1104,6 +1107,19 @@ std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s "MXFP8 requires tensor dims that are divisible by ", MXFP8_BLOCK_SIZE, " (got shape=", shape, ")"); #ifdef USE_ROCM + // gfx1250 MX pre-swizzle (Tensile 3D) layout requires M padded to multiple of 4. + // Other ROCm architectures use 128x4 tiles but currently skip padding + // (the swizzle kernel handles out-of-bounds reads). + if (transformer_engine::cuda::sm_arch() == 125) { + size_t m_dim = numel / last_dim; + size_t k_scale = last_dim / MXFP8_BLOCK_SIZE; + if (!columnwise) { + return {roundup(m_dim, 4), k_scale}; + } else { + return {k_scale, roundup(m_dim, 4)}; + } + } + return !columnwise ? std::vector{numel / last_dim, last_dim / MXFP8_BLOCK_SIZE} : std::vector{numel / last_dim / MXFP8_BLOCK_SIZE, last_dim};