From 3e305f72bf855c9b3de0c68b38dfa9c535e4e78d Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Thu, 3 Apr 2025 22:01:52 -0700 Subject: [PATCH 01/29] [PyTorch] Debug weight matrix usages for dgrad GEMM (#1637) Make sure that weight matrix has required usages for dgrad GEMM Signed-off-by: Tim Moon --- transformer_engine/pytorch/module/layernorm_linear.py | 5 ++--- transformer_engine/pytorch/module/layernorm_mlp.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 928e6c4adb..5fb986bdc3 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -327,9 +327,8 @@ def forward( ln_out.update_usage(rowwise_usage=False) # Weight with column-wise usage is needed for dgrad GEMM. - if inp.requires_grad: - if isinstance(weightmat, QuantizedTensor): - weightmat.update_usage(columnwise_usage=True) + if isinstance(weightmat, QuantizedTensor): + weightmat.update_usage(columnwise_usage=True) if cpu_offloading: if fp8 and weightmat is not None: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 0f324f4489..7dae573688 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -415,7 +415,7 @@ def forward( ) # Weight with column-wise usage is needed for dgrad GEMM. - if is_grad_enabled and inp.requires_grad: + if is_grad_enabled: if isinstance(fc1_weight_final, QuantizedTensor): fc1_weight_final.update_usage(columnwise_usage=True) if isinstance(fc2_weight_final, QuantizedTensor): From 1bbeab1c563e7b8551804cb5af0847d277e22951 Mon Sep 17 00:00:00 2001 From: kwyss-nvidia Date: Thu, 3 Apr 2025 23:39:15 -0700 Subject: [PATCH 02/29] Blockwise float8 quantizer and quantized tensor class (#1513) * Blockwise float8 quantizer and quantized tensor class. The classes are configurable for 128x128 blocksize and 1x128 blocksize via setting block_scaling_dim == 2,1 respectively. Scale tensors are stored in a format emenable for matrix multiplication, however the integration of matmul is deferred as a separate story. Fusions of quantization and DBIAS or activation functions are not yet implemented, and the dequantization is currently implemented in torch. Tests for quantization are included in C++ and pytorch layers, with exact comparison to reference quantizer behavior as well as an attempt to hit interesting branches through the API such as tensor creation in pytorch and CPP and dequantization of row and columnwise usage. Two CUDA kernels for quantization are included, and are direct ports of equivalents in the kitchen repository, where a subchannel recipe has been used for end to end training. Signed-off-by: Keith Wyss * Apply linting changes. Signed-off-by: Keith Wyss * Alignment for 1D scaling for GEMM edge case. Signed-off-by: Keith Wyss * MR feedback. Signed-off-by: Keith Wyss * Change API name. Signed-off-by: Keith Wyss * Fix merge conflict with name change. Signed-off-by: Keith Wyss * Use common tensor map API. Signed-off-by: Keith Wyss * Change API to use two scaling mode enums. Signed-off-by: Keith Wyss * Fix typo. Signed-off-by: Keith Wyss * Update some call sites. Signed-off-by: Keith Wyss * Tests for torch tensor API surface. Since the quantized tensor is a tensor subclass, these tests exercise torch hooks. Signed-off-by: Keith Wyss * Reuse scale calculation between quantizer refs. Signed-off-by: Keith Wyss * Save memory by dropping reference to saved tensors. Issues previously observed are solved. Signed-off-by: Keith Wyss * Remove constexpr parameters from kernel. Code size is reduced with fewer constexpr params. Signed-off-by: Keith Wyss * Merge conflict from rebase. Signed-off-by: Keith Wyss * Add shape implementations for block scaling. nvte_shape was added upstream. Logic added for block scaled fp8. Signed-off-by: Keith Wyss * Move benchmark to te_playground Signed-off-by: Keith Wyss * Remove amax_epsilon and pow_2_scales from tensor. Hardcodes the default values. Signed-off-by: Keith Wyss * Lint changes. Signed-off-by: Keith Wyss * Fixup MR changes that broke. Signed-off-by: Keith Wyss * Safer ifdef in kernel. Signed-off-by: Keith Wyss * Documentation prose. Signed-off-by: Keith Wyss * Reuse compute_scale function from Current Scaling. Signed-off-by: Keith Wyss * Bugfix on inf_value scale refactor. Signed-off-by: Keith Wyss * Remove qopt calls from test. Signed-off-by: Keith Wyss * Update pytest list. Signed-off-by: Keith Wyss * Add copyright to reference scale calc. Signed-off-by: Keith Wyss * Use ptx.cuh functions instead of cde. Signed-off-by: Keith Wyss * Update shape logic with allocation and reuse shape. Signed-off-by: Keith Wyss * Usage defaults MR feedback. Signed-off-by: Keith Wyss * Copyright and header guard. Signed-off-by: Keith Wyss * Updating torch dispatch code. Signed-off-by: Keith Wyss * Fix exception type. Signed-off-by: Keith Wyss * Use TypeInfo Signed-off-by: Keith Wyss * MR feedback. Signed-off-by: Keith Wyss * Update CS scale update test to use updated ref impl Signed-off-by: Tim Moon * Update JAX scaling mode enum Signed-off-by: Tim Moon * Skip tests on Lovelace Signed-off-by: Tim Moon --------- Signed-off-by: Keith Wyss Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- qa/L0_pytorch_unittest/test.sh | 2 + tests/cpp/operator/CMakeLists.txt | 1 + .../cpp/operator/test_cast_float8blockwise.cu | 641 ++++++++++++++++++ tests/cpp/test_common.cu | 148 ++-- tests/cpp/test_common.h | 35 +- .../blockwise_quantizer_reference.py | 302 +++++++++ .../pytorch/references/quantize_scale_calc.py | 60 ++ tests/pytorch/references/ref_per_tensor_cs.py | 59 +- .../test_float8_blockwise_scaling_exact.py | 294 ++++++++ tests/pytorch/test_float8blockwisetensor.py | 442 ++++++++++++ tests/pytorch/test_multi_tensor.py | 11 +- transformer_engine/common/CMakeLists.txt | 2 + transformer_engine/common/common.h | 64 ++ .../common/gemm/cublaslt_gemm.cu | 8 + .../common/include/transformer_engine/cast.h | 63 +- .../transformer_engine/transformer_engine.h | 10 +- .../common/recipe/current_scaling.cu | 3 +- .../common/recipe/recipe_common.cuh | 37 +- .../common/transformer_engine.cpp | 42 +- .../common/transpose/cast_transpose.h | 12 + .../quantize_transpose_square_blockwise.cu | 561 +++++++++++++++ .../quantize_transpose_vector_blockwise.cu | 479 +++++++++++++ .../common/util/cast_kernels.cuh | 24 + .../common/util/dequantize_kernels.cuh | 1 + transformer_engine/common/util/ptx.cuh | 55 +- .../jax/quantize/scaling_modes.py | 4 +- transformer_engine/pytorch/constants.py | 6 + transformer_engine/pytorch/csrc/common.h | 32 + .../multi_tensor_compute_scale.cu | 6 +- .../pytorch/csrc/extensions/pybind.cpp | 26 + .../pytorch/csrc/extensions/quantizer.cpp | 146 +++- .../csrc/extensions/type_converters.cpp | 32 + transformer_engine/pytorch/csrc/pybind.h | 19 +- .../_internal/float8_blockwise_tensor_base.py | 240 +++++++ .../pytorch/tensor/float8_blockwise_tensor.py | 608 +++++++++++++++++ 35 files changed, 4254 insertions(+), 221 deletions(-) create mode 100644 tests/cpp/operator/test_cast_float8blockwise.cu create mode 100644 tests/pytorch/references/blockwise_quantizer_reference.py create mode 100644 tests/pytorch/references/quantize_scale_calc.py create mode 100644 tests/pytorch/test_float8_blockwise_scaling_exact.py create mode 100644 tests/pytorch/test_float8blockwisetensor.py create mode 100644 transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu create mode 100644 transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu create mode 100644 transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py create mode 100644 transformer_engine/pytorch/tensor/float8_blockwise_tensor.py diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 8d38fa59df..21eaededc4 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -30,6 +30,8 @@ PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 6785dbf6f4..0b0e615495 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -11,6 +11,7 @@ add_executable(test_operator test_cast_mxfp8_gated_swiglu.cu test_qdq.cu test_cast_mxfp8.cu + test_cast_float8blockwise.cu test_dequantize_mxfp8.cu test_transpose.cu test_cast_transpose.cu diff --git a/tests/cpp/operator/test_cast_float8blockwise.cu b/tests/cpp/operator/test_cast_float8blockwise.cu new file mode 100644 index 0000000000..cc27f72769 --- /dev/null +++ b/tests/cpp/operator/test_cast_float8blockwise.cu @@ -0,0 +1,641 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include + +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +constexpr size_t kBlockLen = 128; + +enum ProcessingMethod { + CAST_ONLY, + // CAST_DBIAS, + // CAST_DBIAS_DACT, + // CAST_DACT, + // CAST_ACT +}; + +enum ActivationType { + Identity, + // GeLU, + // SiLU, + // ReLU, + // QGeLU, + // SReLU +}; + +template +void scales_from_amax(float amax, const QuantizationOptions& opts, float* qscale_out, + float* qscale_inv_out) { + float input_type_max_val = Quantized_Limits::max(); + float quant_type_max_val = Quantized_Limits::max(); + float eps = opts.amax_epsilon; + amax = std::max(amax, eps); + float qscale = quant_type_max_val / amax; + if (std::isinf(qscale)) { + qscale = input_type_max_val; + } + if (std::isnan(qscale) || amax == 0) { + qscale = 1.0; + } + + if (opts.force_pow_2_scales && qscale != 0.0) { + uint32_t scale_bits = *reinterpret_cast(&qscale); + // Scale must be positive, shift it + uint8_t exp = scale_bits >> 23; + ASSERT_FALSE(exp == 0) << "Subnormals in this path is a logic error."; + qscale = ldexpf(1.0f, static_cast(exp) - 127); + } + + float qscale_inv = 1.0 / qscale; + *qscale_out = qscale; + *qscale_inv_out = qscale_inv; +} + +template +void ref_quantize(const ProcessingMethod processing_method, const InputType* input, + const std::pair& input_hw, OutputType* output, float* scale_inv, + OutputType* output_t, float* scale_inv_t, const QuantizationOptions& opts) { + constexpr size_t kBlockLenX = kBlockLen; + constexpr size_t kBlockLenY = kBlockLen; + + auto quantize_element = [](InputType element, float qscale) -> OutputType { + // Scale in FP32 and cast result to nearest FP8. + return static_cast(float(element) * qscale); + }; + + size_t height = input_hw.first; + size_t width = input_hw.second; + size_t blocks_x = (width + kBlockLenX - 1) / kBlockLenX; + size_t blocks_y = (height + kBlockLenY - 1) / kBlockLenY; + // Find the absolute maximum value in the block + for (size_t block_x = 0; block_x < blocks_x; ++block_x) { + for (size_t block_y = 0; block_y < blocks_y; ++block_y) { + float amax = 0.0f; + // Calculate amax for a tile. + for (size_t i = 0; i < kBlockLenX; ++i) { + for (size_t j = 0; j < kBlockLenY; ++j) { + size_t x_pos = i + block_x * kBlockLenX; + size_t y_pos = j + block_y * kBlockLenY; + if (y_pos >= height || x_pos >= width) { + continue; + } + float val = static_cast(input[y_pos * width + x_pos]); + amax = std::max(amax, std::abs(val)); + } + } + + // We've calculated amax for a tile. Calculate scale and + // scale_inv and populate outputs. + float qscale, qscale_inv; + scales_from_amax(amax, opts, &qscale, &qscale_inv); + + // NOTE: This reference function outputs contigous scale tensors. + // It calculates a naive scale data format. Strides are handled + // in comparison. + if (scale_inv != nullptr) { + scale_inv[block_y * blocks_x + block_x] = qscale_inv; + } + if (scale_inv_t != nullptr) { + scale_inv_t[block_x * blocks_y + block_y] = qscale_inv; + } + + for (size_t i = 0; i < kBlockLenX; ++i) { + for (size_t j = 0; j < kBlockLenY; ++j) { + size_t x_pos = i + block_x * kBlockLenX; + size_t y_pos = j + block_y * kBlockLenY; + if (y_pos >= height || x_pos >= width) { + continue; + } + if (output != nullptr) { + output[y_pos * width + x_pos] = quantize_element(input[y_pos * width + x_pos], qscale); + } + if (output_t != nullptr) { + output_t[x_pos * height + y_pos] = + quantize_element(input[y_pos * width + x_pos], qscale); + } + } + } + } + } +} + +template +void ref_quantize_onedimensional_blocks(const ProcessingMethod processing_method, + const InputType* input, + const std::pair& input_hw, + OutputType* output, float* scale_inv, OutputType* output_t, + float* scale_inv_t, const QuantizationOptions& opts) { + float input_type_max_val = Quantized_Limits::max(); + float quant_type_max_val = Quantized_Limits::max(); + + constexpr size_t kBlockLenX = kBlockLen; + + auto quantize_element = [](InputType element, float qscale) -> OutputType { + // Scale in FP32 and cast result to nearest FP8. + return static_cast(float(element) * qscale); + }; + + size_t height = input_hw.first; + size_t width = input_hw.second; + size_t blocks_x = (width + kBlockLenX - 1) / kBlockLenX; + size_t blocks_x_t = (height + kBlockLenX - 1) / kBlockLenX; + if (output != nullptr && scale_inv != nullptr) { + // Find the absolute maximum value in the block + for (size_t block_x = 0; block_x < blocks_x; ++block_x) { + for (size_t y = 0; y < height; ++y) { + float amax = 0.0f; + // Calculate amax for a tile. + for (size_t i = 0; i < kBlockLenX; ++i) { + size_t x_pos = i + block_x * kBlockLenX; + if (x_pos >= width) { + continue; + } + float val = static_cast(input[y * width + x_pos]); + amax = std::max(amax, std::abs(val)); + } + + // We've calculated amax for a tile. Calculate scale and + // scale_inv and populate outputs. + float qscale, qscale_inv; + scales_from_amax(amax, opts, &qscale, &qscale_inv); + + scale_inv[y + height * block_x] = qscale_inv; + + for (size_t i = 0; i < kBlockLenX; ++i) { + size_t x_pos = i + block_x * kBlockLenX; + if (x_pos >= width) { + continue; + } + output[y * width + x_pos] = quantize_element(input[y * width + x_pos], qscale); + } + } + } + } + if (output_t != nullptr && scale_inv_t != nullptr) { + // Find the absolute maximum value in the block + for (size_t block_x_t = 0; block_x_t < blocks_x_t; ++block_x_t) { + for (size_t x = 0; x < width; ++x) { + float amax = 0.0f; + // Calculate amax for a tile. + for (size_t i = 0; i < kBlockLenX; ++i) { + size_t y_pos = i + block_x_t * kBlockLenX; + if (y_pos >= height) { + continue; + } + float val = static_cast(input[x + y_pos * width]); + amax = std::max(amax, std::abs(val)); + } + + // We've calculated amax for a tile. Calculate scale and + // scale_inv and populate outputs. + float qscale, qscale_inv; + scales_from_amax(amax, opts, &qscale, &qscale_inv); + + scale_inv_t[x + width * block_x_t] = qscale_inv; + + for (size_t i = 0; i < kBlockLenX; ++i) { + size_t y_pos = i + block_x_t * kBlockLenX; + if (y_pos >= height) { + continue; + } + output_t[x * height + y_pos] = quantize_element(input[y_pos * width + x], qscale); + } + } + } + } +} + +inline size_t scale_align_stride(size_t inner_elements) { + return ((inner_elements + 4u - 1u) / 4u) * 4u; +}; + +void compare_scaling_factors(const std::string& name, const float* test, const float* ref, + const size_t row_blocks, const size_t col_blocks, + const size_t test_stride, const size_t ref_stride) { + for (int i = 0; i < row_blocks; ++i) { + for (int j = 0; j < col_blocks; ++j) { + const int test_idx = i * test_stride + j; + const int ref_idx = i * ref_stride + j; + ASSERT_FALSE(test[test_idx] != ref[ref_idx]) + << "Error in " << name << std::endl + << "Mismatch: " << test[test_idx] << " vs " << ref[ref_idx] << " at index " << test_idx + << "," << ref_idx; + } + } +} + +void compare_scaling_factors_one_dimensional_blocks(const std::string& name, const float* test, + const float* ref, const size_t rows, + const size_t col_blocks) { + const size_t test_stride = scale_align_stride(rows); + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < col_blocks; ++j) { + const int test_idx = i + test_stride * j; + const int ref_idx = i + rows * j; + ASSERT_FALSE(test[test_idx] != ref[ref_idx]) + << "Error in " << name << std::endl + << "Mismatch: " << test[test_idx] << " vs " << ref[ref_idx] << " at index " << test_idx + << "," << ref_idx; + } + } +} + +template +void runTestCase(const ProcessingMethod processing_method, const std::vector& shape, + const bool rowwise, const bool colwise, InputsFillCase fill_case, + const QuantizationOptions& opts) { + using namespace test; + using EncodingType = fp32; + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const size_t rows = first_dimension(shape); + const size_t cols = last_dimension(shape); + + size_t blocks_x = (cols + kBlockLen - 1) / kBlockLen; + size_t blocks_y = (rows + kBlockLen - 1) / kBlockLen; + + Tensor input("input", shape, itype); + Tensor grad("grad", shape, itype); + Tensor output_c("output_c", shape, otype, rowwise, colwise, + opts.block_scaling_dim == 2 ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D, &opts); + Tensor output_dbias("output_dbias", {cols}, itype); + + std::unique_ptr ref_output = std::make_unique(rows * cols); + std::unique_ptr ref_output_t = std::make_unique(rows * cols); + std::unique_ptr ref_scale_inv = std::make_unique(blocks_y * blocks_x); + std::unique_ptr ref_scale_inv_t = std::make_unique(blocks_y * blocks_x); + + if (!rowwise) { + ref_output = nullptr; + ref_scale_inv = nullptr; + } + if (!colwise) { + ref_output_t = nullptr; + ref_scale_inv_t = nullptr; + } + + fillCase(&input, fill_case); + fillUniform(&grad); + + Tensor workspace; + switch (processing_method) { + case ProcessingMethod::CAST_ONLY: { + nvte_quantize(input.data(), output_c.data(), 0); + break; + } + } + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + ref_quantize(processing_method, input.rowwise_cpu_dptr(), + {rows, cols}, ref_output.get(), ref_scale_inv.get(), + ref_output_t.get(), ref_scale_inv_t.get(), opts); + + float atol = 0.0; + float rtol = 0.0; + + if (rowwise) { + compareResults("output_c", output_c, ref_output.get(), true, atol, rtol); + compare_scaling_factors("scale_inv", output_c.rowwise_cpu_scale_inv_ptr(), + ref_scale_inv.get(), blocks_y, blocks_x, scale_align_stride(blocks_x), + blocks_x); + } + if (colwise) { + compareResults("output_c_t", output_c, ref_output_t.get(), false, atol, rtol); + compare_scaling_factors("scale_inv_t", output_c.columnwise_cpu_scale_inv_ptr(), + ref_scale_inv_t.get(), blocks_x, blocks_y, scale_align_stride(blocks_y), + blocks_y); + } +} + +template +void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method, + const std::vector& shape, const bool rowwise, + const bool colwise, InputsFillCase fill_case, + const QuantizationOptions& opts) { + using namespace test; + using EncodingType = fp32; + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const size_t rows = first_dimension(shape); + const size_t cols = last_dimension(shape); + + size_t blocks_x = (cols + kBlockLen - 1) / kBlockLen; + size_t blocks_x_t = (rows + kBlockLen - 1) / kBlockLen; + + Tensor input("input", shape, itype); + Tensor grad("grad", shape, itype); + Tensor output_c("output_c", shape, otype, rowwise, colwise, + opts.block_scaling_dim == 2 ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D, &opts); + Tensor output_dbias("output_dbias", {cols}, itype); + + std::unique_ptr ref_output = std::make_unique(rows * cols); + std::unique_ptr ref_output_t = std::make_unique(rows * cols); + std::unique_ptr ref_scale_inv = std::make_unique(rows * blocks_x); + std::unique_ptr ref_scale_inv_t = std::make_unique(cols * blocks_x_t); + + if (!rowwise) { + ref_output = nullptr; + ref_scale_inv = nullptr; + } + if (!colwise) { + ref_output_t = nullptr; + ref_scale_inv_t = nullptr; + } + + fillCase(&input, fill_case); + fillUniform(&grad); + + Tensor workspace; + switch (processing_method) { + case ProcessingMethod::CAST_ONLY: { + nvte_quantize(input.data(), output_c.data(), 0); + break; + } + } + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + ref_quantize_onedimensional_blocks( + processing_method, input.rowwise_cpu_dptr(), {rows, cols}, ref_output.get(), + ref_scale_inv.get(), ref_output_t.get(), ref_scale_inv_t.get(), opts); + + float atol = 0.0; + float rtol = 0.0; + + if (rowwise) { + compareResults("output_c", output_c, ref_output.get(), true, atol, rtol); + compare_scaling_factors_one_dimensional_blocks("scale_inv", + output_c.rowwise_cpu_scale_inv_ptr(), + ref_scale_inv.get(), rows, blocks_x); + } + if (colwise) { + compareResults("output_c_t", output_c, ref_output_t.get(), false, atol, rtol); + compare_scaling_factors_one_dimensional_blocks("scale_inv_t", + output_c.columnwise_cpu_scale_inv_ptr(), + ref_scale_inv_t.get(), cols, blocks_x_t); + } +} + +std::vector> matrix_sizes = { + {1, 16}, {16, 48}, {65, 96}, {128, 128}, {256, 256}, {993, 512}, + {256, 65536}, {2048, 6144}, {16384, 128}, {32768, 160}, {4096, 1632}, {1024, 1}, + {32, 1024}, {16, 512}, {1024}, {8, 32, 1024}, {16, 8, 4, 512}, +}; + +std::vector input_scenarios = { + InputsFillCase::uniform, +}; + +std::vector processing_methods = { + ProcessingMethod::CAST_ONLY, + // ProcessingMethod::CAST_DBIAS, + // ProcessingMethod::CAST_DBIAS_DACT, + // ProcessingMethod::CAST_DACT, + // ProcessingMethod::CAST_ACT, +}; + +// Only GeLU activation tests are supported +std::vector Activation_types = { + ActivationType::Identity, + // ActivationType::GeLU, + // ActivationType::SiLU, + // ActivationType::ReLU, + // ActivationType::QGeLU, + // ActivationType::SReLU, +}; + + +std::vector amax_epsilons = { + 0.0f, +}; + +} // namespace + +class FusedCastFloat8BlockwiseTestSuite + : public ::testing::TestWithParam, transformer_engine::DType, + transformer_engine::DType, InputsFillCase, bool, float, bool>> {}; + +class FusedCastFloat8VectorwiseTestSuite + : public ::testing::TestWithParam, transformer_engine::DType, + transformer_engine::DType, InputsFillCase, bool, float, bool>> {}; + +#define DACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \ + switch (OP_FUNC_TYPE) { \ + case ActivationType::Identity: { \ + constexpr auto OP = &identity; \ + { \ + __VA_ARGS__ \ + } \ + } break; \ + } + +#define ACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \ + switch (OP_FUNC_TYPE) { \ + case ActivationType::Identity: { \ + constexpr auto OP = &identity; \ + { \ + __VA_ARGS__ \ + } \ + } break; \ + } + +TEST_P(FusedCastFloat8BlockwiseTestSuite, TestFusedCastFloat8Blockwise) { + if (getDeviceComputeCapability() < hopperComputeCapability) { + GTEST_SKIP(); + } + + using namespace transformer_engine; + using namespace test; + + const ProcessingMethod processing_method = std::get<0>(GetParam()); + const ActivationType Act_type = std::get<1>(GetParam()); + const auto matrix_size = std::get<2>(GetParam()); + const DType input_type = std::get<3>(GetParam()); + const DType output_type = std::get<4>(GetParam()); + const InputsFillCase fill_case = std::get<5>(GetParam()); + const bool colwise = std::get<6>(GetParam()); + const bool rowwise = true; + const float eps = std::get<7>(GetParam()); + const bool force_pow_2 = std::get<8>(GetParam()); + + QuantizationOptions q_opts; + q_opts.force_pow_2_scales = force_pow_2; + q_opts.amax_epsilon = eps; + q_opts.block_scaling_dim = 2u; + + if (colwise && matrix_size.size() < 2) { + // test_common Tensor initialization code does not + // handle this case. + GTEST_SKIP(); + } + // Skips non Act tests if the Activation type is not an identity + if ( // (processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS) + (processing_method == ProcessingMethod::CAST_ONLY) && Act_type != ActivationType::Identity) { + GTEST_SKIP(); + } + // Skips Act tests if the Activation is an identity + // if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT + // || processing_method == ProcessingMethod::CAST_DACT + // || processing_method == ProcessingMethod::CAST_ACT) && (Act_type == ActivationType::Identity)) { + // GTEST_SKIP(); + // } + + DACT_FUNC_SWITCH( + Act_type, OP, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY( + input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY( + output_type, OutputType, + runTestCase(processing_method, matrix_size, rowwise, colwise, + fill_case, q_opts);););); +} + +TEST_P(FusedCastFloat8VectorwiseTestSuite, TestFusedCastFloat8Vectorwise) { + if (getDeviceComputeCapability() < hopperComputeCapability) { + GTEST_SKIP(); + } + + using namespace transformer_engine; + using namespace test; + + const ProcessingMethod processing_method = std::get<0>(GetParam()); + const ActivationType Act_type = std::get<1>(GetParam()); + const auto matrix_size = std::get<2>(GetParam()); + const DType input_type = std::get<3>(GetParam()); + const DType output_type = std::get<4>(GetParam()); + const InputsFillCase fill_case = std::get<5>(GetParam()); + const bool colwise = std::get<6>(GetParam()); + const bool rowwise = true; + const float eps = std::get<7>(GetParam()); + const bool force_pow_2 = std::get<8>(GetParam()); + + QuantizationOptions q_opts; + q_opts.force_pow_2_scales = force_pow_2; + q_opts.amax_epsilon = eps; + q_opts.block_scaling_dim = 1u; + + if (colwise && matrix_size.size() < 2) { + // test_common Tensor initialization code does not + // handle this case. + GTEST_SKIP(); + } + // Skips non Act tests if the Activation type is not an identity + if ( // (processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS) + (processing_method == ProcessingMethod::CAST_ONLY) && Act_type != ActivationType::Identity) { + GTEST_SKIP(); + } + // Skips Act tests if the Activation is an identity + // if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT + // || processing_method == ProcessingMethod::CAST_DACT + // || processing_method == ProcessingMethod::CAST_ACT) && (Act_type == ActivationType::Identity)) { + // GTEST_SKIP(); + // } + + DACT_FUNC_SWITCH( + Act_type, OP, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY( + input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY( + output_type, OutputType, + runTestCaseOneDimensionalBlocks( + processing_method, matrix_size, rowwise, colwise, fill_case, q_opts);););); +} + +std::string to_string(const ProcessingMethod method) { + switch (method) { + case ProcessingMethod::CAST_ONLY: + return "CAST_ONLY"; + // case ProcessingMethod::CAST_DBIAS: return "CAST_DBIAS"; + // case ProcessingMethod::CAST_DBIAS_DACT: return "CAST_DBIAS_DACT"; + // case ProcessingMethod::CAST_DACT: return "CAST_DACT"; + // case ProcessingMethod::CAST_ACT: return "CAST_ACT"; + default: + return ""; + } +} + +std::string to_string(const ActivationType Act_type) { + switch (Act_type) { + case ActivationType::Identity: + return "Identity"; + // case ActivationType::GeLU: return "GeLU"; + // case ActivationType::SiLU: return "SiLU"; + // case ActivationType::ReLU: return "ReLU"; + // case ActivationType::QGeLU: return "QGeLU"; + // case ActivationType::SReLU: return "SReLU"; + default: + return ""; + } +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, FusedCastFloat8BlockwiseTestSuite, + ::testing::Combine(::testing::ValuesIn(processing_methods), + ::testing::ValuesIn(Activation_types), ::testing::ValuesIn(matrix_sizes), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::ValuesIn(input_scenarios), ::testing::Values(true, false), + ::testing::ValuesIn(amax_epsilons), ::testing::Values(true)), + [](const testing::TestParamInfo& info) { + std::string name = + to_string(std::get<0>(info.param)) + "X" + to_string(std::get<1>(info.param)); + const auto& shape = std::get<2>(info.param); + for (const auto& s : shape) { + name += "X" + std::to_string(s); + } + name += "X" + test::typeName(std::get<3>(info.param)) + "X" + + test::typeName(std::get<4>(info.param)) + "X" + + test::caseName(std::get<5>(info.param)) + "X" + + std::to_string(std::get<6>(info.param)) + "X" + + std::to_string(std::get<7>(info.param) != 0.0f) + "X" + + std::to_string(std::get<8>(info.param)); + return name; + }); + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, FusedCastFloat8VectorwiseTestSuite, + ::testing::Combine(::testing::ValuesIn(processing_methods), + ::testing::ValuesIn(Activation_types), ::testing::ValuesIn(matrix_sizes), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::ValuesIn(input_scenarios), ::testing::Values(true, false), + ::testing::ValuesIn(amax_epsilons), ::testing::Values(true)), + [](const testing::TestParamInfo& info) { + std::string name = + to_string(std::get<0>(info.param)) + "X" + to_string(std::get<1>(info.param)); + const auto& shape = std::get<2>(info.param); + for (const auto& s : shape) { + name += "X" + std::to_string(s); + } + name += "X" + test::typeName(std::get<3>(info.param)) + "X" + + test::typeName(std::get<4>(info.param)) + "X" + + test::caseName(std::get<5>(info.param)) + "X" + + std::to_string(std::get<6>(info.param)) + "X" + + std::to_string(std::get<7>(info.param) != 0.0f) + "X" + + std::to_string(std::get<8>(info.param)); + return name; + }); diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 855d70856a..071c2186e0 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -134,27 +135,19 @@ std::pair get_scales(const NVTEShape& shape, scale_inv_meta ret_rowwise, ret_colwise; - auto block_alignment = std::vector{128ul,4ul}; + auto block_alignment = std::vector{128ul, 4ul}; { auto alignment = block_alignment[0]; - auto scale_dim_0 = DIVUP(DIVUP(first_dim, - static_cast(1)), - alignment) * alignment; + auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast(1)), alignment) * alignment; alignment = block_alignment[1]; - auto scale_dim_1 = DIVUP(DIVUP(last_dim, - static_cast(32)), - alignment) * alignment; + auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(32)), alignment) * alignment; ret_rowwise.shape = {scale_dim_0, scale_dim_1}; } { auto alignment = block_alignment[1]; - auto scale_dim_0 = DIVUP(DIVUP(first_dim, - static_cast(32)), - alignment) * alignment; + auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast(32)), alignment) * alignment; alignment = block_alignment[0]; - auto scale_dim_1 = DIVUP(DIVUP(last_dim, - static_cast(1)), - alignment) * alignment; + auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(1)), alignment) * alignment; ret_colwise.shape = {scale_dim_0, scale_dim_1}; } ret_rowwise.type = DType::kFloat8E8M0; @@ -164,6 +157,58 @@ std::pair get_scales(const NVTEShape& shape, return {ret_rowwise, ret_colwise}; } + if (scaling_mode == NVTE_BLOCK_SCALING_2D) { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + size_t first_dim = first_dimension(shape_vec); + size_t last_dim = last_dimension(shape_vec); + + scale_inv_meta ret_rowwise, ret_colwise; + + { + auto scale_dim_0 = DIVUP(first_dim, static_cast(128)); + auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(128)), 4) * 4; + ret_rowwise.shape = {scale_dim_0, scale_dim_1}; + } + { + auto scale_dim_0 = DIVUP(last_dim, static_cast(128)); + auto scale_dim_1 = DIVUP(DIVUP(first_dim, static_cast(128)), 4) * 4; + ret_colwise.shape = {scale_dim_0, scale_dim_1}; + } + ret_rowwise.type = DType::kFloat32; + ret_colwise.type = DType::kFloat32; + ret_rowwise.type_size = sizeof(float); + ret_colwise.type_size = sizeof(float); + + return {ret_rowwise, ret_colwise}; + } + if (scaling_mode == NVTE_BLOCK_SCALING_1D) { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + size_t first_dim = first_dimension(shape_vec); + size_t last_dim = last_dimension(shape_vec); + scale_inv_meta ret_rowwise, ret_colwise; + + { + auto scale_dim_0 = DIVUP(last_dim, static_cast(128)); + auto scale_dim_1 = DIVUP(first_dim, 4) * 4; + ret_rowwise.shape = {scale_dim_0, scale_dim_1}; + } + { + auto scale_dim_0 = DIVUP(first_dim, static_cast(128)); + auto scale_dim_1 = DIVUP(last_dim, 4) * 4; + ret_colwise.shape = {scale_dim_0, scale_dim_1}; + } + ret_rowwise.type = DType::kFloat32; + ret_colwise.type = DType::kFloat32; + ret_rowwise.type_size = sizeof(float); + ret_colwise.type_size = sizeof(float); + return {ret_rowwise, ret_colwise}; + } NVTE_ERROR("Invalid scaling mode!"); } @@ -171,7 +216,8 @@ std::pair get_scales(const NVTEShape& shape, Tensor::Tensor(const std::string& name, const NVTEShape &shape, const DType type, const bool rowwise, const bool columnwise, - const NVTEScalingMode &scaling_mode) { + const NVTEScalingMode &scaling_mode, + const QuantizationOptions* q_opts) { name_ = name; const size_t seed = create_seed_from_tensor_name(name); gen_.seed(seed); @@ -198,7 +244,7 @@ Tensor::Tensor(const std::string& name, NVTEShape columnwise_shape{nullptr, 0}; std::vector columnwise_shape_vec; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) { // Transpose when tensor scaling columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]); for (size_t i = 0; i < shape.ndim - 1; ++i) { @@ -259,27 +305,33 @@ Tensor::Tensor(const std::string& name, std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0); } } else { - auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(normalized_shape, - tensor_.scaling_mode()); + auto [rowwise_scale_meta, colwise_scale_meta] = + get_scales(normalized_shape, tensor_.scaling_mode()); auto rowwise_scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; auto columnwise_scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size; auto scale_shape = rowwise_scale_meta.shape; auto columnwise_scale_shape = colwise_scale_meta.shape; if (rowwise) { - cudaMalloc((void**)&rowwise_scale_inv, rowwise_scale_size); // NOLINT(*) + cudaMalloc((void **)&rowwise_scale_inv, rowwise_scale_size); // NOLINT(*) cudaMemset(rowwise_scale_inv, 0, rowwise_scale_size); rowwise_scale_inv_cpu_data_ = std::make_unique(rowwise_scale_size); std::fill_n(rowwise_scale_inv_cpu_data_.get(), rowwise_scale_size, 0); - tensor_.set_rowwise_scale_inv(rowwise_scale_inv, DType::kFloat8E8M0, scale_shape); + auto scale_dtype = rowwise_scale_meta.type; + tensor_.set_rowwise_scale_inv(rowwise_scale_inv, scale_dtype, scale_shape); } if (columnwise) { cudaMalloc((void**)&columnwise_scale_inv, columnwise_scale_size); // NOLINT(*) cudaMemset(columnwise_scale_inv, 0, columnwise_scale_size); columnwise_scale_inv_cpu_data_ = std::make_unique(columnwise_scale_size); std::fill_n(columnwise_scale_inv_cpu_data_.get(), columnwise_scale_size, 0); - tensor_.set_columnwise_scale_inv(columnwise_scale_inv, DType::kFloat8E8M0, columnwise_scale_shape); + auto scale_dtype = colwise_scale_meta.type; + tensor_.set_columnwise_scale_inv(columnwise_scale_inv, scale_dtype, columnwise_scale_shape); } } + if (q_opts != nullptr) { + NVTE_CHECK(q_opts->force_pow_2_scales, "Pow2 scales is required for current implementation."); + NVTE_CHECK(q_opts->amax_epsilon == 0.0, "Amax epsilon must be zero for current implementation."); + } } } @@ -311,7 +363,8 @@ void Tensor::to_cpu() const { sizeof(float), cudaMemcpyDeviceToHost); } - auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); + auto [rowwise_scale_meta, colwise_scale_meta] = + get_scales(s, tensor_.scaling_mode()); if (rowwise_) { auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; cudaMemcpy(rowwise_scale_inv_cpu_data_.get(), @@ -349,7 +402,8 @@ void Tensor::from_cpu() const { cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice); } - auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); + auto [rowwise_scale_meta, colwise_scale_meta] = + get_scales(s, tensor_.scaling_mode()); if (rowwise_) { auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr, @@ -368,7 +422,7 @@ void Tensor::from_cpu() const { void Tensor::set_scale(float scale) { if (isFp8Type(dtype())) { NVTE_CHECK(scale_cpu_data_); - if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { + if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { *scale_cpu_data_ = scale; from_cpu(); } @@ -383,27 +437,29 @@ void Tensor::set_scale_inv(float scale_inv) { if (columnwise_) { NVTE_CHECK(columnwise_scale_inv_cpu_data_); } - auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(tensor_.shape(), tensor_.scaling_mode()); + + auto [rowwise_scale_meta, colwise_scale_meta] = + get_scales(tensor_.shape(), tensor_.scaling_mode()); if (rowwise_) { auto num_scales = product(rowwise_scale_meta.shape); - if (num_scales == 1){ + if (num_scales == 1) { rowwise_cpu_scale_inv_ptr()[0] = scale_inv; - } else{ + } else { std::uniform_int_distribution dis(0, 127); - auto* scale_inv_ptr = rowwise_cpu_scale_inv_ptr(); - for (size_t i = 0; i < num_scales; i++){ + auto *scale_inv_ptr = rowwise_cpu_scale_inv_ptr(); + for (size_t i = 0; i < num_scales; i++) { scale_inv_ptr[i] = dis(gen_); } } } if (columnwise_) { auto num_scales = product(colwise_scale_meta.shape); - if (num_scales == 1){ + if (num_scales == 1) { columnwise_cpu_scale_inv_ptr()[0] = scale_inv; - } else{ + } else { std::uniform_int_distribution dis(0, 127); - auto* scale_inv_ptr = columnwise_cpu_scale_inv_ptr(); - for (size_t i = 0; i < num_scales; i++){ + auto *scale_inv_ptr = columnwise_cpu_scale_inv_ptr(); + for (size_t i = 0; i < num_scales; i++) { scale_inv_ptr[i] = dis(gen_); } } @@ -413,23 +469,20 @@ void Tensor::set_scale_inv(float scale_inv) { } void Tensor::shareFP8Meta(const Tensor &other) { - if(isFp8Type(dtype()) && isFp8Type(other.dtype())) { + if (isFp8Type(dtype()) && isFp8Type(other.dtype())) { auto new_tensor = TensorWrapper(other.tensor_.scaling_mode()); auto my_rowwise_data = tensor_.get_rowwise_data(); - new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, - static_cast(my_rowwise_data.dtype), + new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, static_cast(my_rowwise_data.dtype), my_rowwise_data.shape); auto my_columnwise_data = tensor_.get_columnwise_data(); new_tensor.set_columnwise_data(my_columnwise_data.data_ptr, static_cast(my_columnwise_data.dtype), my_columnwise_data.shape); auto other_amax = other.tensor_.get_amax(); - new_tensor.set_amax(other_amax.data_ptr, - static_cast(other_amax.dtype), + new_tensor.set_amax(other_amax.data_ptr, static_cast(other_amax.dtype), other_amax.shape); auto other_scale = other.tensor_.get_scale(); - new_tensor.set_scale(other_scale.data_ptr, - static_cast(other_scale.dtype), + new_tensor.set_scale(other_scale.data_ptr, static_cast(other_scale.dtype), other_scale.shape); auto other_row_scale_inv = other.tensor_.get_rowwise_scale_inv(); new_tensor.set_rowwise_scale_inv(other_row_scale_inv.data_ptr, @@ -460,9 +513,7 @@ std::string to_string(const std::vector &v) { std::vector unravel(const size_t i, const NVTEShape &shape) { std::vector ret; size_t current_i = i; - for (size_t current = shape.ndim - 1; - current > 0; - --current) { + for (size_t current = shape.ndim - 1; current > 0; --current) { ret.push_back(current_i % shape.data[current]); current_i /= shape.data[current]; } @@ -705,7 +756,7 @@ void fillCase_special(Tensor *t) { }); } else { double minAbs = -2.0; - double maxAbs = 1.0; + double maxAbs = 1.0; if constexpr (Case != InputsFillCase::uniform) { minAbs = Quantized_Limits::ranges[Case]; maxAbs = Quantized_Limits::ranges[Case + 1]; @@ -764,14 +815,13 @@ void setRandomScaleInv(Tensor *t) { } bool isFp8Type(DType type) { - return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0; + return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0; } -int32_t getDeviceComputeCapability() -{ - cudaDeviceProp deviceProp; - cudaGetDeviceProperties(&deviceProp, 0); - return 10 * deviceProp.major + deviceProp.minor; +int32_t getDeviceComputeCapability() { + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, 0); + return 10 * deviceProp.major + deviceProp.minor; } size_t first_dimension(const std::vector &shape) { diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 4352056ddb..08df3cf7d1 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -95,21 +95,29 @@ struct TypeInfo{ constexpr static size_t size = sizeof(T); }; +struct QuantizationOptions { + bool force_pow_2_scales = false; + float amax_epsilon = 0.0; + size_t block_scaling_dim = 2u; +}; + class Tensor { public: Tensor(const std::string& name, const NVTEShape &shape, const DType type, const bool rowwise = true, const bool columnwise = false, - const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING); + const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING, + const QuantizationOptions* q_opts = nullptr); Tensor(const std::string& name, const std::vector &shape, const DType type, const bool rowwise = true, const bool columnwise = false, - const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) : - Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode) {} + const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING, + const QuantizationOptions* q_opts = nullptr) : + Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode, q_opts) {} Tensor() {} @@ -136,25 +144,19 @@ class Tensor { if (scale_inv != nullptr) { cudaFree(scale_inv); } - if (columnwise_data_ptr != nullptr){ + if (columnwise_data_ptr != nullptr) { cudaFree(columnwise_data_ptr); } - if (columnwise_scale_inv != nullptr){ + if (columnwise_scale_inv != nullptr) { cudaFree(columnwise_scale_inv); } } - NVTETensor data() const noexcept { - return tensor_.data(); - } + NVTETensor data() const noexcept { return tensor_.data(); } - NVTEShape rowwise_shape() const noexcept { - return tensor_.get_rowwise_data().shape; - } + NVTEShape rowwise_shape() const noexcept { return tensor_.get_rowwise_data().shape; } - NVTEShape columnwise_shape() const noexcept { - return tensor_.get_columnwise_data().shape; - } + NVTEShape columnwise_shape() const noexcept { return tensor_.get_columnwise_data().shape; } NVTEShape rowwise_scale_inv_shape() const { NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); @@ -221,6 +223,8 @@ class Tensor { T *rowwise_cpu_scale_inv_ptr(){ if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); + } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { + NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); } else { NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); } @@ -232,6 +236,8 @@ class Tensor { T *columnwise_cpu_scale_inv_ptr(){ if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); + } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { + NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); } else { NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); } @@ -459,6 +465,7 @@ extern std::vector all_fp_types; bool isFp8Type(DType type); int32_t getDeviceComputeCapability(); +constexpr int32_t hopperComputeCapability = 90; constexpr int32_t blackwellComputeCapability = 100; } // namespace test diff --git a/tests/pytorch/references/blockwise_quantizer_reference.py b/tests/pytorch/references/blockwise_quantizer_reference.py new file mode 100644 index 0000000000..b98966f514 --- /dev/null +++ b/tests/pytorch/references/blockwise_quantizer_reference.py @@ -0,0 +1,302 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import dataclasses +import math +import torch +from typing import Optional, Protocol, Tuple +from references.quantize_scale_calc import scale_from_amax_tensor + + +@dataclasses.dataclass() +class QuantizeResult: + data: torch.Tensor + scale: torch.Tensor + data_t: Optional[torch.Tensor] + scale_t: Optional[torch.Tensor] + + +@dataclasses.dataclass() +class CuBLASScaleMunger: + + def munge_scale_shapes_for_backend( + self, + unmunged: QuantizeResult, + tile_shape: Tuple[int, int], + ) -> QuantizeResult: + """ + cuBLAS GEMMs requires 1x128 quantized tensors to be have scales transposed + so that for an (M, N) tensor, the scales are (RoundUpDiv(N, 128), RoundUp(M, 4)) + + For 128x128 quantized tensors, the GEMM expects (M, PadToAlign(RoundUpDivide(N, 128), 4)) + format. If RoundUpDivide(N, 128) is not divisible by 4, a transformation is required + """ + + def _pad_inner_to_align(s: torch.Tensor, transpose: bool) -> torch.Tensor: + if transpose: + s = s.transpose(-1, -2).contiguous() + M, K = s.shape + if K % 4 == 0: + return s + k_pad = 4 - (K % 4) + return torch.nn.functional.pad(s, (0, k_pad), mode="constant", value=0).contiguous() + + s = _pad_inner_to_align(unmunged.scale, transpose=tile_shape[0] == 1) + if unmunged.scale_t is None: + s_t = None + else: + s_t = _pad_inner_to_align(unmunged.scale_t, transpose=tile_shape[0] == 1) + return QuantizeResult(unmunged.data, s, unmunged.data_t, s_t) + + def demunge_scale_shape_from_backend( + cls, + qtensor_shape: Tuple[int, int], + scales: torch.Tensor, + tile_shape: Tuple[int, int], + ) -> torch.Tensor: + """ + Inverse operation of munge_scale_shapes_for_backend + """ + if tile_shape[0] != 1: + # 2D block quantized tensor may need padding stripped off + derived_scale_k_shape = math.ceil(qtensor_shape[1] / tile_shape[1]) + else: + derived_scale_k_shape = qtensor_shape[0] + M, K = scales.shape + if derived_scale_k_shape != K: + scales = scales[:, :derived_scale_k_shape].contiguous() + if tile_shape[0] == 1: + return scales.transpose(-1, -2).contiguous() + else: + return scales + + +@dataclasses.dataclass() +class BlockwiseQuantizerReference: + """ + A reference QuantizeOp for subchannel/block hybrid quantization. + + Defers to ref GEMMs and quantizization formatting based on the backend. + """ + + def __init__(self) -> None: + self.scale_munger = CuBLASScaleMunger() + + @classmethod + def _quantize_square_block_tiling( + cls, + x: torch.Tensor, + quant_dtype: torch.dtype, + tile_len: int, + *, + return_transpose: bool, + pow_2_scales: bool, + eps: float, + ) -> QuantizeResult: + M, K = x.shape + + pad_m_k = [0, 0] + if K % tile_len != 0: + pad_m_k[1] = tile_len - (K % tile_len) + if M % tile_len != 0: + pad_m_k[0] = tile_len - (M % tile_len) + + unpadded_m, unpadded_k = M, K + if pad_m_k[0] != 0 or pad_m_k[1] != 0: + x = torch.nn.functional.pad( + x, (0, pad_m_k[1], 0, pad_m_k[0]), mode="constant", value=0 + ).contiguous() + M, K = x.shape + + x_tiled = x.reshape(M // tile_len, tile_len, K // tile_len, tile_len) + amax_grid = ( + torch.abs(x_tiled.transpose(-3, -2)) + .reshape(M // tile_len, K // tile_len, tile_len**2) + .amax(dim=-1) + ).float() + dtype_max = torch.finfo(quant_dtype).max + + scale, scale_inv, _ = scale_from_amax_tensor( + x_dtype=x.dtype, + amax=amax_grid, + quant_dtype=quant_dtype, + pow_2_scales=pow_2_scales, + eps=eps, + ) + qx = x_tiled * scale.reshape(M // tile_len, 1, K // tile_len, 1) + qx = torch.clamp(qx, min=-dtype_max, max=dtype_max) + qx = qx.to(dtype=quant_dtype) + qx = qx.reshape(M, K) + if unpadded_k != K or unpadded_m != M: + qx = qx[:unpadded_m, :unpadded_k].contiguous() + if return_transpose: + # Valid because of square block sizes + qx_t = qx.transpose(-1, -2).contiguous() + scale_inv_t = scale_inv.transpose(-1, -2).contiguous() + else: + qx_t = None + scale_inv_t = None + + return QuantizeResult(data=qx, scale=scale_inv, data_t=qx_t, scale_t=scale_inv_t) + + @classmethod + def _quantize_vectorwise_reference( + cls, + x: torch.Tensor, + quant_dtype: torch.dtype, + tile_len: int, + *, + pow_2_scales: bool, + eps: float, + ) -> Tuple[torch.Tensor, torch.Tensor]: + M, K = x.shape + dtype_max = torch.finfo(quant_dtype).max + x_tiled = x.reshape(M, K // tile_len, tile_len) + amax_grid = torch.abs(x_tiled).amax(dim=-1).float() + scale, scale_inv, _ = scale_from_amax_tensor( + x_dtype=x.dtype, + amax=amax_grid, + quant_dtype=quant_dtype, + pow_2_scales=pow_2_scales, + eps=eps, + ) + qx = x_tiled * scale.reshape(M, K // tile_len, 1) + qx = torch.clamp(qx, min=-dtype_max, max=dtype_max) + qx = qx.to(dtype=quant_dtype) + qx = qx.reshape(M, K) + return qx, scale_inv + + @classmethod + def _quantize_vector_tiling( + cls, + x: torch.Tensor, + quant_dtype: torch.dtype, + tile_len: int, + *, + return_transpose: bool, + pow_2_scales: bool, + eps: float, + ) -> QuantizeResult: + M, K = x.shape + + if K % tile_len == 0: + qref_input = x + else: + pad_amount = tile_len - (K % tile_len) + pad = (0, pad_amount) + qref_input = torch.nn.functional.pad(x, pad, mode="constant", value=0) + qout_padded, scale_inv = cls._quantize_vectorwise_reference( + qref_input, + quant_dtype, + tile_len=tile_len, + pow_2_scales=pow_2_scales, + eps=eps, + ) + if K % tile_len == 0: + qout = qout_padded + else: + qout = qout_padded[:, :K].contiguous() + + if return_transpose: + if M % tile_len == 0: + qref_input = x.transpose(-1, -2).contiguous() + else: + amount_to_pad = tile_len - (M % tile_len) + pad = (0, amount_to_pad) + qref_input = torch.nn.functional.pad( + x.transpose(-1, -2), pad, mode="constant", value=0 + ).contiguous() + qout_t_padded, scale_inv_t = cls._quantize_vectorwise_reference( + qref_input, + quant_dtype, + tile_len=tile_len, + pow_2_scales=pow_2_scales, + eps=eps, + ) + if M % tile_len == 0: + qout_t = qout_t_padded + else: + qout_t = qout_t_padded[:, :M].contiguous() + else: + qout_t, scale_inv_t = None, None + + return QuantizeResult(data=qout, scale=scale_inv, data_t=qout_t, scale_t=scale_inv_t) + + def ref_dequantize_rowwise( + self, + q: torch.Tensor, + quant_tile_shape: Tuple[int, int], + s: torch.Tensor, + dtype: torch.dtype, + ) -> torch.Tensor: + assert q.dim() == 2 + q_M, q_K = q.shape + s = self.scale_munger.demunge_scale_shape_from_backend((q_M, q_K), s, quant_tile_shape) + assert len(s.shape) == 2 + m_tiles, k_tiles = s.shape + M, K = q.shape + unpadded_m, unpadded_k = M, K + if M % quant_tile_shape[0] != 0 or K % quant_tile_shape[1] != 0: + m_pad_amount = (quant_tile_shape[0] - (M % quant_tile_shape[0])) % quant_tile_shape[0] + k_pad_amount = (quant_tile_shape[1] - (K % quant_tile_shape[1])) % quant_tile_shape[1] + q = torch.nn.functional.pad( + q, (0, k_pad_amount, 0, m_pad_amount), mode="constant", value=0 + ).contiguous() + M, K = q.shape + q_tiled = q.reshape(m_tiles, quant_tile_shape[0], k_tiles, quant_tile_shape[1]) + result = q_tiled.to(dtype) * s.reshape(m_tiles, 1, k_tiles, 1) + result = result.view(M, K).to(dtype) + if M != unpadded_m or K != unpadded_k: + result = result[:unpadded_m, :unpadded_k].contiguous() + return result + + def quantize( + self, + x: torch.Tensor, + quant_dtype: torch.dtype, + return_transpose: bool = False, + eps: float = 0.0, + pow_2_scales: bool = False, + quant_tile_shape: Tuple[int, int] = (128, 128), + ) -> QuantizeResult: + # sanity checks + assert x.dim() == 2 + assert x.dtype in ( + torch.float, + torch.float16, + torch.bfloat16, + torch.float32, + ), "Unsupported input dtype." + assert quant_dtype in ( + torch.float8_e4m3fn, + torch.float8_e5m2, + ), "Unsupported quant dtype." + + assert quant_tile_shape in ((1, 128), (128, 128)) + if quant_tile_shape[0] == 1: + # Quantize row-wise + return self.scale_munger.munge_scale_shapes_for_backend( + self._quantize_vector_tiling( + x, + quant_dtype, + tile_len=quant_tile_shape[1], + return_transpose=return_transpose, + pow_2_scales=pow_2_scales, + eps=eps, + ), + quant_tile_shape, + ) + else: + # Quantize block-wise + return self.scale_munger.munge_scale_shapes_for_backend( + self._quantize_square_block_tiling( + x, + quant_dtype, + tile_len=quant_tile_shape[0], + return_transpose=return_transpose, + pow_2_scales=pow_2_scales, + eps=eps, + ), + quant_tile_shape, + ) diff --git a/tests/pytorch/references/quantize_scale_calc.py b/tests/pytorch/references/quantize_scale_calc.py new file mode 100644 index 0000000000..f36ddca3b2 --- /dev/null +++ b/tests/pytorch/references/quantize_scale_calc.py @@ -0,0 +1,60 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from typing import Tuple +import torch + + +def scale_from_amax_tensor( + x_dtype: torch.dtype, + amax: torch.Tensor, + quant_dtype: torch.dtype, + *, + eps: float, + pow_2_scales: bool, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Derives quantization and dequantization from amax and options. + + Reference implementation for scale calculation. + + Returns: + - scale: quantization scales + - scale_inv: dequantization scales + - amax: Amax tensor with updates made for extrema values. + """ + assert amax.dtype == torch.float, "amax must be a float tensor." + fp8_max = torch.finfo(quant_dtype).max + # Clamping amax to avoid division by small numbers + amax = torch.max(amax, torch.tensor(eps)) + + # Compute scale factor + scale = torch.div(fp8_max, amax) + # Note frexp doesn't give back inf for exponent with an inf input + # We take care of inf before pow_2_scales + scale = torch.where(scale == torch.inf, torch.finfo(x_dtype).max, scale) + if pow_2_scales: + # Calculate rounded down exponent + _, exp = torch.frexp(scale) + # Positive numbers are always returned as mant, exp with + # a mantissa in [0.5, 1.0). Because a normal float has a mantissa with + # hidden bit in [1.0, 2.0), the exponent will be off by exactly one because + # of the shift. Subnormal and zero cases need not be considered because + # the smallest possible result of fp8_max / amax is still normal. + exp = exp - 1 + # No subnormals and zero. + assert (exp > -127).all() + unity = torch.tensor([1.0], device=exp.device) + torch.ldexp(unity, exp, out=scale) + # Case where amax is inf. The frexp, ldexp logic changes 0.0 scales + # Return 0.0 for 0.0 scale for consistency with non-pow2 scale + # calculation. + scale = torch.where(amax == float("inf"), 0.0, scale) + + # Handle overflow cases for amax zero causing NaN + scale = torch.where(amax == 0, 1.0, scale) + + # Compute scale_inv + scale_inv = torch.reciprocal(scale) + + return scale, scale_inv, amax diff --git a/tests/pytorch/references/ref_per_tensor_cs.py b/tests/pytorch/references/ref_per_tensor_cs.py index dad0c42357..5e803f7ed5 100644 --- a/tests/pytorch/references/ref_per_tensor_cs.py +++ b/tests/pytorch/references/ref_per_tensor_cs.py @@ -6,63 +6,16 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.constants import TE_DType_To_Torch - - -# Compute scale and scale_inv from amax -def _ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales): - # Clamping amax to avoid division by small numbers - amax = torch.max(amax, torch.tensor(eps)) - - # Compute scale factor - scale = torch.div(fp8_max, amax) - # Note frexp doesn't give back inf for exponent with an inf input - # We take care of inf before pow_2_scales - # option1: set scale to fp32 max when scale is inf - scale = torch.where(scale == torch.inf, torch.finfo(torch.float32).max, scale) - # option2: when scale is inf, set scale to 1 - scale = torch.where(scale == torch.inf, 1.0, scale) - if pow_2_scales: - # Calculate rounded down exponent - _, exp = torch.frexp(scale) - # Positive numbers are always returned as mant, exp with - # a mantissa in [0.5, 1.0). Because a normal float has a mantissa with - # hidden bit in [1.0, 2.0), the exponent will be off by exactly one because - # of the shift. Subnormal and zero cases need not be considered because - # the smallest possible result of fp8_max / amax is still normal. - exp = exp - 1 - # No subnormals and zero. - assert (exp > -127).all() - # TODO: If/when adding a URM option an option is to cap to 126 - # rather than allowing the full range of FP32 (2 - 2^23) x 2^127 - # addresses cases where adding a mantissa overflows into inf scales. - # Not necessary currently without additional scale smudging options. - unity = torch.tensor([1.0], device=exp.device) - torch.ldexp(unity, exp, out=scale) - # Case where amax is inf. The frexp, ldexp logic changes 0.0 scales - # Return 0.0 for 0.0 scale for consistency with non-pow2 scale - # calculation. - scale = torch.where(amax == float("inf"), 0.0, scale) - - # Handle overflow cases for amax zero causing NaN - scale = torch.where(amax == 0, 1.0, scale) - # Compute scale_inv - scale_inv = torch.reciprocal(scale) - - return scale, scale_inv +from references.quantize_scale_calc import scale_from_amax_tensor # compute amax and scale def _ref_compute_amax_scale(x, quant_dtype, eps, pow_2_scales): x_fp32 = x.to(torch.float32) amax = torch.amax(torch.abs(x_fp32)).view(1) - assert amax.dtype == torch.float, "amax must be a float tensor." - fp8_max = torch.finfo(quant_dtype).max - - scale, scale_inv = _ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales) - # Clamping amax to avoid division by small numbers - amax = torch.max(amax, torch.tensor(eps)) - - return scale, scale_inv, amax + return scale_from_amax_tensor( + torch.float32, amax, quant_dtype, eps=eps, pow_2_scales=pow_2_scales + ) def _multi_dim_transpose(tensor): @@ -113,7 +66,3 @@ def ref_per_tensor_cs_cast( qx_t = _multi_dim_transpose(qx) sx_t = sx return qx, sx, qx_t, sx_t - - -def ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales): - return _ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales) diff --git a/tests/pytorch/test_float8_blockwise_scaling_exact.py b/tests/pytorch/test_float8_blockwise_scaling_exact.py new file mode 100644 index 0000000000..e638fe8c5b --- /dev/null +++ b/tests/pytorch/test_float8_blockwise_scaling_exact.py @@ -0,0 +1,294 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from typing import Tuple +import math +import pytest +import torch +import transformer_engine as te +import transformer_engine_torch as tex +from transformer_engine.pytorch.utils import get_device_compute_capability +from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( + Float8BlockQuantizer, + Float8BlockwiseQTensor, +) +from references.blockwise_quantizer_reference import ( + BlockwiseQuantizerReference, + QuantizeResult, +) + +# TODO replace with call to fp8.py when recipe added. +recipe_available = get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8 +reason_for_no_recipe = "Quantize kernels require TMA and are only relevant with GEMMS." + + +def initialize_for_many_scales( + x_shape_2d: Tuple[int, int], tile_shape: Tuple[int, int], *, dtype: torch.dtype, device: str +) -> torch.Tensor: + """ + Put separate distributions into each quantization tile + to avoid many tiles having similar scale values and + causing false passes. + """ + tile_grid_shape = ( + math.ceil(x_shape_2d[0] / tile_shape[0]), + math.ceil(x_shape_2d[1] / tile_shape[1]), + ) + # Arbitrary size + max_val = 8192.0 + # Make a uniform distribution of [-max_val, max_val] + tile_extrema = torch.rand(*tile_grid_shape, dtype=dtype) * max_val * 2 - max_val + result = torch.empty(x_shape_2d, dtype=dtype, device=device) + tile_elements = tile_shape[0] * tile_shape[1] + for i in range(tile_grid_shape[0]): + for j in range(tile_grid_shape[1]): + target = tile_extrema[i, j].item() + step = target / (tile_elements) + if target == 0: + tile = torch.zeros(tile_shape, dtype=dtype, device=device) + else: + tile = torch.arange(0.0, target, step=step, dtype=dtype, device=device) + tile = tile.reshape(*tile_shape) + min_dst_vals = (i * tile_shape[0], j * tile_shape[1]) + max_dst_vals = ( + min((i + 1) * tile_shape[0], x_shape_2d[0]), + min((j + 1) * tile_shape[1], x_shape_2d[1]), + ) + max_src_vals = ( + max_dst_vals[0] - min_dst_vals[0], + max_dst_vals[1] - min_dst_vals[1], + ) + result[min_dst_vals[0] : max_dst_vals[0], min_dst_vals[1] : max_dst_vals[1]] = tile[ + : max_src_vals[0], : max_src_vals[1] + ] + return result + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (128, 128), + (256, 256), + (256, 1024), + (1024, 256), + # Padding required cases + (256, 272), + (303, 300), + (305, 256), + # Some larger tiles. + (2000, 2000), + (2048, 2000), + (2000, 1024), + (2048, 1024), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("eps", [0], ids=["eps_0"]) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"] +) +@pytest.mark.parametrize("pow_2_scales", [True], ids=["pow2scales"]) +@pytest.mark.parametrize("tile_size", [(1, 128), (128, 128)], ids=["1DTile", "2DTile"]) +def test_quantization_block_tiling_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + quant_dtype: torch.dtype, + eps: float, + return_transpose: bool, + pow_2_scales: bool, + tile_size: Tuple[int, int], +) -> None: + te_dtype = TE_DType[quant_dtype] + if tile_size == (1, 128): + block_scaling_dim = 1 + elif tile_size == (128, 128): + block_scaling_dim = 2 + else: + raise ValueError("Non support tile size") + # This test runs a comparison of the ref class versus the class using + # CUDA kernels to quantize. They should quantize identically for pixels + # that are not DC values in the scale factor shape. + ref_quantizer = BlockwiseQuantizerReference() + sut_quantizer = Float8BlockQuantizer( + fp8_dtype=te_dtype, + rowwise=True, + columnwise=return_transpose, + amax_epsilon=eps, + force_pow_2_scales=pow_2_scales, + block_scaling_dim=block_scaling_dim, + ) + + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Input + x = initialize_for_many_scales((M, N), tile_size, dtype=x_dtype, device=device) + + x_fp8_sut = sut_quantizer.make_empty((M, N), dtype=x_dtype, device=device, requires_grad=False) + x_fp8_sut = sut_quantizer.update_quantized(x, x_fp8_sut) + + assert x_fp8_sut._rowwise_data is not None + qx: torch.Tensor = x_fp8_sut._rowwise_data.view(dtype=quant_dtype) + assert x_fp8_sut._rowwise_scale_inv is not None + sx: torch.Tensor = x_fp8_sut._rowwise_scale_inv + qx_t = x_fp8_sut._columnwise_data + sx_t = x_fp8_sut._columnwise_scale_inv + + qresult_ref = ref_quantizer.quantize( + x, + quant_dtype=quant_dtype, + return_transpose=return_transpose, + eps=eps, + pow_2_scales=pow_2_scales, + quant_tile_shape=tile_size, + ) + qx_ref, sx_ref, qx_t_ref, sx_t_ref = ( + qresult_ref.data, + qresult_ref.scale, + qresult_ref.data_t, + qresult_ref.scale_t, + ) + + # Check + torch.testing.assert_close(qx.float(), qx_ref.float(), atol=0.0, rtol=0.0) + # Zero out values that are don't care values + # Scale format has padding. + scale_mask = torch.ones( + (math.ceil(M / tile_size[0]), math.ceil(N / tile_size[1])), device=sx.device + ) + scale_mask = ref_quantizer.scale_munger.munge_scale_shapes_for_backend( + QuantizeResult(qx, scale_mask, None, None), tile_size + ).scale + sx = sx * scale_mask + torch.testing.assert_close(sx, sx_ref, atol=0.0, rtol=0.0) + + if return_transpose: + assert qx_t is not None + qx_t = qx_t.view(dtype=quant_dtype) + assert qx_t_ref is not None + assert sx_t is not None + assert sx_t_ref is not None + scale_mask = torch.ones( + (math.ceil(N / tile_size[0]), math.ceil(M / tile_size[1])), + device=sx_t.device, + ) + scale_mask = ref_quantizer.scale_munger.munge_scale_shapes_for_backend( + QuantizeResult(qx_t, scale_mask, None, None), tile_size + ).scale + sx_t = sx_t * scale_mask + torch.testing.assert_close(qx_t.float(), qx_t_ref.float(), atol=0.0, rtol=0.0) + torch.testing.assert_close(sx_t, sx_t_ref, atol=0.0, rtol=0.0) + else: + # should be None + assert qx_t is None and qx_t_ref is None + assert sx_t is None and sx_t_ref is None + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (128, 128), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("eps", [0], ids=["eps_0"]) +@pytest.mark.parametrize("pow_2_scales", [True], ids=["pow2scales"]) +@pytest.mark.parametrize("tile_size", [(128, 128)]) +@pytest.mark.parametrize("extrema_high", [False, True], ids=["zeros", "maxes"]) +def test_quantization_block_tiling_extrema_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + quant_dtype: torch.dtype, + eps: float, + pow_2_scales: bool, + tile_size: Tuple[int, int], + extrema_high: bool, +) -> None: + # This test runs a single tile through a quantizer as a way to test + # branch coverage of scale computation. + te_dtype = TE_DType[quant_dtype] + if tile_size == (1, 128): + block_scaling_dim = 1 + elif tile_size == (128, 128): + block_scaling_dim = 2 + else: + raise ValueError("Non support tile size") + ref_quantizer = BlockwiseQuantizerReference() + sut_quantizer = Float8BlockQuantizer( + fp8_dtype=te_dtype, + rowwise=True, + columnwise=False, + amax_epsilon=eps, + force_pow_2_scales=pow_2_scales, + block_scaling_dim=block_scaling_dim, + ) + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + return_transpose = False + # Input + if extrema_high: + x = torch.full((M, N), torch.finfo(x_dtype).max, dtype=x_dtype, device=device) + else: + x = torch.zeros((M, N), dtype=x_dtype, device=device) + + # Run cast and transpose kernel + # Internal call ops.quantize_tensorwise + x_fp8_sut = sut_quantizer.make_empty((M, N), dtype=x_dtype, device=device, requires_grad=False) + x_fp8_sut = sut_quantizer.update_quantized(x, x_fp8_sut) + qx = x_fp8_sut._rowwise_data.view(dtype=quant_dtype) + sx = x_fp8_sut._rowwise_scale_inv + + qresult_ref = ref_quantizer.quantize( + x, + quant_dtype=quant_dtype, + return_transpose=return_transpose, + eps=eps, + pow_2_scales=pow_2_scales, + quant_tile_shape=tile_size, + ) + qx_ref, sx_ref = ( + qresult_ref.data, + qresult_ref.scale, + ) + + # Check + torch.testing.assert_close(qx.float(), qx_ref.float(), atol=0.0, rtol=0.0) + torch.testing.assert_close(sx.flatten()[0], sx_ref.flatten()[0], atol=0.0, rtol=0.0) + + if extrema_high: + expected_value = torch.finfo(quant_dtype).max / torch.finfo(x_dtype).max + if pow_2_scales: + expected_value = math.floor(math.log2(expected_value)) + expected_value = math.pow(2.0, expected_value) + expected_value = 1 / expected_value + elif not extrema_high and eps == 0: + expected_value = 1.0 + else: + assert not extrema_high + # eps is small enough to trigger inf in quant_dtype_max / eps + if pow_2_scales: + expected_value = math.pow(2.0, -127) + else: + expected_value = 1 / torch.finfo(x_dtype).max + torch.testing.assert_close( + sx.flatten()[0], + torch.tensor(expected_value, device=sx.device), + atol=0.0, + rtol=0.0, + ) diff --git a/tests/pytorch/test_float8blockwisetensor.py b/tests/pytorch/test_float8blockwisetensor.py new file mode 100644 index 0000000000..d030426b74 --- /dev/null +++ b/tests/pytorch/test_float8blockwisetensor.py @@ -0,0 +1,442 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from collections.abc import Iterable +import io +import math +from typing import Any, Dict, List, Tuple, Union + +import pytest +import torch + +import transformer_engine.common.recipe +import transformer_engine.pytorch as te +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( + Float8BlockQuantizer, + Float8BlockwiseQTensor, +) +from transformer_engine.pytorch.utils import get_device_compute_capability +import transformer_engine_torch as tex + +# PyTorch tensor dtypes +_dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16] +# TE FP8 dtypes +_fp8_dtypes: List[tex.DType] = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2] + +# Numerical tolerances with FP8 types +_tols: Dict[tex.DType, Dict[str, float]] = { + tex.DType.kFloat8E4M3: dict(rtol=0.125, atol=0.08), + tex.DType.kFloat8E5M2: dict(rtol=0.25, atol=0.125), +} + + +def _to_list(x: Union[Iterable, Any]) -> List: + """Convert to list if iterable, otherwise put in singleton list""" + if isinstance(x, Iterable): + return list(x) + else: + return [x] + + +# Types that can be interpreted as tensor dims +DimsType = Union[Iterable[int], int] + +# TODO replace with call to fp8.py when recipe added. +recipe_available = get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8 +reason_for_no_recipe = "Quantize kernels require TMA and are only relevant with GEMMS." + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +class TestFloat8BlockwiseTensor: + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + def test_constructor( + self, + dims: DimsType = 1, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + dtype: torch.dtype = torch.float32, + is_2D_scaled: bool = True, + ) -> None: + """Call constructor and perform sanity checks""" + dims = _to_list(dims) + + rowwise = True + columnwise = True + quantizer = Float8BlockQuantizer( + fp8_dtype=fp8_dtype, + rowwise=rowwise, + columnwise=columnwise, + block_scaling_dim=2 if is_2D_scaled else 1, + ) + + scale_dims = quantizer.get_scale_shape(dims, columnwise=False) + columnwise_scale_dims = quantizer.get_scale_shape(dims, columnwise=True) + columnwise_dims = quantizer.get_columnwise_shape(dims) + tensor = Float8BlockwiseQTensor( + shape=dims, + dtype=dtype, + rowwise_data=torch.zeros(dims, device="cuda", dtype=torch.uint8), + rowwise_scale_inv=torch.zeros(scale_dims, device="cuda", dtype=torch.float32), + columnwise_data=torch.zeros(columnwise_dims, device="cuda", dtype=torch.uint8), + columnwise_scale_inv=torch.zeros( + columnwise_scale_dims, device="cuda", dtype=torch.float32 + ), + fp8_dtype=fp8_dtype, + is_2D_scaled=is_2D_scaled, + quantizer=quantizer, + ) + assert list(tensor.size()) == dims, "Incorrect dims" + assert tensor.dtype == dtype, "Incorrect nominal dtype" + assert tensor.is_cuda, "Incorrect device" + + def _test_quantize_dequantize( + self, + quantizer: Float8BlockQuantizer, + dtype: torch.dtype = torch.float32, + dims: DimsType = (23, 128), + rtol: float = 0.0, + atol: float = 0.0, + dequant_columnwise: bool = False, + use_cpp_allocation: bool = False, + ) -> None: + """Check numerical error when casting to FP8 and back""" + dims = _to_list(dims) + + # Initialize random data + x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + x_ref_cuda = x_ref.to("cuda") + + # Cast to FP8 and back + if not use_cpp_allocation: + x_fp8 = quantizer.make_empty(shape=dims, device="cuda") + quantizer.update_quantized(x_ref_cuda, x_fp8) + else: + # This codepath allows the CPP binding to allocate the output + # tensor + x_fp8 = tex.quantize(x_ref_cuda, quantizer, None, None) + if dequant_columnwise: + # Strip out rowwise data to verify dequantization of + # columnwise data. + x_fp8.update_usage(rowwise_usage=False, columnwise_usage=True) + x_fp8 = x_fp8.dequantize(dtype=dtype).cpu() + + # Check results + torch.testing.assert_close(x_fp8, x_ref, rtol=rtol, atol=atol) + + # Make sure we are not trivially passing the test + with pytest.raises(AssertionError): + torch.testing.assert_close(x_fp8, -x_ref, rtol=rtol, atol=atol) + + @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("block_scaling_dim", [1, 2]) + def test_quantize_dequantize_dtypes( + self, fp8_dtype: tex.DType, dtype: torch.dtype, block_scaling_dim: int + ) -> None: + atol = _tols[fp8_dtype]["atol"] + rtol = _tols[fp8_dtype]["rtol"] + quantizer = Float8BlockQuantizer( + fp8_dtype=fp8_dtype, + rowwise=True, + columnwise=False, + block_scaling_dim=block_scaling_dim, + ) + self._test_quantize_dequantize(quantizer=quantizer, dtype=dtype, atol=atol, rtol=rtol) + + @pytest.mark.parametrize( + "dims", [[], 256, 311, [264], [256, 512], [250, 500], [7, 5, 3], [2, 3, 5, 3]] + ) + @pytest.mark.parametrize("block_scaling_dim", [1, 2]) + @pytest.mark.parametrize("dq_columnwise", [True, False]) + def test_quantize_dequantize_dims( + self, dims: DimsType, block_scaling_dim: int, dq_columnwise: bool + ) -> None: + atol = _tols[tex.DType.kFloat8E4M3]["atol"] + rtol = _tols[tex.DType.kFloat8E4M3]["rtol"] + quantizer = Float8BlockQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=dq_columnwise, + block_scaling_dim=block_scaling_dim, + ) + self._test_quantize_dequantize( + quantizer=quantizer, + dims=dims, + atol=atol, + rtol=rtol, + dequant_columnwise=dq_columnwise, + ) + + @pytest.mark.parametrize( + "dims", [[], 256, 311, [264], [256, 512], [250, 500], [7, 5, 3], [2, 3, 5, 3]] + ) + @pytest.mark.parametrize("block_scaling_dim", [1, 2]) + @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) + @pytest.mark.parametrize("dq_columnwise", [True, False]) + def test_quantize_dequantize_dims_cpp_allocate_output( + self, dims: DimsType, block_scaling_dim: int, fp8_dtype: tex.DType, dq_columnwise: bool + ) -> None: + atol = _tols[fp8_dtype]["atol"] + rtol = _tols[fp8_dtype]["rtol"] + quantizer = Float8BlockQuantizer( + fp8_dtype=fp8_dtype, + rowwise=True, + columnwise=dq_columnwise, + block_scaling_dim=block_scaling_dim, + ) + self._test_quantize_dequantize( + quantizer=quantizer, + dims=dims, + atol=atol, + rtol=rtol, + dequant_columnwise=dq_columnwise, + use_cpp_allocation=True, + ) + + @pytest.mark.parametrize("dims", [[256, 512], [250, 500]]) + @pytest.mark.parametrize("block_scaling_dim", [1, 2]) + def test_data_accessors(self, dims: DimsType, block_scaling_dim: int) -> None: + """Test data accessors of Float8BlockwiseQTensor""" + device = "cuda" + dtype = torch.bfloat16 + x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device) + y_hp = torch.rand(_to_list(dims), dtype=dtype, device=device) + + fp8_dtype = tex.DType.kFloat8E4M3 + quantizer = Float8BlockQuantizer( + fp8_dtype=fp8_dtype, + rowwise=True, + columnwise=True, + block_scaling_dim=block_scaling_dim, + ) + + # Create FP8 tensor + x_fp8 = quantizer.quantize(x_hp) + + x_recovered = x_fp8.data + torch.testing.assert_close(x_recovered, x_hp, **_tols[fp8_dtype]) + + x_fp8.data = y_hp + y_recovered = x_fp8.data + torch.testing.assert_close(y_recovered, y_hp, **_tols[fp8_dtype]) + + @pytest.mark.parametrize("dims", [[256, 512], [250, 500]]) + @pytest.mark.parametrize("block_scaling_dim", [1, 2]) + def test_serialization(self, dims: DimsType, block_scaling_dim: int) -> None: + """Test serialization of Float8BlockwiseQTensor""" + device = "cuda" + dtype = torch.bfloat16 + x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device) + quantizer = Float8BlockQuantizer( + fp8_dtype=tex.DType.kFloat8E5M2, + rowwise=True, + columnwise=True, + block_scaling_dim=block_scaling_dim, + ) + + # Create FP8 tensor + x_fp8 = quantizer.quantize(x_hp) + + # Save tensor + buffer = io.BytesIO() + torch.save(x_fp8, buffer) + + # Load tensor + buffer.seek(0) + x_fp8_loaded = torch.load(buffer, weights_only=False) + + # Test that loaded tensor matches original + assert isinstance(x_fp8_loaded, Float8BlockwiseQTensor) + torch.testing.assert_close(x_fp8_loaded._rowwise_data, x_fp8._rowwise_data) + torch.testing.assert_close(x_fp8_loaded._columnwise_data, x_fp8._columnwise_data) + torch.testing.assert_close(x_fp8_loaded._rowwise_scale_inv, x_fp8._rowwise_scale_inv) + torch.testing.assert_close(x_fp8_loaded._columnwise_scale_inv, x_fp8._columnwise_scale_inv) + torch.testing.assert_close(x_fp8_loaded.data, x_fp8.data) + assert x_fp8_loaded._is_2D_scaled == x_fp8._is_2D_scaled + assert x_fp8_loaded.dtype == x_fp8.dtype + assert x_fp8_loaded._fp8_dtype == x_fp8._fp8_dtype + + # Test that dequantized values match + x_fp8_dequant = x_fp8.dequantize() + x_fp8_loaded_dequant = x_fp8_loaded.dequantize() + torch.testing.assert_close(x_fp8_loaded_dequant, x_fp8_dequant) + + @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str) + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) + @pytest.mark.parametrize("dims", [[256, 512], [250, 500]]) + @pytest.mark.parametrize("block_scaling_dim", [1, 2]) + def test_inplace_ops( + self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: DimsType, block_scaling_dim: int + ) -> None: + """Test in-place operations""" + device = "cuda" + x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device) + y_hp = torch.rand(_to_list(dims), dtype=dtype, device=device) + + quantizer = Float8BlockQuantizer( + fp8_dtype=fp8_dtype, + rowwise=True, + columnwise=True, + block_scaling_dim=block_scaling_dim, + ) + + # Test in-place add + x_fp8 = quantizer.quantize(x_hp.clone()) + y_fp8 = quantizer.quantize(y_hp.clone()) + x_fp8.add_(y_fp8) + torch.testing.assert_close(x_fp8.dequantize(), x_hp + y_hp, **_tols[fp8_dtype]) + + # Test in-place subtract + x_fp8 = quantizer.quantize(x_hp.clone()) + y_fp8 = quantizer.quantize(y_hp.clone()) + x_fp8.sub_(y_fp8) + torch.testing.assert_close(x_fp8.dequantize(), x_hp - y_hp, **_tols[fp8_dtype]) + + # Test in-place multiply + x_fp8 = quantizer.quantize(x_hp.clone()) + y_fp8 = quantizer.quantize(y_hp.clone()) + x_fp8.mul_(y_fp8) + torch.testing.assert_close(x_fp8.dequantize(), x_hp * y_hp, **_tols[fp8_dtype]) + + @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str) + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) + @pytest.mark.parametrize("dims", [[256, 512], [250, 500]]) + @pytest.mark.parametrize("block_scaling_dim", [1, 2]) + def test_out_of_place_ops( + self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: DimsType, block_scaling_dim: int + ) -> None: + """Test out-of-place operations""" + device = "cuda" + x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device) + y_hp = torch.rand(_to_list(dims), dtype=dtype, device=device) + + quantizer = Float8BlockQuantizer( + fp8_dtype=fp8_dtype, + rowwise=True, + columnwise=True, + block_scaling_dim=block_scaling_dim, + ) + + x_fp8 = quantizer.quantize(x_hp.clone()) + y_fp8 = quantizer.quantize(y_hp.clone()) + + # Test exact operations + torch.testing.assert_close(-x_fp8, -x_hp, **_tols[fp8_dtype]) + torch.testing.assert_close(x_fp8.abs(), x_hp.abs(), **_tols[fp8_dtype]) + + # Test elementwise operations + torch.testing.assert_close(x_fp8 + y_fp8, x_hp + y_hp, **_tols[fp8_dtype]) + torch.testing.assert_close(x_fp8 - y_fp8, x_hp - y_hp, **_tols[fp8_dtype]) + torch.testing.assert_close(x_fp8 * y_fp8, x_hp * y_hp, **_tols[fp8_dtype]) + torch.testing.assert_close(torch.sin(x_fp8), torch.sin(x_hp), **_tols[fp8_dtype]) + + # Make sure we are not trivially passing tests + with pytest.raises(AssertionError): + torch.testing.assert_close(x_fp8 + y_fp8, x_hp - y_hp, **_tols[fp8_dtype]) + + @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str) + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) + @pytest.mark.parametrize("dims", [[256, 512], [250, 500]]) + @pytest.mark.parametrize("block_scaling_dim", [1, 2]) + def test_view_same_shape( + self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: DimsType, block_scaling_dim: int + ) -> None: + """Test view operations that preserve tensor shape""" + device = "cuda" + x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device) + + quantizer = Float8BlockQuantizer( + fp8_dtype=fp8_dtype, + rowwise=True, + columnwise=True, + block_scaling_dim=block_scaling_dim, + ) + + x_fp8 = quantizer.make_empty(x_hp.shape, dtype=dtype, device=device) + quantizer.update_quantized(x_hp.clone(), x_fp8) + + # Test view with same shape + x_view = x_fp8.view(*dims) + torch.testing.assert_close(x_view.dequantize(), x_hp, **_tols[fp8_dtype]) + assert x_view.shape == x_fp8.shape, "Shape changed after view with same dims" + + # Make sure we are not trivially passing tests + with pytest.raises(AssertionError): + torch.testing.assert_close(x_view.dequantize(), -x_hp, **_tols[fp8_dtype]) + + @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str) + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) + @pytest.mark.parametrize("dims", [[256, 512], [250, 500]]) + @pytest.mark.parametrize("block_scaling_dim", [1, 2]) + def test_reshape_same_shape( + self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: DimsType, block_scaling_dim: int + ) -> None: + """Test reshape operations that preserve tensor shape""" + device = "cuda" + x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device) + + quantizer = Float8BlockQuantizer( + fp8_dtype=fp8_dtype, + rowwise=True, + columnwise=True, + block_scaling_dim=block_scaling_dim, + ) + + x_fp8 = quantizer.make_empty(x_hp.shape, dtype=dtype, device=device) + quantizer.update_quantized(x_hp.clone(), x_fp8) + + # Test reshape with same shape + x_reshape = x_fp8.reshape(*dims) + torch.testing.assert_close(x_reshape.dequantize(), x_hp, **_tols[fp8_dtype]) + assert x_reshape.shape == x_fp8.shape, "Shape changed after reshape with same dims" + + # Test reshape with -1 canonicalization + new_dims = [-1, dims[1]] + x_reshape = x_fp8.reshape(*new_dims) + torch.testing.assert_close(x_reshape.dequantize(), x_hp, **_tols[fp8_dtype]) + assert x_reshape.shape == x_fp8.shape, "Shape changed after reshape with -1" + + # Make sure we are not trivially passing tests + with pytest.raises(AssertionError): + torch.testing.assert_close(x_reshape.dequantize(), -x_hp, **_tols[fp8_dtype]) + + @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str) + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) + @pytest.mark.parametrize("dims", [[256, 512], [250, 500]]) + @pytest.mark.parametrize("block_scaling_dim", [1, 2]) + def test_clone_detach( + self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: DimsType, block_scaling_dim: int + ) -> None: + """Test clone and detach operations""" + device = "cuda" + x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device) + + quantizer = Float8BlockQuantizer( + fp8_dtype=fp8_dtype, + rowwise=True, + columnwise=True, + block_scaling_dim=block_scaling_dim, + ) + + x_fp8 = quantizer.quantize(x_hp.clone()) + + # Test clone + x_clone = x_fp8.clone() + torch.testing.assert_close(x_clone.dequantize(), x_hp, **_tols[fp8_dtype]) + assert x_clone.shape == x_fp8.shape, "Shape changed after clone" + + # Test detach + x_detach = x_fp8.detach() + torch.testing.assert_close(x_detach.dequantize(), x_hp, **_tols[fp8_dtype]) + assert x_detach.shape == x_fp8.shape, "Shape changed after detach" + + # Make sure we are not trivially passing tests + with pytest.raises(AssertionError): + torch.testing.assert_close(x_clone.dequantize(), -x_hp, **_tols[fp8_dtype]) diff --git a/tests/pytorch/test_multi_tensor.py b/tests/pytorch/test_multi_tensor.py index 4dc1ec087f..737b5ff2b0 100644 --- a/tests/pytorch/test_multi_tensor.py +++ b/tests/pytorch/test_multi_tensor.py @@ -9,7 +9,7 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.optimizers import MultiTensorApply -from references.ref_per_tensor_cs import ref_compute_scale_and_scale_inv_from_amax +from references.quantize_scale_calc import scale_from_amax_tensor input_size_pairs = [ @@ -224,17 +224,18 @@ def test_multi_tensor_unscale_l2norm(input_size_pair, applier, repeat, in_type, @pytest.mark.parametrize("input_size_pair", input_size_pairs + [(1, 1)]) @pytest.mark.parametrize("applier", appliers) @pytest.mark.parametrize("repeat", [1, 55]) -@pytest.mark.parametrize("max_fp8", [448.0, 57344.0]) +@pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) @pytest.mark.parametrize("pow_2_scales", [False, True]) @pytest.mark.parametrize("epsilon", [0.0, 100.0]) def test_multi_tensor_compute_scale_and_scale_inv( - input_size_pair, applier, repeat, max_fp8, pow_2_scales, epsilon + input_size_pair, applier, repeat, fp8_dtype, pow_2_scales, epsilon ): sizea, sizeb = input_size_pair device = torch.device("cuda") overflow_buf = torch.zeros(1, dtype=torch.int32, device=device) a = torch.randn([sizea], dtype=torch.float32, device=device).abs() b = torch.randn([sizeb], dtype=torch.float32, device=device).abs() + max_fp8 = torch.finfo(fp8_dtype).max amax_list = [] for i in range(repeat): @@ -253,8 +254,8 @@ def test_multi_tensor_compute_scale_and_scale_inv( ) for amax, scale, scale_inv in zip(amax_list, scale_list, scale_inv_list): - scale_ref, scale_inv_ref = ref_compute_scale_and_scale_inv_from_amax( - amax, max_fp8, epsilon, pow_2_scales + scale_ref, scale_inv_ref, _ = scale_from_amax_tensor( + torch.float32, amax, fp8_dtype, eps=epsilon, pow_2_scales=pow_2_scales ) torch.testing.assert_close(scale, scale_ref, rtol=0, atol=0) torch.testing.assert_close(scale_inv, scale_inv_ref, rtol=0, atol=0) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 3abb61df02..18a0124e5a 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -58,6 +58,8 @@ list(APPEND transformer_engine_SOURCES transpose/cast_transpose_fusion.cu transpose/transpose_fusion.cu transpose/multi_cast_transpose.cu + transpose/quantize_transpose_square_blockwise.cu + transpose/quantize_transpose_vector_blockwise.cu activation/gelu.cu fused_attn/fused_attn_f16_max512_seqlen.cu fused_attn/fused_attn_f16_arbitrary_seqlen.cu diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index ac58398551..b1fe436379 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -99,6 +99,12 @@ struct Tensor { SimpleTensor scale_inv; SimpleTensor columnwise_scale_inv; + private: + // Used as an allocation for nvte_tensor_shape + // if the shape has to be inferred from columnwise data. + mutable std::vector rowwise_shape_cache; + + public: NVTEScalingMode scaling_mode; Tensor() @@ -160,12 +166,39 @@ struct Tensor { return data.shape; } break; + case NVTE_BLOCK_SCALING_1D: + case NVTE_BLOCK_SCALING_2D: { + if (!has_data() && has_columnwise_data()) { + std::vector shape; + size_t ndim = columnwise_data.shape.size(); + shape.reserve(ndim); + for (size_t i = 0; i + 1 < ndim; ++i) { + shape.push_back(columnwise_data.shape[i + 1]); + } + if (ndim > 0) { + shape.push_back(columnwise_data.shape[0]); + } + return shape; + } else { + // NOTE: We may have removed the data pointer from + // data by setting usage. In that case, we return + // the non-null shape. It is our best guess at the most + // recent shape. + return data.shape; + } + break; + } default: NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", to_string(scaling_mode), "\""); return {}; } } + const std::vector &rowwise_shape_ref() const { + rowwise_shape_cache = shape(); + return rowwise_shape_cache; + } + /*! Matrix height after tensor is flattened to 2D * * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted @@ -247,6 +280,36 @@ TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e8m0) #endif #undef TRANSFORMER_ENGINE_TYPE_NAME +template +struct TypeExtrema; + +template <> +struct TypeExtrema { + static constexpr float max = 448.0f; +}; + +template <> +struct TypeExtrema { + static constexpr float max = 57344.0f; +}; + +template <> +struct TypeExtrema { + // Hex float format of 1.(7 bits of 1) * 2 ^ 127 + static constexpr float max = 0x1.FEp127; +}; + +template <> +struct TypeExtrema { + // Hex float format of 1.(10 bits of 1) * 2 ^ 15 + static constexpr float max = 0x1.FFCp15; +}; + +template +struct TypeExtrema { + static constexpr float max = std::numeric_limits::max(); +}; + } // namespace detail template @@ -277,6 +340,7 @@ struct TypeInfo { constexpr static DType dtype = getType(); constexpr static size_t size = sizeof(T); + constexpr static float max_finite_value = detail::TypeExtrema::max; constexpr static const char *name = detail::type_name(); }; diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 3234e087c3..f19465c44b 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -81,6 +81,9 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla const transformer_engine::Tensor &B, const cublasOperation_t transB, const int k, const int lda, const int ldb) { using namespace transformer_engine; + // FIXME(kwyss): 1x128 by 128x128 GEMM is part of the subchannel design. + // Must either force them both into a common block scaling mode or loosen this + // restriction. NVTE_CHECK(A.scaling_mode == B.scaling_mode, "Inputs A and B to GEMM need to have the same scaling mode!"); NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!"); @@ -90,6 +93,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.lda = lda; ret.ldb = ldb; + // FIXME(kwyss): 128x128 by 128x128 GEMMs and 1x128 by 128x128 GEMMs need cases + // or need to be treated as `is_tensor_scaling`. if (is_tensor_scaling(A.scaling_mode)) { ret.A = A.data.dptr; ret.A_scale_inv = A.scale_inv.dptr; @@ -244,6 +249,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(fastAccuMode))); + // FIXME(kwyss): Add binding code for 128x128 block quantized 1x128 block quantized + // GEMM types. + // Scaling factors. #if CUDA_VERSION >= 12080 cublasLtMatmulMatrixScale_t scaling_mode; diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index d57975b2f4..7fa7957fa4 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -17,22 +17,31 @@ extern "C" { #endif -/* Cast the tensor to FP8 (or microscaling FP8 if the compute capability of the device is 10.0 or newer) - * The implementation is per the microscaling format MXFP8 defined by the OCP specification: +/* Quantize the tensor + * + * The type of quantized tensor in the output depends on the scaling mode of the output + * tensor. + * + * Supported formats are: + * + * 1) MXFP8 scaling (for compute capability 10.0 or newer) + * + * The MXFP8 implementation is per the microscaling format MXFP8 defined by the OCP specification: * https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf * - * Supported modes of scaling (live scaling): - * 1) Rowwise scaling (along the dim=0) computes one set of the output data, which includes: + * + * Supported modes of MXFP8 scaling (live scaling) for scaling mode NVTE_MXFP8_1D_SCALING + * a) Rowwise scaling (along the dim=0) computes one set of the output data, which includes: * - the scaled output tensor * - the corresponding scaling factors * The scaling factors are computed for blocks of the shape [1,32] * (i.e., each scaling factor spans 32 contiguous elements along rows). * - * 2) Columwise scaling (along the dim=1) computes one set of the output data. + * b) Columwise scaling (along the dim=1) computes one set of the output data. * The scaling factors are computed for blocks of the shape [32,1] * (i.e., each scaling factor spans 32 contiguous elements along columns). * - * 3) Both rowwise AND columnwise scaling (along the dim=0 and the dim=1) + * c) Both rowwise AND columnwise scaling (along the dim=0 and the dim=1) * computes two sets of the output data: both 1) and 2). * * The shape of the MX block must be specified in the 'output' argument, @@ -40,25 +49,53 @@ extern "C" { * * To cast the input tensor to the MXFP8, the scaling_mode.delayed_scaling parameter * of the output tensor should be set to 0. + * + * 2) NVTE_DELAYED_TENSOR_SCALING that quantize the entire tensor + * using a single scaling factor. The absolute maximum value of the tensor should + * be precalculated either online (current scaling) or based on a tensor history + * (delayed scaling). The calls to nvte_quantize scale based on that data value. + * Note the NVTE_DELAYED_TENSOR_SCALING NVTEScalingMode is reused for online + * per tensor scaling. + * + * + * 3) FP8 block scaling formats NVTE_BLOCK_SCALING_1D and NVTE_BLOCK_SCALING_2D + * for compute capability of at least 9.0. These modes quantize the tensor by blocks + * of size 1x128 (with columnwise mode of 128x1) and 128x128 respectively. + * + * The supported modes are: + * a) Rowwise scaling yields output data: + * - the scaled output tensor in fp8 coefficients with identical shape to the + * input tensor. + * - Scale factors which are computed for either 1D 1x128 or 2D 128x128 blocks. + * b) Columnwise scaling yields output data: + * - the scaled output tensor in fp8 coefficients with a shape equivalent to + * the transpose of the input tensor. + * - Scale factors which are calculated for either 1D 128x1 or 2D 128x128 blocks + * of the input tensor. + * c) Both: In which both tensors and both scales are calculated. + * + * This quantization mode includes both the calculation of the scaling factors + * per-tile and quantization of the row and/or columnwise tiles. No precalculated + * absolute max is required. The scaling factors are also rounded to powers of 2. */ -/*! \brief Casts input tensor to FP8/MXFP8. - * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, - * the block quantization (MXFP8) of the specified shape of the block will be used. +/*! \brief Casts input tensor to FP8/MXFP8/BlockwiseFP8. + * The type of quantized tensor in the output depends on the scaling mode of the output + * tensor. See file level comments. * * \param[in] input Input tensor to be cast. - * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[in,out] output Output FP8/MXFP8/BlockwiseFP8 tensor. * \param[in] stream CUDA stream used for the operation. */ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); /*! \brief Casts input tensor to FP8/MXFP8, providing the option to immediately exit the kernel * based on the value of the 'noop' tensor. - * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, - * the block quantization (MXFP8) of the specified shape of the block will be used. + * The type of quantized tensor in the output depends on the scaling mode of the output + * tensor. See file level comments. * * \param[in] input Input tensor to be cast. - * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[in,out] output Output quantized tensor. * \param[out] noop Noop tensor. * \param[in] stream CUDA stream used for the operation. */ diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 70086a1811..c539265e62 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -80,8 +80,14 @@ enum NVTEScalingMode { /*! Single scale per block of 32 elements consecutive in either rowwise or columnwise direction */ NVTE_MXFP8_1D_SCALING = 1, - NVTE_INVALID_SCALING = 2, - NVTE_NO_SCALING = 3 + /*! Tensor is split into NxN quantization tiles or 1xN quantization tiles, + which each yield a scale. The block_scaling_dim property of the quantizer + selects the granularity. + */ + NVTE_BLOCK_SCALING_1D = 2, + NVTE_BLOCK_SCALING_2D = 3, + NVTE_INVALID_SCALING = 4, + NVTE_NO_SCALING = 5 }; /*! \brief TE Tensor type diff --git a/transformer_engine/common/recipe/current_scaling.cu b/transformer_engine/common/recipe/current_scaling.cu index e53ab18360..197863569e 100644 --- a/transformer_engine/common/recipe/current_scaling.cu +++ b/transformer_engine/common/recipe/current_scaling.cu @@ -152,7 +152,8 @@ namespace { __global__ void compute_scale_from_amax_kernel(const float *amax_ptr, float *scale_ptr, const float max_fp8, const bool force_pow_2_scales, const float epsilon) { - *scale_ptr = compute_scale_from_amax(*amax_ptr, max_fp8, force_pow_2_scales, epsilon); + *scale_ptr = compute_scale_from_amax(*amax_ptr, max_fp8, force_pow_2_scales, epsilon, + std::numeric_limits::max()); } } // namespace diff --git a/transformer_engine/common/recipe/recipe_common.cuh b/transformer_engine/common/recipe/recipe_common.cuh index c789a9b497..11f9bc1299 100644 --- a/transformer_engine/common/recipe/recipe_common.cuh +++ b/transformer_engine/common/recipe/recipe_common.cuh @@ -7,19 +7,21 @@ #ifndef TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_ #define TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_ -#include +#include "common/common.h" namespace transformer_engine { __device__ __forceinline__ float compute_scale_from_amax(float amax, float max_fp8, - bool force_pow_2_scales, float epsilon) { + bool force_pow_2_scales, float epsilon, + float value_for_inf) { + // NOTE: NAN amax evaluates false for <, handled further down. if (amax < epsilon) { amax = epsilon; } float scale = 1.f; - if (isinf(amax) || amax == 0.f) { + if (isinf(amax) || amax == 0.f || isnan(amax)) { return scale; } @@ -32,18 +34,13 @@ __device__ __forceinline__ float compute_scale_from_amax(float amax, float max_f // the scale is not representable in FP32. if (isinf(scale)) { // use fp32 max to represent the scale - scale = std::numeric_limits::max(); + scale = value_for_inf; } - - if (isnan(scale)) { - scale = 1.f; - } - if (force_pow_2_scales) { uint32_t scale_bits = *reinterpret_cast(&scale); scale_bits &= 0xFF800000; // If the exponent was zero, we have a logic error. - __builtin_assume(scale_bits != 0); + __builtin_assume(scale_bits != 0 || scale == 0.0); __builtin_assume(scale_bits != 0x80000000); scale = *reinterpret_cast(&scale_bits); } @@ -51,6 +48,26 @@ __device__ __forceinline__ float compute_scale_from_amax(float amax, float max_f return scale; } +// Calculate the quantization scale for an individual data element +// given the amax(abs(tile)) value for a given quantization tile. +// +// +// Arguments: +// IType: data type of the tensor being quantized (float or bf16) +// OType: quantized data type (e4m3 or e5m2) +// amax: The evaluation of amax(abs(tile)) for the quantization tile. +// eps: An epsilon used as a floor for amax. +// pow_2_scaling: Whether to force the scale to be a power of 2. +template +__device__ __forceinline__ float compute_scale_from_types(const float amax, const float eps, + const float pow_2_scaling) { + constexpr float fp8_max = TypeInfo::max_finite_value; + // NOTE: We're relying on compute_scale_from_amax to have behavior where it + // clips the mantissa of the max_finite_value if power of 2 scaling applies. + constexpr float value_for_inf = TypeInfo::max_finite_value; + return compute_scale_from_amax(amax, fp8_max, pow_2_scaling, eps, value_for_inf); +} + } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_ diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 1f8bfca2c9..97df5892b6 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -215,48 +215,14 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { if (tensor == nullptr) { NVTE_ERROR("Invalid tensor"); } - NVTEShape ret; // Determine tensor shape depending on tensor format const auto &t = *reinterpret_cast(tensor); - switch (t.scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - if (!t.has_data() && t.has_columnwise_data()) { - // We can infer tensor shape if FP8 tensor only has FP8 data - // transpose. However, NVTEShape only contains a pointer and - // cannot store temporary data. We hack around this by caching - // the tensor shape within the empty FP8 data. - auto &shape_cache = const_cast &>(t.data.shape); - shape_cache.clear(); - if (!t.columnwise_data.shape.empty()) { - for (size_t i = 1; i < t.columnwise_data.shape.size(); i++) { - shape_cache.push_back(t.columnwise_data.shape[i]); - } - shape_cache.push_back(t.columnwise_data.shape.front()); - } - ret.data = shape_cache.data(); - ret.ndim = shape_cache.size(); - } else { - ret.data = t.data.shape.data(); - ret.ndim = t.data.shape.size(); - } - break; - } - case NVTE_MXFP8_1D_SCALING: { - if (!t.has_data() && t.has_columnwise_data()) { - ret.data = t.columnwise_data.shape.data(); - ret.ndim = t.columnwise_data.shape.size(); - } else { - ret.data = t.data.shape.data(); - ret.ndim = t.data.shape.size(); - } - break; - } - default: - NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", - transformer_engine::to_string(t.scaling_mode), "\""); - } + const std::vector &rowwise_shape = t.rowwise_shape_ref(); + NVTEShape ret; + ret.data = rowwise_shape.data(); + ret.ndim = rowwise_shape.size(); return ret; } diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h index ed9bd5f5f7..298d087337 100644 --- a/transformer_engine/common/transpose/cast_transpose.h +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -23,6 +23,18 @@ template +#include +#include +#include + +#include +#include + +#include "common/common.h" +#include "common/recipe/recipe_common.cuh" +#include "common/util/ptx.cuh" +#include "common/utils.cuh" + +#if (!defined(__CUDA_MINIMUM_ARCH__) && __CUDA_ARCH__ >= 900) || \ + (defined(__CUDA_MINIMUM_ARCH__) && __CUDA_MINIMUM_ARCH__ >= 900) +#define TMA_HW_SUPPORTED +#endif + +namespace transformer_engine { +namespace { + +// const values configuration + +constexpr size_t kThreadsPerWarp = 32; +#ifdef TMA_HW_SUPPORTED +constexpr size_t BLOCK_TILE_DIM = 128; +constexpr size_t WARP_TILE_DIM_X = 32; +constexpr size_t WARP_TILE_DIM_Y = 64; +constexpr size_t THREAD_TILE_DIM_X = 16; +constexpr size_t THREAD_TILE_DIM_Y = 4; +#else +constexpr size_t BLOCK_TILE_DIM = 128; +constexpr size_t WARP_TILE_DIM_X = 64; +constexpr size_t WARP_TILE_DIM_Y = 32; +constexpr size_t THREAD_TILE_DIM_X = 8; +constexpr size_t THREAD_TILE_DIM_Y = 8; +#endif + +#ifdef TMA_HW_SUPPORTED +constexpr size_t NUM_BYTES_PER_BANK = 4; +constexpr size_t NUM_BANKS_PER_SHARED_ELEM = THREAD_TILE_DIM_Y / NUM_BYTES_PER_BANK; +constexpr size_t SHARED_BLOCK_TILE_DIM_Y = BLOCK_TILE_DIM; +constexpr size_t SHARED_BLOCK_TILE_DIM_X_BANKS = + BLOCK_TILE_DIM / (NUM_BYTES_PER_BANK * NUM_BANKS_PER_SHARED_ELEM); +constexpr size_t NUM_BANKS_Y_IN_WARP = WARP_TILE_DIM_Y / NUM_BYTES_PER_BANK; +#endif +constexpr size_t ELE_PER_THREAD = THREAD_TILE_DIM_X * THREAD_TILE_DIM_Y; +constexpr size_t THREADS_PER_BLOCK = BLOCK_TILE_DIM * BLOCK_TILE_DIM / ELE_PER_THREAD; +constexpr size_t NUM_WARPS_X_IN_BLOCK = BLOCK_TILE_DIM / WARP_TILE_DIM_X; +constexpr size_t NUM_WARPS_Y_IN_BLOCK = BLOCK_TILE_DIM / WARP_TILE_DIM_Y; +constexpr size_t NUM_WARPS_IN_BLOCK = NUM_WARPS_X_IN_BLOCK * NUM_WARPS_Y_IN_BLOCK; + +constexpr size_t NUM_THREADS_X_IN_WARP = WARP_TILE_DIM_X / THREAD_TILE_DIM_X; +constexpr size_t NUM_THREADS_Y_IN_WARP = kThreadsPerWarp / NUM_THREADS_X_IN_WARP; + +#define MIN(a, b) (a < b ? a : b) + +template +__global__ void __launch_bounds__(THREADS_PER_BLOCK) + block_scaled_cast_transpose_kernel(const IType* const input, OType* const output_c, + OType* const output_t, CType* const tile_scales_inv_c, + CType* const tile_scales_inv_t, const size_t row_length, + const size_t num_rows, const size_t scale_stride_x, + const size_t scale_stride_y, const size_t scale_t_stride_x, + const size_t scale_t_stride_y, const float epsilon, + const __grid_constant__ CUtensorMap tensor_map_output_t, + bool pow_2_scaling) { + using IVec = Vec; + using OVecCast = Vec; + using OVecTrans = Vec; + + // shared mem for amax reduction in entire block, each warp produces one amax, there are + // NUM_WARPS_IN_BLOCK amax to reduce + __shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK]; + + IVec thrd_tile_input[THREAD_TILE_DIM_Y]; + constexpr int THREAD_TILE_DIM_X_ = kReturnTranspose ? THREAD_TILE_DIM_X : 1; + OVecTrans thrd_tile_out_trans[THREAD_TILE_DIM_X_]; + + const int tid_in_warp = threadIdx.x % kThreadsPerWarp; + const int tid_in_warp_x = tid_in_warp % NUM_THREADS_X_IN_WARP; + const int tid_in_warp_y = tid_in_warp / NUM_THREADS_X_IN_WARP; + const int warp_id_in_block = threadIdx.x / kThreadsPerWarp; + const int warp_id_in_block_x = warp_id_in_block % NUM_WARPS_X_IN_BLOCK; + const int warp_id_in_block_y = warp_id_in_block / NUM_WARPS_X_IN_BLOCK; + + // This is ONLY true if the input is a full tile + const int tile_id_x = blockIdx.x; + const int tile_id_y = blockIdx.y; + + const size_t block_tile_start_idx = + tile_id_y * BLOCK_TILE_DIM * row_length + tile_id_x * BLOCK_TILE_DIM; + const size_t warp_tile_start_idx = + block_tile_start_idx + + warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP * row_length + + warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP; + const size_t thread_tile_start_idx = warp_tile_start_idx + + tid_in_warp_y * THREAD_TILE_DIM_Y * row_length + + tid_in_warp_x * THREAD_TILE_DIM_X; + + CType warp_tile_amax; + CType block_tile_amax; + CType block_tile_scale; + CType amax = 0; + +// Step 1: Load a block tile of input data into thread tiles on registers +#pragma unroll + for (int i = 0; i < THREAD_TILE_DIM_Y; i++) { + thrd_tile_input[i].load_from(input + thread_tile_start_idx + i * row_length); + } + + // Step 2: calculate block tile amax and scale + // Calculate thread_tile amax + for (int i = 0; i < THREAD_TILE_DIM_Y; i++) { +#pragma unroll + for (int j = 0; j < THREAD_TILE_DIM_X; j++) { + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(static_cast(thrd_tile_input[i].data.elt[j]))); + } + } + // Reduce amax in the warp (32x32 tile) + warp_tile_amax = warp_reduce_max(amax); + // broadcast the amax to all threads in a warp from the lane 0 + constexpr int lane_zero = 0; + warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero); + + // reduce warp_tile_amax across multiple warps in a thread block using shared mem + if (tid_in_warp == 0) { + block_tile_amax_shared[warp_id_in_block_y * NUM_WARPS_X_IN_BLOCK + warp_id_in_block_x] = + warp_tile_amax; + } + __syncthreads(); + // only 8 elements needs reduction, if using reduction tree, multiple _syncthreads will be needed, + // instead we just let thread 0 do the job + if (threadIdx.x == 0) { + CType blk_amax = block_tile_amax_shared[0]; +#pragma unroll + for (int idx = 1; idx < NUM_WARPS_IN_BLOCK; idx++) { + blk_amax = fmaxf(blk_amax, block_tile_amax_shared[idx]); + } + block_tile_amax_shared[0] = blk_amax; + } + __syncthreads(); + block_tile_amax = block_tile_amax_shared[0]; + + block_tile_scale = + compute_scale_from_types(block_tile_amax, epsilon, pow_2_scaling); + + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + const CType scale_inv = 1.0f / block_tile_scale; + + size_t row_idx = tile_id_y; + size_t col_idx = tile_id_x; + tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv; + + if constexpr (kReturnTranspose) { + row_idx = tile_id_x; + col_idx = tile_id_y; + tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv; + } + } + + // Step 3: Store cast output, Step 4: do transpose within thread tile + OVecCast tmp_output_c; + + for (int i = 0; i < THREAD_TILE_DIM_Y; i++) { +#pragma unroll + for (int j = 0; j < THREAD_TILE_DIM_X; j++) { + // Step 3: Store cast output + CType scale_data = block_tile_scale; + + OType scaled_elt = + static_cast(static_cast(thrd_tile_input[i].data.elt[j]) * scale_data); + tmp_output_c.data.elt[j] = scaled_elt; + // Step 4: do transpose within thread tile + if constexpr (kReturnTranspose) { + thrd_tile_out_trans[j].data.elt[i] = scaled_elt; + } + } + tmp_output_c.store_to(output_c + thread_tile_start_idx + i * row_length); + } + + // Step 4: store transpose into shared memory + if constexpr (kReturnTranspose) { +#ifdef TMA_HW_SUPPORTED + __shared__ alignas(128) + OVecTrans block_tile_trans_shared[SHARED_BLOCK_TILE_DIM_Y][SHARED_BLOCK_TILE_DIM_X_BANKS]; + OType(*block_tile_trans_shared_otype_ptr)[BLOCK_TILE_DIM] = + reinterpret_cast(block_tile_trans_shared); + +#pragma unroll + for (int i = 0; i < THREAD_TILE_DIM_X; i++) { + auto warp_id_in_block_x_ = warp_id_in_block_y; + auto warp_id_in_block_y_ = warp_id_in_block_x; + int row_idx = warp_id_in_block_y_ * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP + + tid_in_warp_x * THREAD_TILE_DIM_X + i; + int col_idx = + warp_id_in_block_x_ * (NUM_BANKS_Y_IN_WARP / NUM_BANKS_PER_SHARED_ELEM) + tid_in_warp_y; + block_tile_trans_shared[row_idx][col_idx] = thrd_tile_out_trans[i]; + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Step 5: store transpose output + // Initiate TMA transfer to copy shared memory to global memory + if (threadIdx.x == 0) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_t), tile_id_y * BLOCK_TILE_DIM, + tile_id_x * BLOCK_TILE_DIM, + reinterpret_cast(block_tile_trans_shared_otype_ptr)); + // Wait for TMA transfer to have finished reading shared memory. + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + // Wait for the group to have completed reading from shared memory. + ptx::cp_async_bulk_wait_group_read<0>(); + } +#else + // Step 4 Alternative (when TMA is not available, skip writing to shared memory) + const size_t block_tile_t_start_idx = + tile_id_x * BLOCK_TILE_DIM * num_rows + tile_id_y * BLOCK_TILE_DIM; + const size_t warp_tile_t_start_idx = + block_tile_t_start_idx + + warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP * num_rows + + warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP; + const size_t thread_tile_t_start_idx = warp_tile_t_start_idx + + tid_in_warp_x * THREAD_TILE_DIM_X * num_rows + + tid_in_warp_y * THREAD_TILE_DIM_Y; +#pragma unroll + for (int i = 0; i < THREAD_TILE_DIM_X; i++) { + thrd_tile_out_trans[i].store_to(output_t + thread_tile_t_start_idx + i * num_rows); + } +#endif + } +} + +template +__global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose_kernel_notaligned( + const IType* const input, OType* const output_c, OType* const output_t, + CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length, + const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y, + const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon, + bool pow_2_scaling) { + using IVec = Vec; + using OVecCast = Vec; + using OVecTrans = Vec; + + // shared mem for amax reduction in entire block, each warp produces one amax, there are + // NUM_WARPS_IN_BLOCK amax to reduce + __shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK]; + + IVec thrd_tile_input[THREAD_TILE_DIM_Y]; + constexpr int THREAD_TILE_DIM_X_ = kReturnTranspose ? THREAD_TILE_DIM_X : 1; + OVecTrans thrd_tile_out_trans[THREAD_TILE_DIM_X_]; + + const int tid_in_warp = threadIdx.x % kThreadsPerWarp; + const int tid_in_warp_x = tid_in_warp % NUM_THREADS_X_IN_WARP; + const int tid_in_warp_y = tid_in_warp / NUM_THREADS_X_IN_WARP; + const int warp_id_in_block = threadIdx.x / kThreadsPerWarp; + const int warp_id_in_block_x = warp_id_in_block % NUM_WARPS_X_IN_BLOCK; + const int warp_id_in_block_y = warp_id_in_block / NUM_WARPS_X_IN_BLOCK; + + const int tile_id_x = blockIdx.x; + const int tile_id_y = blockIdx.y; + + const size_t block_tile_start_row_idx = tile_id_y * BLOCK_TILE_DIM; + const size_t block_tile_start_col_idx = tile_id_x * BLOCK_TILE_DIM; + const size_t block_tile_start_idx = + block_tile_start_row_idx * row_length + block_tile_start_col_idx; + const size_t warp_tile_start_idx = + block_tile_start_idx + + warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP * row_length + + warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP; + const size_t thread_tile_start_idx = warp_tile_start_idx + + tid_in_warp_y * THREAD_TILE_DIM_Y * row_length + + tid_in_warp_x * THREAD_TILE_DIM_X; + + // handle non-full tile + // check for three cases: full thread tile, nonfull thread tile, empty thread tile + // for empty thread tile, directly write zero to the transposed shared mem buffer + // for nonfull thread tile, fill zero to thread tile and act as if it's full + const size_t thread_tile_start_row_idx = + tile_id_y * BLOCK_TILE_DIM + warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP + + tid_in_warp_y * THREAD_TILE_DIM_Y; + const size_t thread_tile_start_col_idx = + tile_id_x * BLOCK_TILE_DIM + warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP + + tid_in_warp_x * THREAD_TILE_DIM_X; + + const size_t thread_tile_end_row_idx = thread_tile_start_row_idx + THREAD_TILE_DIM_Y - 1; + const size_t thread_tile_end_col_idx = thread_tile_start_col_idx + THREAD_TILE_DIM_X - 1; + + bool full_thrd_tile = + (thread_tile_end_row_idx < num_rows) && (thread_tile_end_col_idx < row_length); + bool empty_thrd_tile = + (thread_tile_start_row_idx >= num_rows) || (thread_tile_start_col_idx >= row_length); + bool nonfull_thrd_tile = (!full_thrd_tile) && (!empty_thrd_tile); + + const size_t thread_tile_ncols = + MIN(THREAD_TILE_DIM_X, + (MIN(thread_tile_end_col_idx, row_length - 1) - thread_tile_start_col_idx + 1)); + const size_t thread_tile_nrows = + MIN(THREAD_TILE_DIM_Y, + (MIN(thread_tile_end_row_idx, num_rows - 1) - thread_tile_start_row_idx + 1)); + + CType warp_tile_amax; + CType block_tile_amax; + CType block_tile_scale; + CType amax = 0; + + if (!empty_thrd_tile) { + // Step 1: Load a block tile of input data into thread tiles on registers + // Edge case: nonfull thread tile case, will use the partial load function here + if (nonfull_thrd_tile) { +#pragma unroll + for (int i = 0; i < THREAD_TILE_DIM_Y; i++) { + if (i >= thread_tile_nrows) { + thrd_tile_input[i].clear(); + } else { + thrd_tile_input[i].load_from_elts(input + thread_tile_start_idx + i * row_length, 0, + thread_tile_ncols); + } + } + } else { +#pragma unroll + for (int i = 0; i < THREAD_TILE_DIM_Y; i++) { + thrd_tile_input[i].load_from_elts(input + thread_tile_start_idx + i * row_length, 0, + THREAD_TILE_DIM_X); + } + } + + // Step 2: calculate block tile amax and scale + // Calculate thread_tile amax + for (int i = 0; i < THREAD_TILE_DIM_Y; i++) { +#pragma unroll + for (int j = 0; j < THREAD_TILE_DIM_X; j++) { + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(static_cast(thrd_tile_input[i].data.elt[j]))); + } + } + } + // Reduce amax in the warp (32x32 tile) + warp_tile_amax = warp_reduce_max(amax); + // broadcast the amax to all threads in a warp from the lane 0 + constexpr int lane_zero = 0; + warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero); + + // reduce warp_tile_amax across multiple warps in a thread block using shared mem + if (tid_in_warp == 0) { + block_tile_amax_shared[warp_id_in_block_y * NUM_WARPS_X_IN_BLOCK + warp_id_in_block_x] = + warp_tile_amax; + } + __syncthreads(); + // only 8 elements needs reduction, if using reduction tree, multiple _syncthreads will be needed, + // instead we just let thread 0 do the job + if (threadIdx.x == 0) { + CType blk_amax = block_tile_amax_shared[0]; +#pragma unroll + for (int idx = 1; idx < NUM_WARPS_IN_BLOCK; idx++) { + blk_amax = fmaxf(blk_amax, block_tile_amax_shared[idx]); + } + block_tile_amax_shared[0] = blk_amax; + } + __syncthreads(); + block_tile_amax = block_tile_amax_shared[0]; + + block_tile_scale = + compute_scale_from_types(block_tile_amax, epsilon, pow_2_scaling); + + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + const CType scale_inv = 1.0f / block_tile_scale; + + size_t row_idx = tile_id_y; + size_t col_idx = tile_id_x; + tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv; + + if constexpr (kReturnTranspose) { + row_idx = tile_id_x; + col_idx = tile_id_y; + tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv; + } + } + + // Step 3: Store cast output, Step 4: do transpose within thread tile + // Edge case: in the non-full tile case, there are three subcases + // for full thread tile, it's the same thing here + // for nonfull thread tile, pay attention when saving tmp_output_c to global + // memory, cannot vec store_to, but need to elt store to for empty tile, + // it should not enter this step, skip to Step 4 + + // set thrd_tile_out_trans to all zero + if constexpr (kReturnTranspose) { +#pragma unroll + for (int j = 0; j < THREAD_TILE_DIM_X; j++) { + thrd_tile_out_trans[j].clear(); + } + } + + if (!empty_thrd_tile) { + OVecCast tmp_output_c; + for (int i = 0; i < THREAD_TILE_DIM_Y; i++) { + if (i >= thread_tile_nrows) { + continue; + } +#pragma unroll + for (int j = 0; j < THREAD_TILE_DIM_X; j++) { + // Step 3: Store cast output + CType scale_data = block_tile_scale; + + OType scaled_elt = + static_cast(static_cast(thrd_tile_input[i].data.elt[j]) * scale_data); + tmp_output_c.data.elt[j] = scaled_elt; + // Step 4: do transpose within thread tile + if constexpr (kReturnTranspose) { + thrd_tile_out_trans[j].data.elt[i] = scaled_elt; + } + } + tmp_output_c.store_to_elts(output_c + thread_tile_start_idx + i * row_length, 0, + thread_tile_ncols); + } + + if constexpr (kReturnTranspose) { + const size_t block_tile_t_start_idx = + tile_id_x * BLOCK_TILE_DIM * num_rows + tile_id_y * BLOCK_TILE_DIM; + const size_t warp_tile_t_start_idx = + block_tile_t_start_idx + + warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP * num_rows + + warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP; + const size_t thread_tile_t_start_idx = warp_tile_t_start_idx + + tid_in_warp_x * THREAD_TILE_DIM_X * num_rows + + tid_in_warp_y * THREAD_TILE_DIM_Y; +#pragma unroll + for (int i = 0; i < thread_tile_ncols; i++) { + thrd_tile_out_trans[i].store_to_elts(output_t + thread_tile_t_start_idx + i * num_rows, 0, + thread_tile_nrows); + } + } + } +} + +template +CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size_t global_dim_y) { + CUtensorMapDataType dataType; + if constexpr (std::is_same_v || + std::is_same_v) { + dataType = CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8; + } else { + NVTE_CHECK(false, "Invalid Output type (must be FP8)."); + } + + CUtensorMap tensor_map_output_trans{}; + create_2D_tensor_map(tensor_map_output_trans, tensor, global_dim_y, global_dim_x, + /*shmemY=*/BLOCK_TILE_DIM, /*shmemX=*/BLOCK_TILE_DIM, + /*stride_elems=*/global_dim_x, /*offset_elems=*/0, sizeof(OutputType)); + return tensor_map_output_trans; +} + +} // namespace +} // namespace transformer_engine + +namespace transformer_engine::detail { + +void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor& scale_inv, + SimpleTensor& scale_inv_t, SimpleTensor& output, + SimpleTensor& output_t, const float epsilon, + const bool return_transpose, const bool pow_2_scale, + cudaStream_t stream) { + NVTE_API_CALL(quantize_transpose_square_blockwise); + NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); + const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; + size_t num_rows = 1; + for (size_t i = 0; (i < input.shape.size() - 1) && (input.shape.size() > 0); ++i) { + num_rows *= input.shape.at(i); + } + + NVTE_CHECK(scale_inv.shape.size() == 2, "scale_inv must have 2 dimensions."); + + size_t scale_k = scale_inv.shape[1]; + + const size_t scale_stride_x = 1; + const size_t scale_stride_y = scale_k; + + size_t scale_t_stride_x = 0; + size_t scale_t_stride_y = 0; + + if (return_transpose) { + NVTE_CHECK(output_t.shape.size() == input.shape.size(), + "output_t must have same number of dimensions as input."); + if (output_t.shape.size() > 0) { + NVTE_CHECK(output_t.shape[0] == row_length, "Wrong dimension 0 of output_t."); + for (size_t i = 1; i < output_t.shape.size(); ++i) { + NVTE_CHECK(output_t.shape.at(i) == input.shape.at(i - 1), "Wrong dimension in output_t"); + } + } + NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same type."); + + NVTE_CHECK(scale_inv_t.shape.size() == 2, "scale_inv_t must have 2 dimensions."); + + scale_t_stride_x = 1; + scale_t_stride_y = scale_inv_t.shape[1]; + } + + const size_t num_blocks_x = DIVUP(row_length, BLOCK_TILE_DIM); + const size_t num_blocks_y = DIVUP(num_rows, BLOCK_TILE_DIM); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype, InputType, + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output.dtype, OutputType, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_transpose, kReturnTranspose, + + dim3 grid(num_blocks_x, num_blocks_y, 1); + const bool full_tile = + row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0; + + if (full_tile) { + CUtensorMap tensor_map_output_trans; + if (return_transpose) { + tensor_map_output_trans = + get_tensor_map(output_t, num_rows, row_length); + } + block_scaled_cast_transpose_kernel + <<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, + scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, + tensor_map_output_trans, pow_2_scale); + } else { + block_scaled_cast_transpose_kernel_notaligned + <<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, + scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, + pow_2_scale); + } // full-tile + ) // return_transpose + ) // OutputType + ) // InputType + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +} // namespace transformer_engine::detail diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu new file mode 100644 index 0000000000..732d97999c --- /dev/null +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -0,0 +1,479 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "common/common.h" +#include "common/recipe/recipe_common.cuh" +#include "common/utils.cuh" + +namespace transformer_engine { +namespace { + +// clang-format off +/* + +Step 1: Load input to shared memory +* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding) +* 8 warps +* Loop 8 times +* What each thread does in each loop: + * 8 elements are read from the input at a time + * 2 elements are written to the shared memory at a time, for a total of 4 times ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 | T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 | +| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 | T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| Warp 1 | +| | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| ... | +| ... | +| ... | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| Warp 7 | +| | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| ... | +| ... | +| ... | +| ... | +| Loop 8 times | +| ... | +| ... | +| ... | +| ... | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ + +Step 2: Cast and store to output_c +* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding) +* 8 warps +* Loop 4 times +* What each thread does in each loop: + * 2 elements are read from the shared memory at a time, for a total of 8 times + * Every 8 consecutive threads do reduction and calculate the amax of each row + * 16 elements are quantized and write to output_c at a time ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 | +| T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 | +| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 | +| T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| | +| Warp 1 | +| | +| | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| ... | +| ... | +| ... | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| | +| Warp 7 | +| | +| | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| ... | +| ... | +| ... | +| ... | +| Loop 4 times | +| ... | +| ... | +| ... | +| ... | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ + +Step 3: Transpose, cast and store to output_t +* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding) +* 8 warps +* Loop 2 times +* What each thread does in each loop: + * 2 elements (in a row) are read from the shared memory at a time, for a total of 16 times + * Every 8 consecutive threads do reduction and calculate the amax of each column + * 16 elements are quantized and write to output_c at a time, for a total of 2 times ++------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+ +| T0 | T8 | T16 | T24 | | | | T0 | T8 | T16 | T24 | | | | +| T1 | T9 | T17 | T25 | | | | T1 | T9 | T17 | T25 | | | | +| T2 | T10 | T18 | T26 | | | | T2 | T10 | T18 | T26 | | | | +| T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 | T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 | +| T4 | T12 | T20 | T28 | | | | T4 | T12 | T20 | T28 | | | | +| T5 | T13 | T21 | T29 | | | | T5 | T13 | T21 | T29 | | | | +| T6 | T14 | T22 | T30 | | | | T6 | T14 | T22 | T30 | | | | +| T7 | T15 | T23 | T31 | | | | T7 | T15 | T23 | T31 | | | | ++-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+ + +*/ +// clang-format on + +constexpr size_t kThreadsPerWarp = 32; + +// Hyperparameters for performance tuning +constexpr int kTileDim = 128; // Fixed to 128 beacause we are using 1x128 and 128x1 quantization +constexpr int kNVecIn = 8; // The number of elements each LDG touches +constexpr int kNVecOut = 16; // The number of elements each STG touches +constexpr int kNVecSMem = 2; // The number of elements each LDS/STS touches +constexpr int kThreadsPerBlock = 256; // Thread block size, 8 warps in total + +// Auto-calculated constants, do not modify directly) +static_assert(kNVecIn % kNVecSMem == 0, "kNVecIn must be divisible by kNVecSMem"); +static_assert(kNVecOut % kNVecSMem == 0, "kNVecOut must be divisible by kNVecSMem"); +constexpr int kSMemRow = kTileDim; +constexpr int kSMemCol = (kTileDim / kNVecSMem) + 1; +constexpr int kSMemSize = kSMemRow * kSMemCol * kNVecSMem; +constexpr int kNumThreadsLoad = kTileDim / kNVecIn; +constexpr int kNumThreadsStore = kTileDim / kNVecOut; +static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kThreadsPerWarp"); +static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp"); + +template +__global__ void __launch_bounds__(kThreadsPerBlock) + block_scaled_1d_cast_transpose_kernel(const IType* const input, OType* const output_c, + OType* const output_t, CType* const tile_scales_inv_c, + CType* const tile_scales_inv_t, const size_t row_length, + const size_t num_rows, const size_t scale_stride_x, + const size_t scale_stride_y, + const size_t scale_t_stride_x, + const size_t scale_t_stride_y, const float epsilon, + bool return_transpose, bool pow_2_scaling) { + using SMemVec = Vec; + using OVec = Vec; + union IVec { + Vec input_type; + Vec smem_type; + }; + + extern __shared__ char smem_base[]; + SMemVec* smem = reinterpret_cast(&smem_base[0]); + + // Step 1: Load input to shared memory + { + constexpr int r_stride = kThreadsPerBlock / kNumThreadsLoad; // stride in rows of shared memory + constexpr int num_iterations = kTileDim / r_stride; + const int c_s = + (threadIdx.x % kNumThreadsLoad) * (kNVecIn / kNVecSMem); // Column in shared memory + int r_s = threadIdx.x / kNumThreadsLoad; // Row in shared memory + const size_t c_g = + static_cast(blockIdx.x) * kTileDim + c_s * kNVecSMem; // Column in global memory + size_t r_g = static_cast(blockIdx.y) * kTileDim + r_s; // Row in global memory + const size_t stride_g = static_cast(r_stride) * row_length; // Stride in global memory + const size_t num_ele = c_g < row_length ? min(static_cast(kNVecIn), row_length - c_g) + : 0; // For not aligned case + const IType* input_g = &input[r_g * row_length + c_g]; // Input address in global memory +#pragma unroll + for (int iter = 0; iter < num_iterations; ++iter) { + IVec input_vec; + // Step 1.1: Load from global memory (input) to registers + if constexpr (kAligned) { + input_vec.input_type.load_from(input_g); + } else { + if (r_g < num_rows) { + input_vec.input_type.load_from_elts(input_g, 0, num_ele); + } else { + input_vec.input_type.clear(); + } + } + // Step 1.2: Write to shared memory +#pragma unroll + for (int i = 0; i < kNVecIn / kNVecSMem; ++i) { + int c = c_s + i; + int r = r_s; + smem[r * kSMemCol + c] = input_vec.smem_type.data.elt[i]; + } + // Step 1.3: Update input address, row index of shared memory, (and row index of global memory for not aligned case) + input_g += stride_g; + r_s += r_stride; + if constexpr (!kAligned) { + r_g += r_stride; + } + } + } + + __syncthreads(); + + // Step 2: Cast and store to output_c + { + constexpr int r_stride = + kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory + constexpr int num_iterations = kTileDim / r_stride; + const int c_s = + (threadIdx.x % kNumThreadsStore) * (kNVecOut / kNVecSMem); // Column in shared memory + int r_s = threadIdx.x / kNumThreadsStore; // Row in shared memory + const size_t c_g = + static_cast(blockIdx.x) * kTileDim + c_s * kNVecSMem; // Column in global memory + size_t r_g = static_cast(blockIdx.y) * kTileDim + r_s; // Row in global memory + const size_t stride_g = static_cast(r_stride) * row_length; // Stride in global memory + const size_t num_ele = c_g < row_length ? min(static_cast(kNVecOut), row_length - c_g) + : 0; // For not aligned case + OType* output_g = &output_c[r_g * row_length + c_g]; // Output address in global memory + // Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of + // the first thread to do the reduction. + const unsigned src_lane = (threadIdx.x % kThreadsPerWarp) / kNumThreadsStore * kNumThreadsStore; + // This mask represents which threads should do the reduction together. + const unsigned mask = ((1 << kNumThreadsStore) - 1) << src_lane; + const bool is_src_lane = (threadIdx.x % kNumThreadsStore) == 0; +#pragma unroll + for (int iter = 0; iter < num_iterations; ++iter) { + SMemVec smem_vec[kNVecOut / kNVecSMem]; + // Step 2.1: Load from shared memory to registers +#pragma unroll + for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { + int c = c_s + i; + int r = r_s; + smem_vec[i] = smem[r * kSMemCol + c]; + } + // Step 2.2: Compute local amax + CType amax = 0; +#pragma unroll + for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { +#pragma unroll + for (int j = 0; j < kNVecSMem; ++j) { + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[j])); + } + } + // Step 2.3: Reduce amax +#pragma unroll + for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) { + const float other_amax = __shfl_down_sync(mask, amax, delta); + __builtin_assume(amax >= 0); + __builtin_assume(other_amax >= 0); + amax = fmaxf(amax, other_amax); + } + amax = __shfl_sync(mask, amax, src_lane); + CType scale; + // Step 2.4: Compute scale + scale = compute_scale_from_types(amax, epsilon, pow_2_scaling); + // Step 2.5: Write scale_inv + bool write_scale_inv = is_src_lane; + if constexpr (!kAligned) { + write_scale_inv &= (r_g < num_rows); + } + if (write_scale_inv) { + CType scale_inv = 1.0 / scale; + size_t row_idx = static_cast(blockIdx.y) * kTileDim + r_s; + size_t col_idx = static_cast(blockIdx.x); + tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv; + } + // Step 2.6: Quantize + OVec output_vec; +#pragma unroll + for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { +#pragma unroll + for (int j = 0; j < kNVecSMem; ++j) { + output_vec.data.elt[i * kNVecSMem + j] = + static_cast(static_cast(smem_vec[i].data.elt[j]) * scale); + } + } + // Step 2.7: Store output_c + if constexpr (kAligned) { + output_vec.store_to(output_g); + } else { + if (r_g < num_rows) { + output_vec.store_to_elts(output_g, 0, num_ele); + } + } + // Step 2.8: Update output address, row index of shared memory (and row index of global memory for not aligned case) + output_g += stride_g; + r_s += r_stride; + if constexpr (!kAligned) { + r_g += r_stride; + } + } + } + + // Step 3: Transpose, cast and store to output_t + if (return_transpose) { + constexpr int c_stride = + kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory + constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem); + const int r_s = (threadIdx.x % kNumThreadsStore) * kNVecOut; // Row in shared memory + int c_s = threadIdx.x / kNumThreadsStore; // Column in shared memory + size_t r_g = + static_cast(blockIdx.x) * kTileDim + c_s * kNVecSMem; // Row in global memory + const size_t c_g = static_cast(blockIdx.y) * kTileDim + r_s; // Column in global memory + const size_t stride_g = + static_cast(c_stride) * kNVecSMem * num_rows; // Stride in global memory + const size_t num_ele = c_g < num_rows ? min(static_cast(kNVecOut), num_rows - c_g) + : 0; // For not aligned case + OType* output_g = &output_t[r_g * num_rows + c_g]; // Output address in global memory + // Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of + // the first thread to do the reduction. + const unsigned src_lane = (threadIdx.x % kThreadsPerWarp) / kNumThreadsStore * kNumThreadsStore; + // This mask represents which threads should do the reduction together. + const unsigned mask = ((1 << kNumThreadsStore) - 1) << src_lane; + const bool is_src_lane = (threadIdx.x % kNumThreadsStore) == 0; +#pragma unroll + for (int iter = 0; iter < num_iterations; ++iter) { + SMemVec smem_vec[kNVecOut]; + // Step 3.1: Load from shared memory to registers +#pragma unroll + for (int i = 0; i < kNVecOut; ++i) { + int r = r_s + i; + int c = c_s; + smem_vec[i] = smem[r * kSMemCol + c]; + } +#pragma unroll + for (int smem_idx = 0; smem_idx < kNVecSMem; ++smem_idx) { + // Step 3.2: Compute local amax + CType amax = 0; +#pragma unroll + for (int i = 0; i < kNVecOut; ++i) { + amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[smem_idx])); + } + // Step 3.3: Reduce amax +#pragma unroll + for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) { + const float other_amax = __shfl_down_sync(mask, amax, delta); + __builtin_assume(amax >= 0); + __builtin_assume(other_amax >= 0); + amax = fmaxf(amax, other_amax); + } + amax = __shfl_sync(mask, amax, src_lane); + // Step 3.4: Compute scale + CType scale; + scale = compute_scale_from_types(amax, epsilon, pow_2_scaling); + // Step 3.5: Write scale_inv_t + bool write_scale_inv = is_src_lane; + if constexpr (!kAligned) { + write_scale_inv &= (r_g + smem_idx < row_length); + } + if (write_scale_inv) { + CType scale_inv = 1.0 / scale; + size_t row_idx = static_cast(blockIdx.x) * kTileDim + c_s * kNVecSMem + smem_idx; + size_t col_idx = static_cast(blockIdx.y); + tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv; + } + // Step 3.6: Quantize + OVec output_vec; +#pragma unroll + for (int i = 0; i < kNVecOut; ++i) { + output_vec.data.elt[i] = + static_cast(static_cast(smem_vec[i].data.elt[smem_idx]) * scale); + } + // Step 3.7: Store output_t + if constexpr (kAligned) { + output_vec.store_to(output_g + smem_idx * num_rows); + } else { + if (r_g + smem_idx < row_length) { + output_vec.store_to_elts(output_g + smem_idx * num_rows, 0, num_ele); + } + } + } + // Step 3.8: Update output address, column index of shared memory (and row index of global memory for not aligned case) + output_g += stride_g; + c_s += c_stride; + if constexpr (!kAligned) { + r_g += c_stride * kNVecSMem; + } + } + } +} + +} // namespace +} // namespace transformer_engine + +namespace transformer_engine::detail { + +void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor& scale_inv, + SimpleTensor& scale_inv_t, SimpleTensor& output, + SimpleTensor& output_t, const float epsilon, + const bool return_transpose, const bool pow2_scale, + cudaStream_t stream) { + NVTE_API_CALL(quantize_transpose_vector_blockwise); + NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); + + const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; + size_t num_elements = row_length; + size_t num_rows = 1; + for (size_t i = 0; (i < input.shape.size() - 1) && (input.shape.size() > 0); ++i) { + num_rows *= input.shape.at(i); + num_elements *= input.shape.at(i); + } + + // Early return if the input tensor is empty + if (num_elements == 0) { + return; + } + + // Options for scale layout of cuBLAS GEMM kernel. + + NVTE_CHECK(input.shape.size() == output.shape.size(), + "Input and output must have the same shape."); + + size_t scale_stride_x = 0; + size_t scale_stride_y = 0; + NVTE_CHECK(scale_inv.shape.size() == 2, "Scale dimension must be 2."); + size_t scale_k = scale_inv.shape[1]; + scale_stride_x = scale_k; + scale_stride_y = 1; + + size_t scale_t_stride_x = 0; + size_t scale_t_stride_y = 0; + + if (return_transpose) { + NVTE_CHECK(output_t.shape.size() == input.shape.size(), + "output_t must have same number of dimensions as input."); + if (output_t.shape.size() > 0) { + NVTE_CHECK(output_t.shape[0] == row_length, "Wrong dimension 0 of output_t."); + for (size_t i = 1; i < output_t.shape.size(); ++i) { + NVTE_CHECK(output_t.shape.at(i) == input.shape.at(i - 1), "Wrong dimension in output_t"); + } + } + + NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same dtype."); + + NVTE_CHECK(scale_inv_t.shape.size() == 2, "Scale_t dimension must be 2."); + scale_t_stride_x = scale_inv_t.shape[1]; + scale_t_stride_y = 1; + } + + const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim); + const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype, InputType, + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output.dtype, OutputType, + + dim3 grid(num_blocks_x, num_blocks_y, 1); + + const bool full_tile = row_length % kTileDim == 0 && num_rows % kTileDim == 0; + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + full_tile, kAligned, + + size_t smem_bytes = kSMemSize * sizeof(InputType); + // shared memory must be requested up + if (smem_bytes >= 48 * 1024) { + cudaError_t err = cudaFuncSetAttribute( + &block_scaled_1d_cast_transpose_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); + NVTE_CHECK(err == cudaSuccess, "Failed to set dynamic shared memory size."); + } block_scaled_1d_cast_transpose_kernel + <<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, scale_stride_x, + scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, return_transpose, + pow2_scale);) // kAligned + ) // OutputType + ) // InputType + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +} // namespace transformer_engine::detail diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index ba2890ada3..412a6f6ef0 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1262,6 +1262,30 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe workspace_tensor, stream); break; } + case NVTE_BLOCK_SCALING_2D: { + // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. + NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), + "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D"); + constexpr bool force_pow_2_scales = true; + quantize_transpose_square_blockwise( + input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, + /*epsilon=*/0.0, + /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream); + break; + } + case NVTE_BLOCK_SCALING_1D: { + // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. + NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), + "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"); + constexpr bool force_pow_2_scales = true; + quantize_transpose_vector_blockwise( + input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, + /*epsilon=*/0.0, + /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream); + break; + } default: NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); } diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index e529289640..c885c69333 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -349,6 +349,7 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); } } else { + // TODO(kwyss): Move dequantization code from torch to C++ for NVTE_BLOCK_SCALING NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); } } diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index a22b930ecd..55bc247f70 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -83,6 +83,29 @@ __device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared( : "memory"); } +__device__ __forceinline__ bool mbarrier_try_wait_parity(uint32_t mbar_ptr, const uint32_t parity) { + uint32_t waitComplete; + asm volatile( + "{\n\t .reg .pred P_OUT; \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P_OUT, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P_OUT; \n" + "}" + : "=r"(waitComplete) + : "r"(mbar_ptr), "r"(parity) + : "memory"); + return static_cast(waitComplete); +} + +__device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint32_t parity) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + while (!mbarrier_try_wait_parity(mbar_ptr, parity)) { + } +} + +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor // shared::cta -> global __device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global(uint64_t *dst_global_ptr, @@ -106,30 +129,6 @@ __device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global( : "memory"); } -__device__ __forceinline__ bool mbarrier_try_wait_parity(uint32_t mbar_ptr, const uint32_t parity) { - uint32_t waitComplete; - asm volatile( - "{\n\t .reg .pred P_OUT; \n\t" - "mbarrier.try_wait.parity.shared::cta.b64 P_OUT, [%1], %2; \n\t" - "selp.b32 %0, 1, 0, P_OUT; \n" - "}" - : "=r"(waitComplete) - : "r"(mbar_ptr), "r"(parity) - : "memory"); - return static_cast(waitComplete); -} - -__device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint32_t parity) { - uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); - while (!mbarrier_try_wait_parity(mbar_ptr, parity)) { - } -} - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group -__device__ __forceinline__ void cp_async_bulk_commit_group() { - asm volatile("cp.async.bulk.commit_group;"); -} - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group __device__ __forceinline__ void cp_async_bulk_wait_group() { asm volatile("cp.async.bulk.wait_group 0;"); @@ -158,13 +157,19 @@ __device__ __forceinline__ void cp_async_bulk_wait_group_read<4>() { asm volatile("cp.async.bulk.wait_group.read 4;"); } +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group +__device__ __forceinline__ void cp_async_bulk_commit_group() { + asm volatile("cp.async.bulk.commit_group;"); +} + // Proxy fence (bi-directional): __device__ __forceinline__ void fence_proxy_async() { asm volatile("fence.proxy.async;"); } + __device__ __forceinline__ void fence_proxy_async_shared_cta() { asm volatile("fence.proxy.async.shared::cta;"); } -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } // namespace ptx diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index 7aecc34643..805c034334 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -183,8 +183,8 @@ class ScalingMode(Enum): NVTE_DELAYED_TENSOR_SCALING = 0 NVTE_MXFP8_1D_SCALING = 1 - NVTE_INVALID_SCALING = 2 - NVTE_NO_SCALING = 3 + NVTE_INVALID_SCALING = 4 + NVTE_NO_SCALING = 5 def _get_impl(self) -> ScalingModeMetadataImpl: """Get the implementation for this scaling mode. diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index 3d807960ca..d1470e22e3 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -24,6 +24,12 @@ torch.bfloat16: tex.DType.kBFloat16, } +""" +This is a map: int -> torch.dtype +Used for resolving cuda extension types to torch. +Has one to one mapping with enum in +transformer_engine.h +""" TE_DType_To_Torch = { tex.DType.kByte: torch.uint8, tex.DType.kFloat8E4M3: torch.float8_e4m3fn, diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 2cf47e7399..338f1fcbb1 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -163,6 +163,38 @@ class Float8CurrentScalingQuantizer : public Quantizer { std::optional rowwise_data = std::nullopt) const override; }; +class Float8BlockQuantizer : public Quantizer { + public: + // Which float8 type is used for q data. + DType dtype; + + private: + // Options about how to quantize the tensor + // Quantization scales are rounded down to powers of 2. + bool force_pow_2_scales = false; + // Amax within quantization tile has a floor of epsilon. + float amax_epsilon = 0.0; + int block_scaling_dim = 2; + + public: + // Initializes from a python handle to a Float8BlockQuantizer + explicit Float8BlockQuantizer(const py::handle& quantizer); + + NVTEScalingMode get_scaling_mode() const override { + return (block_scaling_dim == 2) ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D; + } + + // Gets rowwise and columnwise_data from tensor and sets them on wrapper + void set_quantization_params(TensorWrapper* tensor) const override; + + // Create a python Float8BlockQuantized tensor and C++ wrapper + // for the tensor. Should set quantized data, scales for rowwise + // and optionally columnwise usage. + std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional rowwise_data = std::nullopt) const override; +}; + class MXFP8Quantizer : public Quantizer { public: DType dtype; diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_compute_scale.cu b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_compute_scale.cu index d262767958..0770e63015 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_compute_scale.cu +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_compute_scale.cu @@ -12,6 +12,8 @@ // #include #include + +#include // Stringstream is a big hammer, but I want to rely on operator<< for dtype. #include @@ -47,8 +49,8 @@ struct ComputeScaleAndScaleInvFunctor { n -= chunk_idx * chunk_size; for (int i_start = threadIdx.x; i_start < n && i_start < chunk_size; i_start += blockDim.x) { - float scale_val = transformer_engine::compute_scale_from_amax(amax[i_start], max_fp8, - force_pow_2_scales, epsilon); + float scale_val = transformer_engine::compute_scale_from_amax( + amax[i_start], max_fp8, force_pow_2_scales, epsilon, std::numeric_limits::max()); scale[i_start] = scale_val; transformer_engine::reciprocal(scale_inv + i_start, scale_val); } diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index c966f2ba97..60a97dad3c 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -28,6 +28,9 @@ PyTypeObject *Float8CurrentScalingQuantizerClass = nullptr; PyTypeObject *MXFP8TensorPythonClass = nullptr; /// TODO Remove PyTypeObject *MXFP8TensorBasePythonClass = nullptr; PyTypeObject *MXFP8QuantizerClass = nullptr; +PyTypeObject *Float8BlockwiseQTensorPythonClass = nullptr; +PyTypeObject *Float8BlockwiseQTensorBasePythonClass = nullptr; +PyTypeObject *Float8BlockwiseQuantizerClass = nullptr; void init_float8_extension() { if (Float8TensorPythonClass) return; @@ -61,9 +64,31 @@ void init_mxfp8_extension() { "Internal error: could not initialize pyTorch MXFP8 extension."); } +void init_float8blockwise_extension() { + if (Float8BlockwiseQTensorBasePythonClass) return; + auto fp8_module = + py::module_::import("transformer_engine.pytorch.tensor.float8_blockwise_tensor"); + auto fp8_base_module = py::module_::import( + "transformer_engine.pytorch.tensor._internal.float8_blockwise_tensor_base"); + Float8BlockwiseQuantizerClass = reinterpret_cast( + PyObject_GetAttrString(fp8_module.ptr(), "Float8BlockQuantizer")); + Float8BlockwiseQTensorBasePythonClass = reinterpret_cast( + PyObject_GetAttrString(fp8_base_module.ptr(), "Float8BlockwiseQTensorBase")); + Float8BlockwiseQTensorPythonClass = reinterpret_cast( + PyObject_GetAttrString(fp8_module.ptr(), "Float8BlockwiseQTensor")); + + NVTE_CHECK(Float8BlockwiseQuantizerClass != nullptr, + "Internal error: could not initialize pyTorch float8blockwise extension."); + NVTE_CHECK(Float8BlockwiseQTensorBasePythonClass != nullptr, + "Internal error: could not initialize pyTorch float8blockwise extension."); + NVTE_CHECK(Float8BlockwiseQTensorPythonClass != nullptr, + "Internal error: could not initialize pyTorch float8blockwise extension."); +} + void init_extension() { init_float8_extension(); init_mxfp8_extension(); + init_float8blockwise_extension(); } } // namespace transformer_engine::pytorch @@ -76,6 +101,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("output") = py::none(), py::arg("noop") = py::none()); m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"), py::arg("otype")); + m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize, "Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer")); m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)", diff --git a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp index 5121bc7f88..19d8a75a64 100644 --- a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp @@ -250,6 +250,142 @@ std::pair Float8CurrentScalingQuantizer::create_tenso tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); } this->set_quantization_params(&tensor); + + return {std::move(tensor), std::move(ret)}; +} + +Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) { + this->dtype = quantizer.attr("dtype").cast(); + this->block_scaling_dim = quantizer.attr("block_scaling_dim").cast(); + NVTE_CHECK(quantizer.attr("force_pow_2_scales").cast(), + "Pending additional parameters to the nvte_quantize API, " + "float8 block quantization requires pow2 scales"); + NVTE_CHECK(quantizer.attr("amax_epsilon").cast() == 0.0, + "Pending additional parameters to the nvte_quantize API, " + "float8 block quantization requires amax_epsilon==0"); + NVTE_CHECK(this->block_scaling_dim == 1 || this->block_scaling_dim == 2, + "Unsupported block scaling dim."); +} + +void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const { + // Change the rowwise and columnwise_data to the configured dtype. + // May be a switch between E5M2 and E4M3. + auto rowwise_data = tensor->get_rowwise_data(); + rowwise_data.dtype = static_cast(dtype); + + auto columnwise_data = tensor->get_columnwise_data(); + columnwise_data.dtype = static_cast(dtype); + + tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), + rowwise_data.shape); + tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), + columnwise_data.shape); +} + +std::pair Float8BlockQuantizer::create_tensor( + const std::vector& shape, DType dtype, std::optional rowwise_data) const { + using namespace pybind11::literals; + std::vector torch_shape; + size_t numel = 1; + for (auto s : shape) { + torch_shape.emplace_back(static_cast(s)); + numel *= s; + } + + TensorWrapper tensor(this->get_scaling_mode()); + at::TensorOptions opts; + at::TensorOptions scale_opts; + at::Tensor data_rowwise, data_colwise, scale_inv_rowwise, scale_inv_colwise; + opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); + scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA); + + size_t k_dim = torch_shape.size() == 0 ? 1u : torch_shape.back(); + size_t m_dim = numel / k_dim; + constexpr size_t kBlockLen = 128; + + if (rowwise_usage) { + if (rowwise_data.has_value()) { + data_rowwise = std::move(*rowwise_data); + } else { + data_rowwise = at::empty(torch_shape, opts); + } + size_t sinv0 = 0; + size_t sinv1 = 0; + if (block_scaling_dim == 2) { + sinv0 = (m_dim + kBlockLen - 1) / kBlockLen; + sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4); + } else if (block_scaling_dim == 1) { + sinv0 = (k_dim + kBlockLen - 1) / kBlockLen; + sinv1 = roundup(m_dim, 4); + } else { + NVTE_CHECK(false, + "Unsupported block_scaling_dim in create_tensor rowwise." + "Expected 1 or 2. Got ", + block_scaling_dim); + } + scale_inv_rowwise = at::empty({sinv0, sinv1}, scale_opts); + tensor.set_rowwise_data(data_rowwise.data_ptr(), this->dtype, shape); + tensor.set_rowwise_scale_inv(scale_inv_rowwise.data_ptr(), DType::kFloat32, + std::vector{sinv0, sinv1}); + } + + if (columnwise_usage) { + std::vector torch_columnwise_shape; + std::vector columnwise_shape; + NVTE_CHECK(torch_shape.size() == shape.size(), "Shape expected to match torch shape. Shape ", + columnwise_shape, " torch shape: ", torch_columnwise_shape); + if (torch_shape.size() > 0) { + torch_columnwise_shape.reserve(torch_shape.size()); + columnwise_shape.reserve(shape.size()); + torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]); + columnwise_shape.push_back(shape[shape.size() - 1]); + for (size_t i = 0; i < torch_shape.size() - 1; ++i) { + torch_columnwise_shape.push_back(torch_shape[i]); + columnwise_shape.push_back(shape[i]); + } + } + size_t sinv0 = 0; + size_t sinv1 = 0; + if (block_scaling_dim == 2) { + sinv0 = (k_dim + kBlockLen - 1) / kBlockLen; + sinv1 = roundup((m_dim + kBlockLen - 1) / kBlockLen, 4); + } else if (block_scaling_dim == 1) { + sinv0 = (m_dim + kBlockLen - 1) / kBlockLen; + sinv1 = roundup(k_dim, 4); + } else { + NVTE_CHECK(false, + "Unsupported block_scaling_dim in create_tensor columnwise." + "Expected 1 or 2. Got ", + block_scaling_dim); + } + data_colwise = at::empty(torch_columnwise_shape, opts); + scale_inv_colwise = at::empty({sinv0, sinv1}, scale_opts); + + tensor.set_columnwise_data(data_colwise.data_ptr(), this->dtype, columnwise_shape); + tensor.set_columnwise_scale_inv(scale_inv_colwise.data_ptr(), DType::kFloat32, + std::vector{sinv0, sinv1}); + } + this->set_quantization_params(&tensor); + + py::object ret; + if (internal) { + py::handle Float8BlockwiseQTensorClass( + reinterpret_cast(Float8BlockwiseQTensorBasePythonClass)); + ret = Float8BlockwiseQTensorClass( + "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, + "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise, + "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer, + "is_2D_scaled"_a = (block_scaling_dim == 2)); + } else { + py::handle Float8BlockwiseQTensorClass( + reinterpret_cast(Float8BlockwiseQTensorPythonClass)); + ret = Float8BlockwiseQTensorClass( + "shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise, + "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, + "columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype, + "quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2)); + } + return {std::move(tensor), std::move(ret)}; } @@ -302,8 +438,9 @@ std::pair MXFP8Quantizer::create_tensor( auto sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4); rowwise_scale_inv = at::zeros({sinv0, sinv1}, opts); tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); - tensor.set_rowwise_scale_inv(rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0, - std::vector{sinv0, sinv1}); + tensor.set_rowwise_scale_inv( + rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0, + std::vector{static_cast(sinv0), static_cast(sinv1)}); } if (columnwise_usage) { @@ -313,8 +450,9 @@ std::pair MXFP8Quantizer::create_tensor( columnwise_scale_inv = at::zeros({sinv0, sinv1}, opts); tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, shape); - tensor.set_columnwise_scale_inv(columnwise_scale_inv.data_ptr(), DType::kFloat8E8M0, - std::vector{sinv0, sinv1}); + tensor.set_columnwise_scale_inv( + columnwise_scale_inv.data_ptr(), DType::kFloat8E8M0, + std::vector{static_cast(sinv0), static_cast(sinv1)}); } this->set_quantization_params(&tensor); diff --git a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp index d5654fb43a..cb2121a457 100644 --- a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp @@ -84,6 +84,38 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) return ret; } +TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer *quantizer) { + const DType dtype = tensor.attr("_fp8_dtype").cast(); + bool is_2D_scaled = tensor.attr("_is_2D_scaled").cast(); + + bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); + bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); + + auto ret = TensorWrapper(is_2D_scaled ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D); + + if (rowwise_usage) { + const at::Tensor &data_rowwise = tensor.attr("_rowwise_data").cast(); + const at::Tensor &scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast(); + void *scale_inv_rowwise_dptr = scale_inv_rowwise.data_ptr(); + const auto &rowwise_shape = getTensorShape(data_rowwise); + ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, rowwise_shape); + const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise); + ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat32, scale_inv_rowwise_shape); + } + if (columnwise_usage) { + const at::Tensor &data_colwise = tensor.attr("_columnwise_data").cast(); + const at::Tensor &scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast(); + void *scale_inv_colwise_dptr = scale_inv_colwise.data_ptr(); + const auto &shape = getTensorShape(data_colwise); + ret.set_columnwise_data(data_colwise.data_ptr(), dtype, shape); + + const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise); + ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape); + } + quantizer->set_quantization_params(&ret); + return ret; +} + } // namespace detail } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index b0f55d7598..c7b3167e78 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -25,6 +25,9 @@ extern PyTypeObject *Float8CurrentScalingQuantizerClass; extern PyTypeObject *MXFP8TensorPythonClass; extern PyTypeObject *MXFP8TensorBasePythonClass; extern PyTypeObject *MXFP8QuantizerClass; +extern PyTypeObject *Float8BlockwiseQTensorPythonClass; +extern PyTypeObject *Float8BlockwiseQTensorBasePythonClass; +extern PyTypeObject *Float8BlockwiseQuantizerClass; void init_extension(); @@ -50,6 +53,15 @@ inline bool IsMXFP8Tensor(PyObject *obj) { return Py_TYPE(obj) == MXFP8TensorPythonClass || Py_TYPE(obj) == MXFP8TensorBasePythonClass; } +inline bool IsFloat8BlockwiseQuantizers(PyObject *obj) { + return Py_TYPE(obj) == Float8BlockwiseQuantizerClass; +} + +inline bool IsFloat8BlockwiseQTensor(PyObject *obj) { + return Py_TYPE(obj) == Float8BlockwiseQTensorPythonClass || + Py_TYPE(obj) == Float8BlockwiseQTensorBasePythonClass; +} + TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer); template @@ -61,6 +73,9 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizati std::unique_ptr CreateMXFP8Params(const py::handle params); +TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, + Quantizer *quantization_params); + inline bool IsFloatingPointType(at::ScalarType type) { return type == at::kFloat || type == at::kHalf || type == at::kBFloat16; } @@ -71,7 +86,9 @@ constexpr std::array custom_types_converters = { std::make_tuple(IsFloat8Tensor, IsFloat8CurrentScalingQuantizers, NVTETensorFromFloat8Tensor, CreateQuantizer), std::make_tuple(IsMXFP8Tensor, IsMXFP8Quantizers, NVTETensorFromMXFP8Tensor, - CreateQuantizer)}; + CreateQuantizer), + std::make_tuple(IsFloat8BlockwiseQTensor, IsFloat8BlockwiseQuantizers, + NVTETensorFromFloat8BlockwiseQTensor, CreateQuantizer)}; } // namespace detail diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py new file mode 100644 index 0000000000..9135237854 --- /dev/null +++ b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py @@ -0,0 +1,240 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Mixin class holding data specific for Float8BlockwiseQTensor""" + +from __future__ import annotations +import math +from typing import Optional, Dict, Any, Tuple +import torch + +from transformer_engine_torch import DType as TE_DType + +from ...constants import TE_DType_To_Torch + +from ..quantized_tensor import Quantizer + + +class Float8BlockwiseQTensorBase: + """Mixin class that holds data attributes of Float8BlockwiseQTensor. + + Float8BlockwiseQTensor inherits from the PyTorch tensor class and this + mixin class. If this class is instantiated directly, it has the same + data, lower CPU overhead, and less functionality. It should only + be instantiated directly for performance-critical internal usage. + """ + + _rowwise_data: Optional[torch.Tensor] + _columnwise_data: Optional[torch.Tensor] + _quantizer: Quantizer + _fp8_dtype: TE_DType + _rowwise_scale_inv: Optional[torch.Tensor] + _columnwise_scale_inv: Optional[torch.Tensor] + _is_2D_scaled: bool + + def __new__( + cls, + *args, + rowwise_data: torch.Tensor, + rowwise_scale_inv: torch.Tensor, + columnwise_data: Optional[torch.Tensor], + columnwise_scale_inv: Optional[torch.Tensor], + fp8_dtype: TE_DType, + quantizer: Quantizer, + is_2D_scaled: bool, + **kwargs, + ): + instance = super().__new__(cls, *args, **kwargs) + instance._rowwise_data = rowwise_data + instance._columnwise_data = columnwise_data + instance._quantizer = quantizer + instance._fp8_dtype = fp8_dtype + instance._rowwise_scale_inv = rowwise_scale_inv + instance._columnwise_scale_inv = columnwise_scale_inv + instance._is_2D_scaled = is_2D_scaled + + return instance + + def get_metadata(self) -> Dict[str, Any]: + """Get this tensor's metadata.""" + return { + "rowwise_data": self._rowwise_data, + "rowwise_scale_inv": self._rowwise_scale_inv, + "columnwise_data": self._columnwise_data, + "columnwise_scale_inv": self._columnwise_scale_inv, + "fp8_dtype": self._fp8_dtype, + "quantizer": self._quantizer, + "is_2D_scaled": self._is_2D_scaled, + } + + def prepare_for_saving( + self, + ) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]: + """Prepare the tensor base for saving for backward""" + tensors = [self._rowwise_data, self._columnwise_data] + self._rowwise_data = None + self._columnwise_data = None + return tensors, self + + def restore_from_saved( + self, tensors: list[Optional[torch.Tensor]] + ) -> list[Optional[torch.Tensor]]: + """Restore the tensor base data from the saved tensors list.""" + self._rowwise_data = tensors[0] + self._columnwise_data = tensors[1] + return tensors[2:] + + def get_data_tensors(self): + """Get this Tensor's data.""" + return self._rowwise_data, self._columnwise_data + + def _transpose_dq_columnwise_output(self, columnwise_dq: torch.Tensor) -> torch.Tensor: + """Takes dequantized columnwise data and permutes to a rowwise shape""" + if columnwise_dq.dim() < 2: + return columnwise_dq + permute_dims = list(range(1, columnwise_dq.dim())) + permute_dims.append(0) + return torch.permute(columnwise_dq, tuple(permute_dims)).contiguous() + + def _dequantize_vectorwise(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + block_len = 128 + + q_M, q_K = 1, 1 + if self._rowwise_data is not None: + q = self._rowwise_data + scale_inv = self._rowwise_scale_inv + transpose_output = False + if len(q.shape) >= 1: + q_K = q.shape[-1] + for i in range(len(q.shape) - 1): + q_M *= q.shape[i] + else: + assert self._columnwise_data is not None, "No data to dequantize" + q = self._columnwise_data + scale_inv = self._columnwise_scale_inv + transpose_output = True + if len(q.shape) >= 1: + q_M = q.shape[0] + for i in range(1, len(q.shape)): + q_K *= q.shape[i] + + orig_shape = q.shape + q = q.reshape(q_M, q_K) + k_tiles, scale_m = scale_inv.shape + if q_K % block_len != 0: + k_pad_amount = (block_len - (q_K % block_len)) % block_len + q = torch.nn.functional.pad( + q, (0, k_pad_amount, 0, 0), mode="constant", value=0 + ).contiguous() + _, padded_K = q.shape + q_tiled = q.reshape(q_M, k_tiles, block_len) + if scale_m > q_M: + # scale_m is 4 element aligned. + scale_inv = scale_inv[:, :q_M].contiguous() + dq_scale = scale_inv.transpose(-2, -1).contiguous().reshape(q_M, k_tiles, 1) + torch_q_dtype = TE_DType_To_Torch[self._fp8_dtype] + result = q_tiled.view(torch_q_dtype).to(torch.float32) * dq_scale + if padded_K != q_K: + result = result.reshape(q_M, padded_K)[:, :q_K] + result = result.to(dtype) + if len(orig_shape) == 0: + result = result.reshape([]) + else: + result = result.reshape(*orig_shape).contiguous() + + if transpose_output: + return self._transpose_dq_columnwise_output(result) + return result + + def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """ + Construct plain PyTorch tensor from Float8BlockwiseQTensor + """ + block_len = 128 + if not self._is_2D_scaled: + return self._dequantize_vectorwise(dtype=dtype) + + def format_scale_as_logical_shape(q_K, scales, block_len): + # The GEMM for 2D blocks required padding in the scales. + derived_scale_k_shape = math.ceil(q_K / block_len) + _, scale_K = scales.shape + if derived_scale_k_shape == scale_K: + return scales + return scales[:, :derived_scale_k_shape].contiguous() + + q_M, q_K = 1, 1 + if self._rowwise_data is not None: + q = self._rowwise_data + scale_inv = self._rowwise_scale_inv + transpose_output = False + if len(q.shape) >= 1: + q_K = q.shape[-1] + for i in range(len(q.shape) - 1): + q_M *= q.shape[i] + else: + assert self._columnwise_data is not None, "No data to dequantize" + q = self._columnwise_data + scale_inv = self._columnwise_scale_inv + transpose_output = True + if len(q.shape) >= 1: + q_M = q.shape[0] + for i in range(1, len(q.shape)): + q_K *= q.shape[i] + + orig_shape = q.shape + q = q.reshape(q_M, q_K) + formatted_scales = format_scale_as_logical_shape(q_K, scale_inv, block_len) + assert len(formatted_scales.shape) == 2 + m_tiles, k_tiles = formatted_scales.shape + unpadded_m, unpadded_k = q_M, q_K + m_block_len = block_len + k_block_len = block_len + if q_M % m_block_len != 0 or q_K % k_block_len != 0: + m_pad_amount = (m_block_len - (q_M % m_block_len)) % m_block_len + k_pad_amount = (k_block_len - (q_K % k_block_len)) % k_block_len + q = torch.nn.functional.pad( + q, (0, k_pad_amount, 0, m_pad_amount), mode="constant", value=0 + ).contiguous() + padded_M, padded_K = q.shape + q_tiled = q.reshape(m_tiles, m_block_len, k_tiles, k_block_len) + + torch_q_dtype = TE_DType_To_Torch[self._fp8_dtype] + + result = q_tiled.view(torch_q_dtype).to(torch.float32) * formatted_scales.view( + m_tiles, 1, k_tiles, 1 + ) + result = result.view(padded_M, padded_K).to(dtype) + if padded_M != unpadded_m or padded_K != unpadded_k: + result = result[:unpadded_m, :unpadded_k] + if len(orig_shape) == 0: + result = result.reshape([]) + else: + result = result.reshape(*orig_shape).contiguous() + if transpose_output: + return self._transpose_dq_columnwise_output(result) + return result + + def size(self, *args, **kwargs): + # pylint: disable=missing-function-docstring + if self._rowwise_data is not None: + return self._rowwise_data.size(*args, **kwargs) + dims = list(self._columnwise_data.size(*args, **kwargs)) + reordered = [] + for i in range(1, len(dims)): + reordered.append(dims[i]) + reordered.append(dims[0]) + return torch.Size(reordered) + + def __repr__(self): + if self._rowwise_data is not None: + data = self.dequantize() + descriptor = "rowwise" + else: + data = self.dequantize() + descriptor = "columnwise" + return ( + "Float8BlockwiseQTensorBase(" + f"fp8_dtype={self._fp8_dtype}, " + f"{descriptor}_scaled_data={data}" + ) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py new file mode 100644 index 0000000000..138d1fd29e --- /dev/null +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -0,0 +1,608 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tensor class with FP8 data quantized with NxN tiles""" +from __future__ import annotations +from typing import Optional, Tuple, Iterable + +import math +import torch +import transformer_engine_torch as tex + +from transformer_engine_torch import DType as TE_DType +from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc +from ..utils import devices_match, round_up_to_nearest_multiple + +aten = torch.ops.aten + + +class Float8BlockQuantizer(Quantizer): + """Builder class for tensors quantized with current scaling using + NxN quantization tilings to choose scale. + + This class is typically used to convert a high-precision tensor + (e.g. in FP32 or BF16) into a quantized tensor (e.g. in FP8). + + """ + + dtype: TE_DType + block_len: int + amax_epsilon: float + force_pow_2_scales: bool + block_scaling_dim: int + + def __init__( + self, + fp8_dtype: TE_DType, + *, + rowwise: bool, + columnwise: bool, + amax_epsilon: float = 0.0, + force_pow_2_scales: bool = True, + block_scaling_dim: int = 2, + ) -> None: + super().__init__(rowwise=rowwise, columnwise=columnwise) + assert rowwise + self.dtype = fp8_dtype + self.block_len = 128 + self.force_pow_2_scales = force_pow_2_scales + self.amax_epsilon = amax_epsilon + self.block_scaling_dim = block_scaling_dim + + def update_quantized( + self, + src: torch.Tensor, + dst: QuantizedTensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + """Update the quantized tensor with data from the source tensor. + + This method quantizes the input tensor and stores the result in the destination tensor. + + Parameters + ---------- + src : torch.Tensor + Source tensor containing the data to be quantized + dst : QuantizedTensor + Destination tensor where the quantized data will be stored + noop_flag : Optional[torch.Tensor] + Optional flag tensor indicating whether to skip the quantization operation + + Returns + ------- + QuantizedTensor + The destination tensor containing the quantized data + + Raises + ------ + AssertionError + If the destination tensor is not a Float8BlockwiseQTensor + """ + assert isinstance( + dst, Float8BlockwiseQTensor + ), f"Cannot store quantized blockwise tensor in {type(dst)} type." + # Make sure input is in expected format + if not devices_match(src.device, dst.device): + src = src.to(device=dst.device) + if not src.is_contiguous(): + src = src.contiguous() + + # Launch cast kernel + tex.quantize(src, self, dst, noop_flag) + + dst._fp8_dtype = self.dtype + return dst + + def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, int]: + """Calculate the shape of the scaling tensor for blockwise quantization. + + This method determines the shape of the scaling tensor needed for blockwise quantization, + taking into account the input tensor shape and whether columnwise scaling is used. + The scales are padded to multiples of 4 on the inner dimension for compatibility with GEMM. + + Parameters + ---------- + shape : Iterable[int] + Shape of the input tensor to be quantized + columnwise : bool + Whether to use columnwise scaling (True) or rowwise scaling (False) + + Returns + ------- + Tuple[int, int] + Shape of the scaling tensor as (outer_dim, inner_dim) + For 2D tensors: + - If columnwise: (roundup(K/blocksize), round_to_multiple(roundup(M/blocksize), 4)) + - If rowwise: (roundup(M/blocksize), round_to_multiple(roundup(K/blocksize), 4)) + For 1D tensors: + - If columnwise: (roundup(M/blocksize), round_to_multiple(K, 4)) + - If rowwise: (roundup(K/blocksize), round_to_multiple(M, 4)) + """ + M, K = 1, 1 + for i in range(len(shape) - 1): + M *= shape[i] + if len(shape) > 0: + K = shape[-1] + if self.block_scaling_dim == 2: + if columnwise: + outer = math.ceil(K / self.block_len) + inner = round_up_to_nearest_multiple(math.ceil(M / self.block_len), 4) + return (outer, inner) + outer = math.ceil(M / self.block_len) + inner = round_up_to_nearest_multiple(math.ceil(K / self.block_len), 4) + return (outer, inner) + assert self.block_scaling_dim == 1, "Only 1D or 2D blocks supported" + if columnwise: + outer = math.ceil(M / self.block_len) + inner = round_up_to_nearest_multiple(K, 4) + return (outer, inner) + outer = math.ceil(K / self.block_len) + inner = round_up_to_nearest_multiple(M, 4) + return (outer, inner) + + def get_columnwise_shape(self, shape: Iterable[int]) -> Tuple[int, ...]: + """Calculate the shape of a tensor after columnwise permutation. + + This method rearranges the dimensions of a tensor to be columnwise, + moving the last dimension to the front and keeping the order of other dimensions. + + Parameters + ---------- + shape : Iterable[int] + Original shape of the tensor + + Returns + ------- + Tuple[int, ...] + New shape with dimensions rearranged for columnwise layout. + For a shape (d1, d2, ..., dn), returns (dn, d1, d2, ..., dn-1). + Returns empty tuple for empty input shape. + """ + if len(shape) == 0: + return tuple() + colwise_shape = [shape[-1]] + for i in range(len(shape) - 1): + colwise_shape.append(shape[i]) + return tuple(colwise_shape) + + def make_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + requires_grad: bool = False, + ) -> Float8BlockwiseQTensor: + """Construct quantized tensor with uninitialized data""" + if device is None: + device = torch.device("cuda") + + # Allocate FP8 data + data = torch.empty(shape, dtype=torch.uint8, device=device) + scale_shape = self.get_scale_shape(shape, columnwise=False) + scale_inv = torch.empty( + scale_shape, + dtype=torch.float32, + device=device, + ) + + # Allocate FP8 data transpose if needed + columnwise_data = None + columnwise_scale_inv = None + if self.columnwise_usage: + columnwise_data = torch.empty( + self.get_columnwise_shape(shape), dtype=torch.uint8, device=device + ) + columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) + columnwise_scale_inv = torch.empty( + columnwise_scale_shape, + dtype=torch.float32, + device=device, + ) + + # Construct FP8 tensor + return Float8BlockwiseQTensor( + shape=shape, + dtype=dtype, + fp8_dtype=self.dtype, + rowwise_data=data, + rowwise_scale_inv=scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + quantizer=self, + is_2D_scaled=self.block_scaling_dim == 2, + requires_grad=requires_grad, + ) + + def calibrate(self, tensor: torch.Tensor) -> None: + # NOTE: This interface is specific to requirements like delayed scaling + # where state from an estimator influences distribution parameters. + pass + + +class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): + """Tensor class with FP8 data quantized via NxN blocks or 1xN blocks. + + The tensor presents as having a standard, higher-precision dtype, + but the data itself is (scaled) FP8. For most tensor operations, + the data will be cast to the nominal dtype before performing the + operation. + + Parameters + ---------- + rowwise_data: torch.Tensor + FP8 data in a uint8 tensor matching shape of dequantized tensor. + rowwise_scale_inv: torch.Tensor + FP32 dequantization scales in GEMM format for dequantizing rowwise_data. + columnwise_data: Optional[torch.Tensor] + FP8 data in a uint8 tensor matching shape of dequantized tensor transpose. + columnwise_scale_inv: Optional[torch.Tensor] + FP32 dequantization scales in GEMM format for dequantizing columnwise_data. + + fp8_dtype: transformer_engine_torch.DType, default = kFloat8E4M3 + FP8 format. + quantizer: Quantizer - the Float8BlockQuantizer that quantized this tensor and + holds configuration about quantization and dequantization modes. + """ + + def __repr__(self, *, tensor_contents=None): + return ( + f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype}," + f" is_2D_scaled={self._is_2D_scaled}," + f" data={self.dequantize(dtype=self.dtype)})" + ) + + def _get_quantizer(self) -> Quantizer: + """Get builder for quantized tensor + + Quantizer can be used for in-place operations. + + """ + assert self._quantizer is not None + return self._quantizer + + def quantize_( + self, + tensor: torch.Tensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> Float8BlockwiseQTensor: + """Update FP8 data + + Parameters + ---------- + tensor: torch.Tensor + Tensor to copy from + noop_flag: torch.Tensor, optional + float32 flag indicating whether to avoid performing update + + """ + if isinstance(tensor, QuantizedTensor): + return self.quantize_(tensor.dequantize()) + self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) + return self + + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """ + Construct plain PyTorch tensor from Float8BlockwiseQTensor + + By default the resulting tensor's dtype is the + Float8BlockwiseQTensor's pre-quantized dtype. + """ + if dtype is not None: + dequant_dtype = dtype + else: + dequant_dtype = self.dtype + return super().dequantize(dtype=dequant_dtype) + + def detach(self) -> Float8BlockwiseQTensor: + # pylint: disable=missing-function-docstring + return Float8BlockwiseQTensor.make_like(self) + + def update_usage( + self, rowwise_usage: Optional[bool] = None, columnwise_usage: Optional[bool] = None + ): + """ + update_usage can be used to clear out one of two possible copies of the data. + """ + + if rowwise_usage is None: + rowwise_usage = self._rowwise_data is not None + if columnwise_usage is None: + columnwise_usage = self._columnwise_data is not None + assert ( + columnwise_usage or rowwise_usage + ), "Must retain some data either columnwise or rowwise" + + if columnwise_usage and rowwise_usage: + assert ( + self._rowwise_data is not None + and self._rowwise_scale_inv is not None + and self._columnwise_data is not None + and self._columnwise_scale_inv is not None + ), "Cannot update to rowwise and columnwise usage." + return + + if rowwise_usage: + assert ( + self._rowwise_data is not None and self._rowwise_scale_inv is not None + ), "Cannot update to rowwise usage." + self._columnwise_data = None + self._columnwise_scale_inv = None + return + if columnwise_usage: + assert ( + self._columnwise_data is not None and self._columnwise_scale_inv is not None + ), "Cannot update to columnwise usage." + self._rowwise_data = None + self._rowwise_scale_inv = None + return + + return + + def clone(self) -> Float8BlockwiseQTensor: + # pylint: disable=missing-function-docstring + rowwise_data = None + if self._rowwise_data is not None: + rowwise_data = self._rowwise_data.detach().clone() + columnwise_data = None + if self._columnwise_data is not None: + columnwise_data = self._columnwise_data.detach().clone() + return _IdentityFunc.apply( + self, + { + "rowwise_data": rowwise_data, + "columnwise_data": columnwise_data, + }, + ) + + def view(self, *shape: Tuple[int]) -> Float8BlockwiseQTensor: + # pylint: disable=missing-function-docstring + return _ViewFunc.apply(self, shape) + + def reshape(self, *shape: Tuple[int]) -> Float8BlockwiseQTensor: + # pylint: disable=missing-function-docstring + return _ReshapeFunc.apply(self, shape) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + + # View op + if func == aten.view.default: + tensor = args[0] + data = tensor._rowwise_data + if data is None: + # Columnwise data only. + super().__torch_dispatch__(func, types, args, kwargs) + orig_size = data.size() + out_data = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + if orig_size != out_data.size(): + raise NotImplementedError( + "Changing shape with view not implemented " + " (scales and columnwise data untouched)." + ) + return Float8BlockwiseQTensor.make_like(tensor) + + # Default case + return super().__torch_dispatch__(func, types, args, kwargs) + + def contiguous( + self, + memory_format: torch.memory_format = torch.contiguous_format, + ) -> Float8BlockwiseQTensor: + """Returns tensor with data in provided memory format + + Returns `self` if data is already in correct memory format. + + """ + if ( + self._rowwise_data is not None + and self._rowwise_data.is_contiguous(memory_format=memory_format) + and ( + (self._columnwise_data is None) + or (self._columnwise_data.is_contiguous(memory_format=memory_format)) + ) + ): + return self + raise ValueError("Float8BlockwiseQTensor does not support different memory formats!") + + def clear(self): + """Deallocate this tensor's memory. Typically not needed and must be used carefully.""" + self._rowwise_data = torch.Tensor() if self._rowwise_data is not None else None + self._columnwise_data = torch.Tensor() if self._columnwise_data is not None else None + + @classmethod + def _make_in_reduce_ex( + cls, + shape: torch.Size, + rowwise_data: torch.Tensor, + rowwise_scale_inv: torch.Tensor, + columnwise_data: torch.Tensor, + columnwise_scale_inv: torch.Tensor, + fp8_dtype: TE_DType, + dtype: torch.dtype, + quantizer: Quantizer, + is_2D_scaled: bool, + ) -> Float8BlockwiseQTensor: + """Build Float8BlockwiseQTensor, for use in __reduce__ + + __reduce_ex__ assumes object constructor has positional + arguments. + + """ + return Float8BlockwiseQTensor( + shape=shape, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + fp8_dtype=fp8_dtype, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + dtype=dtype, + quantizer=quantizer, + is_2D_scaled=is_2D_scaled, + ) + + def __reduce_ex__(self, protocol: int) -> tuple: + """Custom pickling to remove references to FP8 metadata objects""" + return ( + Float8BlockwiseQTensor._make_in_reduce_ex, + ( + self.shape, + self._rowwise_data, + self._rowwise_scale_inv, + self._columnwise_data, + self._columnwise_scale_inv, + self._fp8_dtype, + self.dtype, + self._quantizer, + self._is_2D_scaled, + ), + ) + + def _get_data(self) -> Float8BlockwiseQTensor: + """Get tensor data property""" + return self + + @torch.no_grad() + def _set_data(self, tensor: torch.Tensor) -> None: + """Set tensor data property + + Just takes FP8 data if setting from a Float8BlockwiseQTensor. Otherwise + casts to FP8. + + """ + # Tensor device + new_device = tensor.device if tensor.is_cuda else self.device + + def _set_from_tensor(dst: Float8BlockwiseQTensor, src: Float8BlockwiseQTensor): + dst._rowwise_data = src._rowwise_data + dst._columnwise_data = src._columnwise_data + dst._quantizer = src._quantizer + dst._fp8_dtype = src._fp8_dtype + dst._rowwise_scale_inv = src._rowwise_scale_inv + dst._columnwise_scale_inv = src._columnwise_scale_inv + dst.dtype = src.dtype + + # Check that tensor dimensions match + if ( + self.size() != tensor.size() + or self.stride() != tensor.stride() + or self.layout != tensor.layout + ): + raise ValueError("Invalid tensor for updating Float8BlockwiseQTensor data") + + # Just copy FP8 data if other tensor is Float8BlockwiseQTensor + if ( + isinstance(tensor, Float8BlockwiseQTensor) + and self.storage_offset() == tensor.storage_offset() + and devices_match(self.device, new_device) + ): + _set_from_tensor(self, tensor) + return + + if isinstance(tensor, Float8BlockwiseQTensor): + assert tensor._quantizer is not None, "Can't quantize without a quantizer" + quantizer = tensor._quantizer + else: + assert self._quantizer is not None, "Can't quantize without a quantizer" + quantizer = self._quantizer + + # Quantize to FP8 + quantizer.update_quantized(tensor, self) + + # Cast to FP8 when setting Float8BlockwiseQTensor.data + data = property(_get_data, _set_data) + + +class _ViewFunc(torch.autograd.Function): + """View function + + View the Float8BlockwiseQTensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: Float8BlockwiseQTensor, + shape: Optional[list[int]] = None, + ) -> Float8BlockwiseQTensor: + # pylint: disable=missing-function-docstring + + # Return input tensor if shape is not provided + if ctx is not None: + ctx.shape = tensor.shape + if shape is None: + return tensor + + if list(shape) != list(tensor.shape): + raise NotImplementedError("View not implemented.") + return tensor + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + + if isinstance(grad, Float8BlockwiseQTensor): + raise NotImplementedError("View bwd not implemented") + return grad.view(ctx.shape), None + + +class _ReshapeFunc(torch.autograd.Function): + """Reshape function + + Reshape the Float8BlockwiseQTensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: Float8BlockwiseQTensor, + shape: Optional[list[int]] = None, + ) -> Float8BlockwiseQTensor: + # pylint: disable=missing-function-docstring + + # Return input tensor if shape is not provided + if ctx is not None: + ctx.shape = tensor.shape + if shape is None: + return tensor + + # Canonicalize shape + if not isinstance(shape, Iterable): + shape = [shape] + elif len(shape) == 1 and isinstance(shape[0], Iterable): + shape = shape[0] + if -1 in shape: + shape = list(shape) + d_inferred = -math.prod(tensor.shape) // math.prod(shape) + for i, d in enumerate(shape): + if d == -1: + shape[i] = d_inferred + break + if list(shape) != list(tensor.shape): + raise NotImplementedError("Reshape not implemented yet.") + return tensor + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + + if isinstance(grad, Float8BlockwiseQTensor): + raise NotImplementedError("Reshape bwd not implemented yet.") + return grad.view(ctx.shape), None From be1f647c4f07d95e66b2fa8e611578766c37f927 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Fri, 4 Apr 2025 05:58:11 -0700 Subject: [PATCH 03/29] [JAX-Q] Distributed MXFP8 flax layer tests (#1643) MXFP8 flax layer tests Signed-off-by: Jeremy Berchtold --- tests/jax/test_distributed_layernorm_mlp.py | 49 +++++++++++++-------- transformer_engine/jax/flax/module.py | 7 ++- 2 files changed, 36 insertions(+), 20 deletions(-) diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 0586d2b6c7..efc24fe6ea 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -267,9 +267,18 @@ def _test_layernorm_mlp( transpose_batch_sequence=False, # input: [batch, seqlen, hidden] intermediate_dim=INTERMEDIATE, activations=activation_type, + scale_axes=(W_NO_SHARD_AXES,), + ln_bias_axes=(W_NO_SHARD_AXES,), + kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES), + kernel_axes_2=(W_TP_AXES, W_FSDP_AXES), use_bias=use_bias, + bias_axes_1=(W_JOINED_AXES, W_TP_AXES), + bias_axes_2=(W_NO_SHARD_AXES,), + layernorm_input_axes=LAYERNORM_INPUT_AXES, + dot_1_input_axes=DOT_1_INPUT_AXES, + dot_2_input_axes=DOT_2_INPUT_AXES, ) - params_single = ln_mlp_single.init(init_rngs, x) + params_single = ln_mlp_single.init(init_rngs, x, deterministic=True) mlp_out_single, ln_out_single = ln_mlp_single.apply( params_single, x, deterministic=True ) @@ -298,7 +307,7 @@ def _test_layernorm_mlp( dot_2_input_axes=DOT_2_INPUT_AXES, name="mlp", ) - params_sharded = ln_mlp_sharded.init(init_rngs, x) + params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True) mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply( params_sharded, x, deterministic=True ) @@ -318,20 +327,22 @@ def test_layernorm_mlp_layer(self, mesh_config, activation_type, use_bias, input mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8=False ) - # TODO: debug - # @pytest.mark.skipif(not is_fp8_supported, reason=reason) - # @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) - # @pytest_parametrize_wrapper( - # "activation_type", [("gelu",), ("gelu", "linear")] - # ) - # @pytest_parametrize_wrapper("use_bias", [True, False]) - # @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) - # @pytest_parametrize_wrapper("dtype", DTYPES) - # @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) - # def test_layernorm_fp8_mlp_layer( - # self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe - # ): - # self._test_layernorm_mlp( - # mesh_config, activation_type, use_bias, input_shape, dtype, - # use_fp8=True, fp8_recipe=fp8_recipe - # ) + @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) + @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) + @pytest_parametrize_wrapper("use_bias", [True, False]) + @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) + @pytest_parametrize_wrapper("dtype", DTYPES) + @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + def test_layernorm_fp8_mlp_layer( + self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe + ): + self._test_layernorm_mlp( + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + use_fp8=True, + fp8_recipe=fp8_recipe, + ) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 56672fb6bf..a0d1e33e38 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -1088,6 +1088,10 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): kernel_1 = jnp.reshape(kernel_1, kernel_1_compute_shape) if not QuantizeConfig.is_fp8_enabled(): kernel_1 = kernel_1.astype(input_dtype) + if self.kernel_axes_1 is not None: + kernel_1 = with_sharding_constraint_by_logical_axes( + kernel_1, self.kernel_axes_1[:-2] + self.kernel_axes_1[-1:] + ) hidden_size = inputs.shape[-1] hidden_size_tuple = _canonicalize_tuple(hidden_size) kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple @@ -1105,7 +1109,8 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): kernel_2 = jnp.reshape(kernel_2, kernel_2_compute_shape) if not QuantizeConfig.is_fp8_enabled(): kernel_2 = kernel_2.astype(input_dtype) - + if self.kernel_axes_2 is not None: + kernel_2 = with_sharding_constraint_by_logical_axes(kernel_2, self.kernel_axes_2) contract_ind = tuple(range(0, len(axis))) if self.use_bias: From fbcbcb0924a16c8301e3e02ca0a5a1a266606179 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Fri, 28 Feb 2025 15:30:07 -0800 Subject: [PATCH 04/29] Add GEMM logic for blockwise quantized tensors. GEMM test cases included in pytorch integration. Signed-off-by: Keith Wyss --- .../blockwise_fp8_gemm_reference.py | 238 +++++ .../blockwise_quantizer_reference.py | 1 + .../test_float8_blockwise_gemm_exact.py | 832 ++++++++++++++++++ .../common/gemm/cublaslt_gemm.cu | 263 ++++-- .../common/normalization/layernorm/ln_api.cpp | 4 +- .../normalization/rmsnorm/rmsnorm_api.cpp | 4 +- .../csrc/extensions/type_converters.cpp | 7 + 7 files changed, 1262 insertions(+), 87 deletions(-) create mode 100644 tests/pytorch/references/blockwise_fp8_gemm_reference.py create mode 100644 tests/pytorch/test_float8_blockwise_gemm_exact.py diff --git a/tests/pytorch/references/blockwise_fp8_gemm_reference.py b/tests/pytorch/references/blockwise_fp8_gemm_reference.py new file mode 100644 index 0000000000..3487dfb810 --- /dev/null +++ b/tests/pytorch/references/blockwise_fp8_gemm_reference.py @@ -0,0 +1,238 @@ +from typing import Tuple + +import torch +import triton +import triton.language as tl + + +@triton.jit +def fused_fma_kernel(y_ptr, x_ptr, s_ptr, M, N, y_str0, y_str1, BLOCK: tl.constexpr = 128): + pid = tl.program_id(0) + idx = pid * BLOCK + tl.arange(0, BLOCK) + mask = idx < M * N + + row = idx // N + col = idx % N + + y_offset = row * y_str0 + col * y_str1 + x_offset = row * N + col + s_offset = row * N + col + + y = tl.load(y_ptr + y_offset, mask=mask) + x = tl.load(x_ptr + x_offset, mask=mask) + s = tl.load(s_ptr + s_offset, mask=mask) + + tl.store(y_ptr + y_offset, tl.fma(x, s, y), mask=mask) + + +def fused_fma(y, x, s, BLOCK=128): + """ + Fused multiply-add operation (y = y + x * s). + + PyTorch does not provide a direct FMA equivalent (torch.addcmul is not bitwise equivalent to this operation). + This function also supports cases where 'y' is non-contiguous in memory. + """ + + assert ( + y.shape == x.shape == s.shape and y.dim() == 2 + ), "All tensors must be 2D with the same shape" + assert x.is_contiguous() and s.is_contiguous(), "x and s must be contiguous" + + M, N = y.shape + grid = ((M * N + BLOCK - 1) // BLOCK,) + + fused_fma_kernel[grid](y, x, s, M, N, *y.stride(), BLOCK) + + return y + + +class CuBLASRefBlockwiseGemm: + """ + A cuBLAS compatible reference implementation of subchannel GEMM. + """ + + def qgemm( + self, + qx: torch.Tensor, + qw: torch.Tensor, + out_dtype: torch.dtype, + demunged_sx: torch.Tensor, + demunged_sw: torch.Tensor, + quant_tile_shape_x: Tuple[int, int], + quant_tile_shape_w: Tuple[int, int], + bias: torch.Tensor | None = None, + out: torch.Tensor | None = None, + accumulate: bool = False, + use_split_accumulator: bool = False, + ) -> torch.Tensor: + # demunge scale shapes for cuBLAS + is_a_1d_scaled = quant_tile_shape_x[0] == 1 + is_b_1d_scaled = quant_tile_shape_w[0] == 1 + M, K = qx.shape + N, K = qw.shape + + # mm_tile_shape = (tile_m, tile_n, tile_k) + mm_tile_shape = ( + quant_tile_shape_x[0], + quant_tile_shape_w[0], + quant_tile_shape_w[1], + ) + if bias is not None and bias.numel(): + # To match cuBLAS more closely when bias is applied, + # the reference accumulates into float32, and cast to + # bfloat16 is deferred until after the GEMM. + out_dtype_for_ref = torch.float32 + else: + out_dtype_for_ref = out_dtype + y = self.qgemm_blockwise_2d( + qx, + qw, + out_dtype_for_ref, + demunged_sx, + demunged_sw, + mm_tile_shape, + use_split_accumulator, + is_a_1d_scaled, + is_b_1d_scaled, + ) + if bias is not None and bias.numel(): + y += bias + y = y.to(dtype=out_dtype) + # cublas accumulation first convert to output dtype, then accumulate. + if accumulate: + assert out is not None + y = y + out + else: + assert out is None, "Output tensor should be None when accumulate is False." + + return y + + @classmethod + def qgemm_blockwise_2d( + cls, + qx: torch.Tensor, + qw: torch.Tensor, + out_dtype: torch.dtype, + sx: torch.Tensor, + sw: torch.Tensor, + mm_tile_shape: Tuple[int, int, int], + use_split_accumulator: bool, + is_a_1d_scaled: bool, + is_b_1d_scaled: bool, + ) -> torch.Tensor: + """ + Difference between cuBLAS and CUTLASS GEMM implementations: + - cuBLAS accumulation equation: use different equation for each scaling mode. + - For accumulation C in epiloge, it first convert C to output dtype, then accumulate. + """ + + M, K = qx.shape + N, K_w = qw.shape + assert K == K_w, "K dimension mismatch between qx and qw" + + tile_len = 128 + # Calculate grid sizes without padding + grid_m = (M + tile_len - 1) // tile_len + grid_n = (N + tile_len - 1) // tile_len + grid_k = (K + tile_len - 1) // tile_len + + block_m, block_n, block_k = mm_tile_shape + scale_m_per_tile = tile_len // block_m + scale_n_per_tile = tile_len // block_n + assert block_k == tile_len, "block_k must be equal to tile_len" + + # Notes on making the reference implementation numerically equivalent to Cast Blockwise FP8 GEMM: + # 1) When using split_accumulate in FP8 GEMM, every 4 QMMA partial accumulation results are accumulated into float32 registers. + # 2) Partial accumulation results are accumulated using FMA (Fused Multiply-Add) instructions to apply scaling factors, as in: y += partial_y * scale + y = torch.zeros(M, N, dtype=torch.float32, device=qx.device) + + # Validate shapes of sx and sw + scale_m_per_tensor = (M + block_m - 1) // block_m + scale_n_per_tensor = (N + block_n - 1) // block_n + assert sx.shape == ( + scale_m_per_tensor, + grid_k, + ), f"sx shape mismatch: expected ({scale_m_per_tensor}, {grid_k}), got {sx.shape}" + assert sw.shape == ( + scale_n_per_tensor, + grid_k, + ), f"sw shape mismatch: expected ({scale_n_per_tensor}, {grid_k}), got {sw.shape}" + + for i in range(grid_m): + m_start = i * tile_len + m_end = min(m_start + tile_len, M) + m_size = m_end - m_start + + for j in range(grid_n): + n_start = j * tile_len + n_end = min(n_start + tile_len, N) + n_size = n_end - n_start + + y_block = y[m_start:m_end, n_start:n_end] + + for k in range(grid_k): + k_start = k * tile_len + k_end = min(k_start + tile_len, K) + k_size = k_end - k_start + + qx_block = ( + qx[m_start:m_end, k_start:k_end].clone().contiguous() + ) # Shape: [m_size, k_size] + qw_block = ( + qw[n_start:n_end, k_start:k_end].clone().contiguous() + ) # Shape: [n_size, k_size] + + # Extract scaling factors for the current blocks + sx_block = sx[i * scale_m_per_tile : (i + 1) * scale_m_per_tile, k].unsqueeze( + -1 + ) + sw_block = sw[j * scale_n_per_tile : (j + 1) * scale_n_per_tile, k].unsqueeze(0) + + # Perform qgemm with scaling factors fused in the GEMM + # Accumulate should be in float32 format, which aligns with the split_accumulate in FP8 GEMM + one = torch.tensor(1.0, dtype=torch.float32, device=qx.device) + y_partial = torch._scaled_mm( + qx_block, + qw_block.t(), + scale_a=one, + scale_b=one, + out_dtype=torch.float32, + use_fast_accum=not use_split_accumulator, + ) + + # Accumulate the partial result + if is_a_1d_scaled and is_b_1d_scaled: + # 1Dx1D + # CuBLAS accumulation equation: y += (y * scale_a) * scale_b + y_partial = y_partial * sx_block + # Fuse multiplication and addition to align with the split_accumulate in FP8 GEMM + # y_block.add_(y_partial, alpha=scale.item()) + fused_fma( + y_block, + y_partial, + sw_block.expand_as(y_partial).contiguous(), + ) + elif not is_a_1d_scaled and is_b_1d_scaled: + # 2Dx1D + # CuBLAS accumulation equation: y += (y * scale_b) * scale_a + y_partial = y_partial * sw_block + fused_fma( + y_block, + y_partial, + sx_block.expand_as(y_partial).contiguous(), + ) + elif is_a_1d_scaled and not is_b_1d_scaled: + # 1Dx2D + # CuBLAS accumulation equation: y += (y * scale_a) * scale_b + y_partial = y_partial * sx_block + fused_fma( + y_block, + y_partial, + sw_block.expand_as(y_partial).contiguous(), + ) + else: + scale = sx_block * sw_block + fused_fma(y_block, y_partial, scale.expand_as(y_partial).contiguous()) + + y = y.to(out_dtype) + return y diff --git a/tests/pytorch/references/blockwise_quantizer_reference.py b/tests/pytorch/references/blockwise_quantizer_reference.py index b98966f514..f5c9dc0e96 100644 --- a/tests/pytorch/references/blockwise_quantizer_reference.py +++ b/tests/pytorch/references/blockwise_quantizer_reference.py @@ -49,6 +49,7 @@ def _pad_inner_to_align(s: torch.Tensor, transpose: bool) -> torch.Tensor: s_t = _pad_inner_to_align(unmunged.scale_t, transpose=tile_shape[0] == 1) return QuantizeResult(unmunged.data, s, unmunged.data_t, s_t) + @classmethod def demunge_scale_shape_from_backend( cls, qtensor_shape: Tuple[int, int], diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py new file mode 100644 index 0000000000..a118c6f81c --- /dev/null +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -0,0 +1,832 @@ +import pytest +import torch +import transformer_engine as te +import transformer_engine_torch as tex + +from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( + Float8BlockQuantizer, + Float8BlockwiseQTensor, +) +from tests.pytorch.references.blockwise_quantizer_reference import CuBLASScaleMunger +from tests.pytorch.references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm + + +def fp8_blockwise_gemm_supported() -> bool: + return float(torch.version.cuda) >= 12.8 + + +def cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + *, + x_columnwise: bool = False, + w_columnwise: bool = False, + use_bias: bool = False, + atol: float = 0.0, + rtol: float = 0.0 +): + if x_dtype == torch.float8_e5m2 and w_dtype == torch.float8_e5m2: + pytest.skip("FP8 GEMM doesn't support both a and b types being torch.float8_e5m2") + if not (is_x_1d_scaled or is_w_1d_scaled): + pytest.skip("FP8 GEMM doesn't support 2dimensional qtile by 2dimensional qtile") + if not fp8_blockwise_gemm_supported(): + pytest.skip("CUDA version does not support blockwise FP8 gemm.") + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + x_shape = (K, M) if x_columnwise else (M, K) + w_shape = (K, N) if w_columnwise else (N, K) + # generate random input and weight + if noise_type == "uniform": + x = torch.rand(x_shape, dtype=torch.float32, device=device) * x_magnitude * 2 - x_magnitude + w = torch.rand(w_shape, dtype=torch.float32, device=device) * w_magnitude * 2 - w_magnitude + elif noise_type == "normal": + x = torch.randn(x_shape, dtype=torch.float32, device=device) * x_magnitude + w = torch.randn(w_shape, dtype=torch.float32, device=device) * w_magnitude + else: + assert False + + # Setup out tensor if accumulate is True + if accumulate: + out = torch.randn((M, N), dtype=out_dtype, device=device) * x_magnitude + else: + out = None + + # Set quantize_op and quantization parameters + x_quant_tile_shape = (1, 128) if is_x_1d_scaled else (128, 128) + w_quant_tile_shape = (1, 128) if is_w_1d_scaled else (128, 128) + x_block_scaling_dim = 1 if is_x_1d_scaled else 2 + w_block_scaling_dim = 1 if is_w_1d_scaled else 2 + x_te_dtype = TE_DType[x_dtype] + w_te_dtype = TE_DType[w_dtype] + x_quantizer = Float8BlockQuantizer( + fp8_dtype=x_te_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=0.0, + force_pow_2_scales=False, + block_scaling_dim=x_block_scaling_dim, + ) + w_quantizer = Float8BlockQuantizer( + fp8_dtype=w_te_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=0.0, + force_pow_2_scales=False, + block_scaling_dim=w_block_scaling_dim, + ) + + # Quantize x and w + qx = x_quantizer.make_empty(x_shape, dtype=x_dtype, device=device, requires_grad=False) + qx = x_quantizer.update_quantized(x, qx) + qw = w_quantizer.make_empty(w_shape, dtype=w_dtype, device=device, requires_grad=False) + qw = w_quantizer.update_quantized(w, qw) + + if not use_bias: + bias = None + else: + bias = torch.randn((1, N), dtype=torch.bfloat16, device=device) + + # Reference GEMM + ref_gemm = CuBLASRefBlockwiseGemm() + scale_decoder = CuBLASScaleMunger() + qx_data = ( + qx._columnwise_data.view(dtype=x_dtype) + if x_columnwise + else qx._rowwise_data.view(dtype=x_dtype) + ) + qw_data = ( + qw._columnwise_data.view(dtype=w_dtype) + if w_columnwise + else qw._rowwise_data.view(dtype=w_dtype) + ) + ref_scales_x = qx._columnwise_scale_inv if x_columnwise else qx._rowwise_scale_inv + ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv + y_ref = ref_gemm.qgemm( + qx=qx_data, + qw=qw_data, + out_dtype=out_dtype, + demunged_sx=CuBLASScaleMunger.demunge_scale_shape_from_backend( + qtensor_shape=(M, K), scales=ref_scales_x, tile_shape=x_quant_tile_shape + ), + demunged_sw=CuBLASScaleMunger.demunge_scale_shape_from_backend( + qtensor_shape=(N, K), scales=ref_scales_w, tile_shape=w_quant_tile_shape + ), + quant_tile_shape_x=x_quant_tile_shape, + quant_tile_shape_w=w_quant_tile_shape, + bias=bias, + out=out.clone() if accumulate else None, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + # Allocate cuBLAS workspace + workspace_size = 0 + workspace = torch.empty(0, dtype=torch.uint8, device=device) + + transa = True if not w_columnwise else False + transb = False if not x_columnwise else True + out_quantizer = None + grad = False + gelu = False + gelu_in = None + + bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] + # cuBLAS GEMM + # return type is out, bias_grad, gelu_input, extra_output + # We are just capturing out. + y = tex.generic_gemm( + qw, + transa, + qx, + transb, + out.clone() if accumulate else None, + out_quantizer, + TE_DType[out_dtype], + bias, + bias_dtype, + gelu, + gelu_in, + grad, + workspace, + workspace.shape[0], + accumulate, + use_split_accumulator, + )[0] + + # just in case of accumulation, make sure y_ref and y are not the same tensor + assert y_ref is not y, "y_ref and y should not be the same tensor" + # Reset nans to zeros because torch.assert_close does not assume nans to be equal + assert not torch.isnan(y_ref.float()).all(), "All elements are nan" + y_ref = torch.where(y_ref.isnan(), torch.zeros_like(y_ref), y_ref) + y = torch.where(y.isnan(), torch.zeros_like(y), y) + + # Check + torch.testing.assert_close(y, y_ref, atol=atol, rtol=rtol) + + +def cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + *, + x_columnwise: bool = False, + w_columnwise: bool = False, + use_bias: bool = False, + use_gelu: bool = False, + use_grad: bool = False, + expected_err_msg="CUBLAS_STATUS_NOT_SUPPORTED", + expected_err_cls=RuntimeError +): + if not fp8_blockwise_gemm_supported(): + pytest.skip("CUDA version does not support blockwise FP8 gemm.") + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + x_shape = (K, M) if x_columnwise else (M, K) + w_shape = (K, N) if w_columnwise else (N, K) + # generate random input and weight + x = torch.rand(x_shape, dtype=torch.float32, device=device) * 2.0 - 1.0 + w = torch.rand(w_shape, dtype=torch.float32, device=device) * 2.0 - 1.0 + + # Setup out tensor if accumulate is True + if accumulate: + out = torch.randn((M, N), dtype=out_dtype, device=device) + else: + out = None + + # Set quantize_op and quantization parameters + x_quant_tile_shape = (1, 128) if is_x_1d_scaled else (128, 128) + w_quant_tile_shape = (1, 128) if is_w_1d_scaled else (128, 128) + x_block_scaling_dim = 1 if is_x_1d_scaled else 2 + w_block_scaling_dim = 1 if is_w_1d_scaled else 2 + x_te_dtype = TE_DType[x_dtype] + w_te_dtype = TE_DType[w_dtype] + x_quantizer = Float8BlockQuantizer( + fp8_dtype=x_te_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=0.0, + force_pow_2_scales=False, + block_scaling_dim=x_block_scaling_dim, + ) + w_quantizer = Float8BlockQuantizer( + fp8_dtype=w_te_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=0.0, + force_pow_2_scales=False, + block_scaling_dim=w_block_scaling_dim, + ) + + # Quantize x and w + qx = x_quantizer.make_empty(x_shape, dtype=x_dtype, device=device, requires_grad=False) + qx = x_quantizer.update_quantized(x, qx) + qw = w_quantizer.make_empty(w_shape, dtype=w_dtype, device=device, requires_grad=False) + qw = w_quantizer.update_quantized(w, qw) + + if not use_bias: + bias = None + else: + bias = torch.randn((1, N), dtype=torch.bfloat16, device=device) + + # Allocate cuBLAS workspace + workspace_size = 0 + workspace = torch.empty(0, dtype=torch.uint8, device=device) + + transa = True if not w_columnwise else False + transb = False if not x_columnwise else True + out_quantizer = None + grad = use_grad + gelu_in = None if not use_gelu else torch.randn((M, N), dtype=out_dtype, device=device) + + bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] + # cuBLAS GEMM + # return type is out, bias_grad, gelu_input, extra_output + # We are just capturing out. + with pytest.raises(expected_err_cls, match=expected_err_msg): + y = tex.generic_gemm( + qw, + transa, + qx, + transb, + out.clone() if accumulate else None, + out_quantizer, + TE_DType[out_dtype], + bias, + bias_dtype, + use_gelu, + gelu_in, + grad, + workspace, + workspace.shape[0], + accumulate, + use_split_accumulator, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (128, 128, 128), + (256, 128, 256), + # non 128x128 divisible input shapes + (16, 128, 128), + (16, 64, 128), + (128, 160, 128), + (320, 128, 336), + (320, 64, 336), + # k > 128 + (256, 256, 256), + (320, 256, 336), + (256, 512, 256), + (256, 1024, 256), + (1024, 1024, 1024), + (1024, 4096, 1024), + (512, 128, 512), + (768, 128, 768), + (1024, 128, 1024), + (1536, 128, 1536), + (2048, 128, 2048), + (4096, 128, 4096), + (4096, 512, 3072), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("noise_type", ["normal", "uniform"], ids=str) +@pytest.mark.parametrize("x_magnitude", [1e-28, 1, 1e3], ids=str) +@pytest.mark.parametrize("w_magnitude", [1], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_cublas_gemm_fp8_blockwise_shape_varying( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +): + cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (256, 128, 256), + # non 128x128 divisible input shapes + (16, 128, 128), + (320, 64, 336), + # k > 128 + (256, 256, 256), + (4096, 128, 4096), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("noise_type", ["normal"], ids=str) +@pytest.mark.parametrize("x_magnitude", [1e-28, 1, 1e3], ids=str) +@pytest.mark.parametrize("w_magnitude", [1], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_cublas_gemm_fp8_blockwise_bias( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +): + rtol = 1e-3 + atol = 0.0 + cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + use_bias=True, + atol=atol, + rtol=rtol, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (256, 128, 256), + # non 128x128 divisible input shapes + (16, 128, 128), + (320, 64, 336), + # k > 128 + (256, 256, 256), + (4096, 128, 4096), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("noise_type", ["normal"], ids=str) +@pytest.mark.parametrize("x_magnitude", [1], ids=str) +@pytest.mark.parametrize("w_magnitude", [1], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +@pytest.mark.parametrize( + "is_x_columnwise, is_w_columnwise", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["colxrow", "colxcol", "rowxcol"], +) +def test_cublas_gemm_fp8_blockwise_columnwise( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + is_x_columnwise, + is_w_columnwise, +): + cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + x_columnwise=is_x_columnwise, + w_columnwise=is_w_columnwise, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (256, 128, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [False], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_split_accumulator_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +) -> None: + cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (256, 128, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_bgrad_not_supported_until_tested( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +) -> None: + # NOTE: This may work, but until it is tested thoroughly, + # testing that the implementation errors. + cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + use_grad=True, + use_bias=True, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (256, 128, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no_bias"]) +@pytest.mark.parametrize("use_grad", [True, False], ids=["grad", "no_grad"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_gelu_not_supported_until_tested( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_bias, + use_grad, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +) -> None: + # NOTE: This may work, but until it is tested thoroughly, + # testing that the implementation errors. + cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + use_grad=use_grad, + use_bias=use_bias, + use_gelu=True, + expected_err_msg=( + "not supported for NVTE_BLOCK_SCALING until further numerical verification" + ), + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + (256, 128, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_illegal_dtype_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +) -> None: + # e5m2 by e5m2 not supported. + cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + (256, 128, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (False, False), + ], + ids=["2Dx2D"], +) +def test_illegal_2D_by_2D_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +) -> None: + # 2D block quantization by 2D block quantization is not supported. + expected_err_msg = "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported" + cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + expected_err_msg=expected_err_msg, + ) + + +@pytest.mark.parametrize( + "M, K, N, legalX1d, legalX2d", + [ + # M dim unconstrained when X is 2D. + (255, 128, 256, False, True), + # K must be multiple of 16 + (256, 120, 256, False, False), + # N must be a multiple of 8 + (256, 128, 252, False, False), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16], ids=str) +@pytest.mark.parametrize("accumulate", [False], ids=["no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (False, True), + (True, True), + ], + ids=["1Dx2D", "2Dx1D", "1Dx1D"], +) +def test_unaligned_shapes( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + legalX1d, + legalX2d, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +) -> None: + legal = legalX1d if is_x_1d_scaled else legalX2d + if not legal: + cublas_gemm_test_constraint_enforced( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + expected_err_msg="dimension requirement", + ) + else: + cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + "uniform", # noise type + 1.0, # x_magnitude + 1.0, # w_magnitude + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + ) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index f19465c44b..a4a0a2c32d 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -55,14 +55,23 @@ inline void CreateCublasHandle(cublasLtHandle_t *handle) { struct GemmParam { void *A; void *B; + // The layout (e.g. TN to call cublas with) cublasOperation_t transA; cublasOperation_t transB; transformer_engine::DType Atype; transformer_engine::DType Btype; void *A_scale_inv; void *B_scale_inv; + // Element stride for A int lda; + // Element stride for B int ldb; + // major and minor number of elements for the + // storage of A, and B of GemmParam + int a_major_dim; + int a_minor_dim; + int b_major_dim; + int b_minor_dim; GemmParam(cublasOperation_t transA, cublasOperation_t transB) : A(nullptr), @@ -74,27 +83,78 @@ struct GemmParam { A_scale_inv(nullptr), B_scale_inv(nullptr), lda(0), - ldb(0) {} + ldb(0), + a_major_dim(0), + a_minor_dim(0), + b_major_dim(0), + b_minor_dim(0) {} }; GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA, const transformer_engine::Tensor &B, const cublasOperation_t transB, - const int k, const int lda, const int ldb) { + int A0, int A1, int B0, int B1) { using namespace transformer_engine; - // FIXME(kwyss): 1x128 by 128x128 GEMM is part of the subchannel design. - // Must either force them both into a common block scaling mode or loosen this - // restriction. NVTE_CHECK(A.scaling_mode == B.scaling_mode, "Inputs A and B to GEMM need to have the same scaling mode!"); NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!"); NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!"); GemmParam ret(transA, transB); - ret.lda = lda; - ret.ldb = ldb; + bool transa_bool = transA == CUBLAS_OP_T; + bool transb_bool = transB == CUBLAS_OP_T; + + int arch = cuda::sm_arch(cuda::current_device()); + if (A.scaling_mode == NVTE_BLOCK_SCALING) { + // For this scaling mode, the quantizer stores + // rowwise data and transposes the data for columnwise + // data so the physical layout is always row major + // and the transA and transB values to pass to cublas + // should always be TN. + + ret.a_major_dim = transa_bool ? A0 : A1; + ret.a_minor_dim = transa_bool ? A1 : A0; + ret.b_major_dim = transb_bool ? B1 : B0; + ret.b_minor_dim = transb_bool ? B0 : B1; + + ret.transA = CUBLAS_OP_T; + ret.transB = CUBLAS_OP_N; + ret.lda = ret.a_minor_dim; + ret.ldb = ret.b_minor_dim; + + NVTE_CHECK(ret.a_minor_dim == ret.b_minor_dim, + "Inner dimension must be equal for NVTE_BLOCK_SCALING Gemm."); + + } else { + // In these scaling modes, the physical layout of + // the tensor will always line up with transA and + // transB, which are passed along to cuBLAS. + // NOTE: There is some logic below that may edit this + // decision for A and B depending on dtype and arch. + const int m = transa_bool ? A0 : A1; + const int k = transa_bool ? A1 : A0; + const int n = transb_bool ? B1 : B0; + ret.a_major_dim = A0; + ret.a_minor_dim = A1; + ret.b_major_dim = B0; + ret.b_minor_dim = B1; + + int lda, ldb; + if (transa_bool && !transb_bool) { // TN + lda = k; + ldb = k; + } else if (!transa_bool && !transb_bool) { // NN + lda = m; + ldb = k; + } else if (!transa_bool && transb_bool) { // NT + lda = m; + ldb = n; + } else { // TT + NVTE_ERROR("TT layout not allowed."); + } + ret.lda = lda; + ret.ldb = ldb; + } - // FIXME(kwyss): 128x128 by 128x128 GEMMs and 1x128 by 128x128 GEMMs need cases - // or need to be treated as `is_tensor_scaling`. if (is_tensor_scaling(A.scaling_mode)) { ret.A = A.data.dptr; ret.A_scale_inv = A.scale_inv.dptr; @@ -103,14 +163,15 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla } else { ret.Atype = A.has_columnwise_data() ? A.columnwise_data.dtype : A.data.dtype; if (is_fp8_dtype(ret.Atype)) { - int arch = cuda::sm_arch(cuda::current_device()); if (arch < 100) { // Hopper and Ada - we need to use columnwise_data and change transA NVTE_CHECK(A.has_columnwise_data(), "Input A is not suitable for columnwise usage!"); ret.A = A.columnwise_data.dptr; ret.transA = CUBLAS_OP_T; ret.A_scale_inv = A.columnwise_scale_inv.dptr; - ret.lda = k; + ret.a_major_dim = A1; + ret.a_minor_dim = A0; + ret.lda = A0; } } } @@ -119,29 +180,63 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla if (transB == CUBLAS_OP_T) { ret.Btype = B.has_columnwise_data() ? B.columnwise_data.dtype : B.data.dtype; if (is_fp8_dtype(ret.Btype)) { - int arch = cuda::sm_arch(cuda::current_device()); if (arch < 100) { // Hopper and Ada - we need to use columnwise_data and change transA NVTE_CHECK(B.has_columnwise_data(), "Input B is not suitable for columnwise usage!"); ret.B = B.columnwise_data.dptr; ret.transB = CUBLAS_OP_N; ret.B_scale_inv = B.columnwise_scale_inv.dptr; - ret.ldb = k; + ret.ldb = B0; + ret.b_major_dim = B1; + ret.b_minor_dim = B0; } } } else { ret.Btype = B.data.dtype; } } else { + // MXF8 scaling or NVTE_BLOCK_SCALING // If not tensor scaling (which includes also high precision types), we need to // use the proper version of data - // We leave the transA/B values as is, since Blackwell supports transposes - ret.A = transA ? A.data.dptr : A.columnwise_data.dptr; - ret.Atype = transA ? A.data.dtype : A.columnwise_data.dtype; - ret.A_scale_inv = transA ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; - ret.B = transB ? B.columnwise_data.dptr : B.data.dptr; - ret.Btype = transB ? B.columnwise_data.dtype : B.data.dtype; - ret.B_scale_inv = transB ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; + // For MXF8, we leave the transA/B values as is, since Blackwell supports transposes + // but for NVTE_BLOCK_SCALING, we force transA/B to TN since the quantizers + // store data in that manner and the GEMM requires that layout. + if (A.scaling_mode == NVTE_BLOCK_SCALING) { + if (transA == CUBLAS_OP_T) { + NVTE_CHECK(A.has_data(), "Input A is not suitable for rowwise usage!"); + } else { + NVTE_CHECK(A.has_columnwise_data(), "Input A is not suitable for columnwise usage!"); + } + if (transB == CUBLAS_OP_N) { + NVTE_CHECK(B.has_data(), "Input B is not suitable for rowwise usage!"); + } else { + NVTE_CHECK(B.has_columnwise_data(), "Input B is not suitable for columnwise usage!"); + } + // Requirements from + // https://docs.nvidia.com/cuda/cublas/#tensor-core-usage + NVTE_CHECK((ret.a_minor_dim % 16) == 0, + "Inner dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); + // Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement. + // Smallest supported CType is 2 bytes in this scaling mode. + NVTE_CHECK((ret.a_major_dim % 8) == 0, + "Outer dimension requirement on A for NVTE_BLOCK_SCALING GEMM. Caller must pad."); + // Observed this requirement only present for B tensor is 1D quantized. + if (B.block_scaling_dim == 1) { + NVTE_CHECK( + (ret.b_major_dim % 8) == 0, + "Outer dimension requirement on B for NVTE_BLOCK_SCALING GEMM. Caller must pad."); + } + NVTE_CHECK((ret.lda % 16) == 0, + "A tensor stride requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); + NVTE_CHECK((ret.ldb % 16) == 0, + "B tensor stride requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); + } + ret.A = transa_bool ? A.data.dptr : A.columnwise_data.dptr; + ret.Atype = transa_bool ? A.data.dtype : A.columnwise_data.dtype; + ret.A_scale_inv = transa_bool ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; + ret.B = transb_bool ? B.columnwise_data.dptr : B.data.dptr; + ret.Btype = transb_bool ? B.columnwise_data.dtype : B.data.dtype; + ret.B_scale_inv = transb_bool ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; } return ret; } @@ -153,18 +248,23 @@ namespace transformer_engine { using cublasHandleManager = detail::HandleManager; void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, - const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda, - int ldb, int ldd, cublasOperation_t transa, cublasOperation_t transb, bool grad, - void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, + const Tensor *inputBias, Tensor *outputPreGelu, int A0, int A1, int B0, int B1, + cublasOperation_t transa, cublasOperation_t transb, bool grad, void *workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool gemm_producer, const Tensor *inputCounter, cudaStream_t stream) { + const int m = transa == CUBLAS_OP_T ? A0 : A1; + const int k = transa == CUBLAS_OP_T ? A1 : A0; + const int n = transb == CUBLAS_OP_T ? B1 : B0; + const int ldd = m; // Return immediately if GEMM is trivial if (m <= 0 || n <= 0) { return; } NVTE_CHECK(k > 0); - const GemmParam ¶m = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, k, lda, ldb); + const GemmParam ¶m = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, A0, A1, B0, B1); + void *C = outputD->data.dptr; void *D = outputD->data.dptr; void *D_scale = outputD->scale.dptr; @@ -222,10 +322,13 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, } // Create matrix descriptors. Not setting any extra attributes. - NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, A_type, param.transA == CUBLAS_OP_N ? m : k, - param.transA == CUBLAS_OP_N ? k : m, param.lda)); - NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, param.transB == CUBLAS_OP_N ? k : n, - param.transB == CUBLAS_OP_N ? n : k, param.ldb)); + NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate( + &Adesc, A_type, param.transA == CUBLAS_OP_N ? param.a_major_dim : param.a_minor_dim, + param.transA == CUBLAS_OP_N ? param.a_minor_dim : param.a_major_dim, param.lda)); + NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate( + &Bdesc, B_type, param.transB == CUBLAS_OP_N ? param.b_minor_dim : param.b_major_dim, + param.transB == CUBLAS_OP_N ? param.b_major_dim : param.b_minor_dim, param.ldb)); + NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd)); NVTE_CHECK_CUBLAS(cublasLtMatmulDescCreate(&operationDesc, gemm_compute_type, CUDA_R_32F)); @@ -249,12 +352,10 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(fastAccuMode))); - // FIXME(kwyss): Add binding code for 128x128 block quantized 1x128 block quantized - // GEMM types. - // Scaling factors. #if CUDA_VERSION >= 12080 - cublasLtMatmulMatrixScale_t scaling_mode; + cublasLtMatmulMatrixScale_t scaling_mode_a; + cublasLtMatmulMatrixScale_t scaling_mode_b; #endif if ((is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode))) { void *A_scale_inverse = param.A_scale_inv; @@ -266,8 +367,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &B_scale_inverse, sizeof(B_scale_inverse))); #if CUDA_VERSION >= 12080 - scaling_mode = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; - } else if ((is_block_scaling(inputA->scaling_mode) && is_block_scaling(inputB->scaling_mode))) { + scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; + scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; + } else if ((is_mxfp_scaling(inputA->scaling_mode) && is_mxfp_scaling(inputB->scaling_mode))) { fp8e8m0 *A_scale_inverse = reinterpret_cast(param.A_scale_inv); fp8e8m0 *B_scale_inverse = reinterpret_cast(param.B_scale_inv); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, @@ -276,7 +378,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &B_scale_inverse, sizeof(B_scale_inverse))); - scaling_mode = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; + scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; + scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; // Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling. // CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set. if (cublasLtGetVersion() <= 120803) { @@ -285,6 +388,30 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride, sizeof(dummy_a_vec_stride))); } +#if CUDA_VERSION >= 12080 + } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING) && + (inputB->scaling_mode == NVTE_BLOCK_SCALING)) { + float *A_scale_inverse = reinterpret_cast(param.A_scale_inv); + float *B_scale_inverse = reinterpret_cast(param.B_scale_inv); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, + &A_scale_inverse, sizeof(A_scale_inverse))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + &B_scale_inverse, sizeof(B_scale_inverse))); + int block_scaling_dim_a = inputA->block_scaling_dim; + int block_scaling_dim_b = inputB->block_scaling_dim; + NVTE_CHECK((block_scaling_dim_a == 1 && block_scaling_dim_b == 1) || + (block_scaling_dim_a == 1 && block_scaling_dim_b == 2) || + (block_scaling_dim_a == 2 && block_scaling_dim_b == 1), + "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported got " + + std::to_string(block_scaling_dim_a) + " x " + + std::to_string(block_scaling_dim_b)); + scaling_mode_a = block_scaling_dim_a == 1 ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F + : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; + scaling_mode_b = block_scaling_dim_b == 1 ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F + : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; +#endif #endif } else { NVTE_ERROR("Not implemented scaling modes: " + to_string(inputA->scaling_mode) + " and " + @@ -293,9 +420,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, #if CUDA_VERSION >= 12080 NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scaling_mode, sizeof(scaling_mode))); + operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scaling_mode_a, sizeof(scaling_mode_a))); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scaling_mode, sizeof(scaling_mode))); + operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scaling_mode_b, sizeof(scaling_mode_b))); #endif if (is_fp8_dtype(outputD->data.dtype)) { // Accumulation mode not supported for FP8 output @@ -305,8 +432,11 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax))); #if CUDA_VERSION >= 12080 - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_MODE, &scaling_mode, sizeof(scaling_mode))); + // NOTE: In all current cases where FP8 output is supported, the input is + // scaled identically to the output. + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_D_SCALE_MODE, + &scaling_mode_a, sizeof(scaling_mode_a))); #endif // For FP8 output, cuBLAS requires C_type to match bias_type and // be FP16/BF16 @@ -364,6 +494,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type))); } + if ((inputA->scaling_mode == NVTE_BLOCK_SCALING) && + (inputB->scaling_mode == NVTE_BLOCK_SCALING)) { + NVTE_CHECK((epilogue == CUBLASLT_EPILOGUE_DEFAULT || epilogue == CUBLASLT_EPILOGUE_BIAS || + epilogue == CUBLASLT_EPILOGUE_BGRADB), + "Epilogue (gelu fusion) not supported for NVTE_BLOCK_SCALING until further " + "numerical verification."); + } + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); @@ -411,7 +549,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED, "Unable to find suitable cuBLAS GEMM algorithm"); NVTE_CHECK_CUBLAS(status); - if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms"); // D = alpha * (A * B) + beta * C @@ -474,27 +611,7 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons const size_t B0 = inputB->flat_first_dim(); const size_t B1 = inputB->flat_last_dim(); - const int m = transa ? A0 : A1; - const int k = transa ? A1 : A0; - const int n = transb ? B1 : B0; - int lda, ldb, ldd; - if (transa && !transb) { // TN - lda = k; - ldb = k; - ldd = m; - } else if (!transa && !transb) { // NN - lda = m; - ldb = k; - ldd = m; - } else if (!transa && transb) { // NT - lda = m; - ldb = n; - ldd = m; - } else { // TT - NVTE_ERROR("TT layout not allowed."); - } - - cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd, + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, A0, A1, B0, B1, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream); @@ -525,28 +642,8 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) && is_delayed_tensor_scaling(inputB->scaling_mode), "Atomic GEMM only supports delayed scaling."); - - const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1]; - const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0]; - const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0]; - int lda, ldb, ldd; - if (transa && !transb) { // TN - lda = k; - ldb = k; - ldd = m; - } else if (!transa && !transb) { // NN - lda = m; - ldb = k; - ldd = m; - } else if (!transa && transb) { // NT - lda = m; - ldb = n; - ldd = m; - } else { // TT - NVTE_ERROR("TT layout not allowed."); - } - - cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd, + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, inputA->data.shape[0], + inputA->data.shape[1], inputB->data.shape[0], inputB->data.shape[1], (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream); diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp index dae39d82bf..f6b6ae22c2 100644 --- a/transformer_engine/common/normalization/layernorm/ln_api.cpp +++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp @@ -27,7 +27,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size const int multiprocessorCount, const bool zero_centered_gamma, cudaStream_t stream) { if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) && - !is_block_scaling(z->scaling_mode)) { + !is_mxfp_scaling(z->scaling_mode)) { NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); } @@ -57,7 +57,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size NVTE_Norm_Backend norm_backend; bool is_aligned = true; - bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode); + bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); if (cudnn_backend) { // TODO: add check for GPU ARCH diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index 8519fe1b64..c56f9ef407 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -23,7 +23,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens Tensor *rsigma, Tensor *workspace, const int multiprocessorCount, const bool zero_centered_gamma, cudaStream_t stream) { if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) && - !is_block_scaling(z->scaling_mode)) { + !is_mxfp_scaling(z->scaling_mode)) { NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); } @@ -47,7 +47,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens NVTE_Norm_Backend norm_backend; bool is_aligned = true; - bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode); + bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); bool training = is_delayed_tensor_scaling(z->scaling_mode) || (z->columnwise_data).dptr != nullptr; diff --git a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp index cb2121a457..e8e8b06a4c 100644 --- a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp @@ -112,6 +112,13 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise); ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape); } + if (!tensor.attr("_quantizer").is_none()) { + // Some calls to makeTransformerEngineTensor pass a NoneQuantizer. + // The quantizer stores settings like block_scaling_dim that are important. + // and are stored indirectly via the quantizer. + auto tensor_meta_quantizer = CreateQuantizer(tensor.attr("_quantizer")); + tensor_meta_quantizer->set_quantization_params(&ret); + } quantizer->set_quantization_params(&ret); return ret; } From 522ffbe14fed6150547e79e73f3edef005d740b9 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Mon, 10 Mar 2025 16:50:56 -0700 Subject: [PATCH 05/29] Update NVTE_BLOCK_SCALING for GEMM. Signed-off-by: Keith Wyss --- .../common/gemm/cublaslt_gemm.cu | 33 +++++++++---------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index a4a0a2c32d..ff179f3569 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -94,8 +94,10 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla const transformer_engine::Tensor &B, const cublasOperation_t transB, int A0, int A1, int B0, int B1) { using namespace transformer_engine; - NVTE_CHECK(A.scaling_mode == B.scaling_mode, - "Inputs A and B to GEMM need to have the same scaling mode!"); + NVTE_CHECK(A.scaling_mode == B.scaling_mode || + (A.scaling_mode == NVTE_BLOCK_SCALING_1D && B.scaling_mode == NVTE_BLOCK_SCALING_2D) || + (A.scaling_mode == NVTE_BLOCK_SCALING_2D && B.scaling_mode == NVTE_BLOCK_SCALING_1D), + "Inputs A and B to GEMM need to have compatible scaling modes!"); NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!"); NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!"); GemmParam ret(transA, transB); @@ -104,7 +106,9 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla bool transb_bool = transB == CUBLAS_OP_T; int arch = cuda::sm_arch(cuda::current_device()); - if (A.scaling_mode == NVTE_BLOCK_SCALING) { + int a_major_dim; + int b_major_dim; + if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) { // For this scaling mode, the quantizer stores // rowwise data and transposes the data for columnwise // data so the physical layout is always row major @@ -201,7 +205,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla // For MXF8, we leave the transA/B values as is, since Blackwell supports transposes // but for NVTE_BLOCK_SCALING, we force transA/B to TN since the quantizers // store data in that manner and the GEMM requires that layout. - if (A.scaling_mode == NVTE_BLOCK_SCALING) { + if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) { if (transA == CUBLAS_OP_T) { NVTE_CHECK(A.has_data(), "Input A is not suitable for rowwise usage!"); } else { @@ -221,7 +225,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla NVTE_CHECK((ret.a_major_dim % 8) == 0, "Outer dimension requirement on A for NVTE_BLOCK_SCALING GEMM. Caller must pad."); // Observed this requirement only present for B tensor is 1D quantized. - if (B.block_scaling_dim == 1) { + if (B.scaling_mode == NVTE_BLOCK_SCALING_1D) { NVTE_CHECK( (ret.b_major_dim % 8) == 0, "Outer dimension requirement on B for NVTE_BLOCK_SCALING GEMM. Caller must pad."); @@ -389,8 +393,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, sizeof(dummy_a_vec_stride))); } #if CUDA_VERSION >= 12080 - } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING) && - (inputB->scaling_mode == NVTE_BLOCK_SCALING)) { + } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D || inputA->scaling_mode == inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) && + (inputB->scaling_mode == NVTE_BLOCK_SCALING_1D || inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) { float *A_scale_inverse = reinterpret_cast(param.A_scale_inv); float *B_scale_inverse = reinterpret_cast(param.B_scale_inv); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, @@ -399,17 +403,12 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &B_scale_inverse, sizeof(B_scale_inverse))); - int block_scaling_dim_a = inputA->block_scaling_dim; - int block_scaling_dim_b = inputB->block_scaling_dim; - NVTE_CHECK((block_scaling_dim_a == 1 && block_scaling_dim_b == 1) || - (block_scaling_dim_a == 1 && block_scaling_dim_b == 2) || - (block_scaling_dim_a == 2 && block_scaling_dim_b == 1), - "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported got " + - std::to_string(block_scaling_dim_a) + " x " + - std::to_string(block_scaling_dim_b)); - scaling_mode_a = block_scaling_dim_a == 1 ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F + NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D && + inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)), + "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported got 2D by 2D"); + scaling_mode_a = inputA->scaling_mode == NVTE_BLOCK_SCALING_1D ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; - scaling_mode_b = block_scaling_dim_b == 1 ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F + scaling_mode_b = inputB->scaling_mode == NVTE_BLOCK_SCALING_1D ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; #endif #endif From d7e1fce86c4179554e57e15279284d8d119cbb65 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Thu, 6 Mar 2025 11:17:27 -0800 Subject: [PATCH 06/29] Gate feature on CUDA 12.9 Signed-off-by: Keith Wyss --- tests/pytorch/test_float8_blockwise_gemm_exact.py | 2 +- transformer_engine/common/gemm/cublaslt_gemm.cu | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py index a118c6f81c..3da9e95c17 100644 --- a/tests/pytorch/test_float8_blockwise_gemm_exact.py +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -13,7 +13,7 @@ def fp8_blockwise_gemm_supported() -> bool: - return float(torch.version.cuda) >= 12.8 + return float(torch.version.cuda) >= 12.9 def cublas_gemm_fp8_blockwise_case( diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index ff179f3569..c51397ac6e 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -392,8 +392,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride, sizeof(dummy_a_vec_stride))); } -#if CUDA_VERSION >= 12080 - } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D || inputA->scaling_mode == inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) && +#if CUDA_VERSION >= 12090 +else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D || inputA->scaling_mode == inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) && (inputB->scaling_mode == NVTE_BLOCK_SCALING_1D || inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) { float *A_scale_inverse = reinterpret_cast(param.A_scale_inv); float *B_scale_inverse = reinterpret_cast(param.B_scale_inv); From f212c81dc2cd93da2e363689b90276432eb7d5fa Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Mon, 10 Mar 2025 18:18:18 -0700 Subject: [PATCH 07/29] Gemm typo. Signed-off-by: Keith Wyss --- transformer_engine/common/gemm/cublaslt_gemm.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index c51397ac6e..5031db539f 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -393,7 +393,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, sizeof(dummy_a_vec_stride))); } #if CUDA_VERSION >= 12090 -else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D || inputA->scaling_mode == inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) && + } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D || inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) && (inputB->scaling_mode == NVTE_BLOCK_SCALING_1D || inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) { float *A_scale_inverse = reinterpret_cast(param.A_scale_inv); float *B_scale_inverse = reinterpret_cast(param.B_scale_inv); From 48b2d57923cb14510e42063c8b566a8c40bd95d4 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Mon, 10 Mar 2025 17:40:24 -0700 Subject: [PATCH 08/29] Remove unecessary type converter change. Signed-off-by: Keith Wyss --- .../pytorch/csrc/extensions/type_converters.cpp | 7 ------- 1 file changed, 7 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp index e8e8b06a4c..cb2121a457 100644 --- a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp @@ -112,13 +112,6 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise); ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape); } - if (!tensor.attr("_quantizer").is_none()) { - // Some calls to makeTransformerEngineTensor pass a NoneQuantizer. - // The quantizer stores settings like block_scaling_dim that are important. - // and are stored indirectly via the quantizer. - auto tensor_meta_quantizer = CreateQuantizer(tensor.attr("_quantizer")); - tensor_meta_quantizer->set_quantization_params(&ret); - } quantizer->set_quantization_params(&ret); return ret; } From 57615893b8f0fe67b2ea5c1f6a8f30160d6699d2 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Tue, 11 Mar 2025 13:24:05 -0700 Subject: [PATCH 09/29] Reflect epilogue availability and test supported epilogues. Signed-off-by: Keith Wyss --- .../test_float8_blockwise_gemm_exact.py | 130 +++++++++++++++--- .../common/gemm/cublaslt_gemm.cu | 37 ++--- 2 files changed, 134 insertions(+), 33 deletions(-) diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py index 3da9e95c17..c52ced214d 100644 --- a/tests/pytorch/test_float8_blockwise_gemm_exact.py +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -34,6 +34,8 @@ def cublas_gemm_fp8_blockwise_case( x_columnwise: bool = False, w_columnwise: bool = False, use_bias: bool = False, + use_gelu: bool = False, + use_grad: bool = False, atol: float = 0.0, rtol: float = 0.0 ): @@ -67,6 +69,7 @@ def cublas_gemm_fp8_blockwise_case( else: out = None + assert not (use_bias and use_grad), "Bias grad not supported by GEMM" # Set quantize_op and quantization parameters x_quant_tile_shape = (1, 128) if is_x_1d_scaled else (128, 128) w_quant_tile_shape = (1, 128) if is_w_1d_scaled else (128, 128) @@ -142,9 +145,9 @@ def cublas_gemm_fp8_blockwise_case( transa = True if not w_columnwise else False transb = False if not x_columnwise else True out_quantizer = None - grad = False - gelu = False - gelu_in = None + assert not (use_gelu and use_bias), "Bias and GELU not supported by GEMM" + aux_tensor = torch.randn((M, N), dtype=out_dtype, device=device) if use_gelu else None + aux_tensor_ref = aux_tensor.clone() if use_gelu else None bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] # cuBLAS GEMM @@ -160,9 +163,9 @@ def cublas_gemm_fp8_blockwise_case( TE_DType[out_dtype], bias, bias_dtype, - gelu, - gelu_in, - grad, + use_gelu, + aux_tensor, + use_grad, workspace, workspace.shape[0], accumulate, @@ -176,8 +179,25 @@ def cublas_gemm_fp8_blockwise_case( y_ref = torch.where(y_ref.isnan(), torch.zeros_like(y_ref), y_ref) y = torch.where(y.isnan(), torch.zeros_like(y), y) - # Check - torch.testing.assert_close(y, y_ref, atol=atol, rtol=rtol) + if use_gelu: + # Check + if use_grad: + # With use_grad, GEMM should use aux tensor to calculate + # gradient + gelu_ref = tex.dgelu(y_ref, aux_tensor_ref, None) + # TODO: How do we decide whether this is acceptably close? + # Could also try to put the activation inside the reference + # before the output cast to see different tolerances. + torch.testing.assert_close(y, gelu_ref, atol=1e-3, rtol=1e-2) + else: + # aux tensor is pre-gelu aux output. Verify against y_ref. + torch.testing.assert_close(aux_tensor, y_ref, atol=atol, rtol=rtol) + act = torch.nn.GELU() + gelu_ref = act(y_ref) + # gelu_ref = tex.gelu(y_ref, None) + torch.testing.assert_close(y, gelu_ref, atol=atol, rtol=rtol) + else: + torch.testing.assert_close(y, y_ref, atol=atol, rtol=rtol) def cublas_gemm_test_constraint_enforced( @@ -509,6 +529,84 @@ def test_cublas_gemm_fp8_blockwise_columnwise( ) +@pytest.mark.parametrize( + "M, K, N", + [ + # k = 128 + (256, 128, 256), + # non 128x128 divisible input shapes + (16, 128, 128), + (320, 64, 336), + # k > 128 + (256, 256, 256), + (4096, 128, 4096), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("noise_type", ["normal"], ids=str) +@pytest.mark.parametrize("x_magnitude", [1], ids=str) +@pytest.mark.parametrize("w_magnitude", [1], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +@pytest.mark.parametrize( + "use_grad", + [ + True, + ], + ids=["grad"], +) +def test_cublas_gemm_fp8_gelu( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + use_grad, +): + # NOTE: cuBLAS doesn't complain with not use_grad, but the tests don't succeed + # so the epilogue is disabled on the transformer engine side. + if not use_grad and not (is_x_1d_scaled and not is_w_1d_scaled): + pytest.skip( + "CUBLASLT_EPILOGUE_GELU_AUX epilogue is only supported for 1Dx2D (cuBLAS 2Dx1D)." + ) + cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + use_gelu=True, + use_grad=use_grad, + ) + + @pytest.mark.parametrize( "M, K, N", [ @@ -577,7 +675,7 @@ def test_split_accumulator_enforced( ], ids=["1Dx2D", "1Dx1D", "2Dx1D"], ) -def test_bgrad_not_supported_until_tested( +def test_bgrad_not_supported( x_dtype, w_dtype, out_dtype, @@ -589,8 +687,7 @@ def test_bgrad_not_supported_until_tested( is_x_1d_scaled, is_w_1d_scaled, ) -> None: - # NOTE: This may work, but until it is tested thoroughly, - # testing that the implementation errors. + # NOTE: BGRAD epilogue is not supported for fp8. cublas_gemm_test_constraint_enforced( x_dtype, w_dtype, @@ -604,6 +701,7 @@ def test_bgrad_not_supported_until_tested( is_w_1d_scaled, use_grad=True, use_bias=True, + expected_err_msg="Epilogue requested outside of the available", ) @@ -630,7 +728,7 @@ def test_bgrad_not_supported_until_tested( ], ids=["1Dx2D", "1Dx1D", "2Dx1D"], ) -def test_gelu_not_supported_until_tested( +def test_gelu_unsupported_cases_error( x_dtype, w_dtype, out_dtype, @@ -644,8 +742,8 @@ def test_gelu_not_supported_until_tested( is_x_1d_scaled, is_w_1d_scaled, ) -> None: - # NOTE: This may work, but until it is tested thoroughly, - # testing that the implementation errors. + if use_grad and not use_bias: + pytest.skip("DGELU epilogue is supported.") cublas_gemm_test_constraint_enforced( x_dtype, w_dtype, @@ -660,9 +758,7 @@ def test_gelu_not_supported_until_tested( use_grad=use_grad, use_bias=use_bias, use_gelu=True, - expected_err_msg=( - "not supported for NVTE_BLOCK_SCALING until further numerical verification" - ), + expected_err_msg="Epilogue requested outside of the available", ) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 5031db539f..419230ecdc 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -94,10 +94,11 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla const transformer_engine::Tensor &B, const cublasOperation_t transB, int A0, int A1, int B0, int B1) { using namespace transformer_engine; - NVTE_CHECK(A.scaling_mode == B.scaling_mode || - (A.scaling_mode == NVTE_BLOCK_SCALING_1D && B.scaling_mode == NVTE_BLOCK_SCALING_2D) || - (A.scaling_mode == NVTE_BLOCK_SCALING_2D && B.scaling_mode == NVTE_BLOCK_SCALING_1D), - "Inputs A and B to GEMM need to have compatible scaling modes!"); + NVTE_CHECK( + A.scaling_mode == B.scaling_mode || + (A.scaling_mode == NVTE_BLOCK_SCALING_1D && B.scaling_mode == NVTE_BLOCK_SCALING_2D) || + (A.scaling_mode == NVTE_BLOCK_SCALING_2D && B.scaling_mode == NVTE_BLOCK_SCALING_1D), + "Inputs A and B to GEMM need to have compatible scaling modes!"); NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!"); NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!"); GemmParam ret(transA, transB); @@ -393,8 +394,10 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, sizeof(dummy_a_vec_stride))); } #if CUDA_VERSION >= 12090 - } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D || inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) && - (inputB->scaling_mode == NVTE_BLOCK_SCALING_1D || inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) { + } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D || + inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) && + (inputB->scaling_mode == NVTE_BLOCK_SCALING_1D || + inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) { float *A_scale_inverse = reinterpret_cast(param.A_scale_inv); float *B_scale_inverse = reinterpret_cast(param.B_scale_inv); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, @@ -404,12 +407,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &B_scale_inverse, sizeof(B_scale_inverse))); NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D && - inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)), + inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)), "Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported got 2D by 2D"); - scaling_mode_a = inputA->scaling_mode == NVTE_BLOCK_SCALING_1D ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F - : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; - scaling_mode_b = inputB->scaling_mode == NVTE_BLOCK_SCALING_1D ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F - : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; + scaling_mode_a = inputA->scaling_mode == NVTE_BLOCK_SCALING_1D + ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F + : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; + scaling_mode_b = inputB->scaling_mode == NVTE_BLOCK_SCALING_1D + ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F + : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; #endif #endif } else { @@ -493,12 +498,12 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type))); } - if ((inputA->scaling_mode == NVTE_BLOCK_SCALING) && - (inputB->scaling_mode == NVTE_BLOCK_SCALING)) { + if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D) || + (inputA->scaling_mode == NVTE_BLOCK_SCALING_2D)) { NVTE_CHECK((epilogue == CUBLASLT_EPILOGUE_DEFAULT || epilogue == CUBLASLT_EPILOGUE_BIAS || - epilogue == CUBLASLT_EPILOGUE_BGRADB), - "Epilogue (gelu fusion) not supported for NVTE_BLOCK_SCALING until further " - "numerical verification."); + epilogue == CUBLASLT_EPILOGUE_DGELU), + "Epilogue requested outside of the available and tested cuBLAS functionality for " + "float8 block scaled GEMM"); } NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, From 07b19b7bfbdd72bf9036cdaeb41532b7faf91eca Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Tue, 11 Mar 2025 17:47:46 -0700 Subject: [PATCH 10/29] GEMM simplifications from recipe branch. Signed-off-by: Keith Wyss --- .../test_float8_blockwise_gemm_exact.py | 5 -- .../common/gemm/cublaslt_gemm.cu | 80 ++++++------------- 2 files changed, 23 insertions(+), 62 deletions(-) diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py index c52ced214d..2b2911c32f 100644 --- a/tests/pytorch/test_float8_blockwise_gemm_exact.py +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -15,7 +15,6 @@ def fp8_blockwise_gemm_supported() -> bool: return float(torch.version.cuda) >= 12.9 - def cublas_gemm_fp8_blockwise_case( x_dtype, w_dtype, @@ -432,8 +431,6 @@ def test_cublas_gemm_fp8_blockwise_bias( is_x_1d_scaled, is_w_1d_scaled, ): - rtol = 1e-3 - atol = 0.0 cublas_gemm_fp8_blockwise_case( x_dtype, w_dtype, @@ -449,8 +446,6 @@ def test_cublas_gemm_fp8_blockwise_bias( is_x_1d_scaled, is_w_1d_scaled, use_bias=True, - atol=atol, - rtol=rtol, ) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 419230ecdc..ed13fccaef 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -62,16 +62,10 @@ struct GemmParam { transformer_engine::DType Btype; void *A_scale_inv; void *B_scale_inv; - // Element stride for A + // ld are leading dimensions or minor dimensions + // in storage int lda; - // Element stride for B int ldb; - // major and minor number of elements for the - // storage of A, and B of GemmParam - int a_major_dim; - int a_minor_dim; - int b_major_dim; - int b_minor_dim; GemmParam(cublasOperation_t transA, cublasOperation_t transB) : A(nullptr), @@ -83,11 +77,7 @@ struct GemmParam { A_scale_inv(nullptr), B_scale_inv(nullptr), lda(0), - ldb(0), - a_major_dim(0), - a_minor_dim(0), - b_major_dim(0), - b_minor_dim(0) {} + ldb(0) {} }; GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA, @@ -116,18 +106,15 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla // and the transA and transB values to pass to cublas // should always be TN. - ret.a_major_dim = transa_bool ? A0 : A1; - ret.a_minor_dim = transa_bool ? A1 : A0; - ret.b_major_dim = transb_bool ? B1 : B0; - ret.b_minor_dim = transb_bool ? B0 : B1; + a_major_dim = transa_bool ? A0 : A1; + b_major_dim = transb_bool ? B1 : B0; + ret.lda = transa_bool ? A1 : A0; + ret.ldb = transb_bool ? B0 : B1; ret.transA = CUBLAS_OP_T; ret.transB = CUBLAS_OP_N; - ret.lda = ret.a_minor_dim; - ret.ldb = ret.b_minor_dim; - NVTE_CHECK(ret.a_minor_dim == ret.b_minor_dim, - "Inner dimension must be equal for NVTE_BLOCK_SCALING Gemm."); + NVTE_CHECK(ret.lda == ret.ldb, "Minor dimension must be equal for NVTE_BLOCK_SCALING Gemm."); } else { // In these scaling modes, the physical layout of @@ -135,29 +122,14 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla // transB, which are passed along to cuBLAS. // NOTE: There is some logic below that may edit this // decision for A and B depending on dtype and arch. - const int m = transa_bool ? A0 : A1; - const int k = transa_bool ? A1 : A0; - const int n = transb_bool ? B1 : B0; - ret.a_major_dim = A0; - ret.a_minor_dim = A1; - ret.b_major_dim = B0; - ret.b_minor_dim = B1; - - int lda, ldb; - if (transa_bool && !transb_bool) { // TN - lda = k; - ldb = k; - } else if (!transa_bool && !transb_bool) { // NN - lda = m; - ldb = k; - } else if (!transa_bool && transb_bool) { // NT - lda = m; - ldb = n; - } else { // TT + a_major_dim = A0; + b_major_dim = B0; + ret.lda = A1; + ret.ldb = B1; + + if (transa_bool && transb_bool) { // TT NVTE_ERROR("TT layout not allowed."); } - ret.lda = lda; - ret.ldb = ldb; } if (is_tensor_scaling(A.scaling_mode)) { @@ -174,8 +146,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.A = A.columnwise_data.dptr; ret.transA = CUBLAS_OP_T; ret.A_scale_inv = A.columnwise_scale_inv.dptr; - ret.a_major_dim = A1; - ret.a_minor_dim = A0; + a_major_dim = A1; ret.lda = A0; } } @@ -191,9 +162,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.B = B.columnwise_data.dptr; ret.transB = CUBLAS_OP_N; ret.B_scale_inv = B.columnwise_scale_inv.dptr; + b_major_dim = B1; ret.ldb = B0; - ret.b_major_dim = B1; - ret.b_minor_dim = B0; } } } else { @@ -219,20 +189,18 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla } // Requirements from // https://docs.nvidia.com/cuda/cublas/#tensor-core-usage - NVTE_CHECK((ret.a_minor_dim % 16) == 0, + NVTE_CHECK((ret.lda % 16) == 0, "Inner dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); // Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement. // Smallest supported CType is 2 bytes in this scaling mode. - NVTE_CHECK((ret.a_major_dim % 8) == 0, + NVTE_CHECK((a_major_dim % 8) == 0, "Outer dimension requirement on A for NVTE_BLOCK_SCALING GEMM. Caller must pad."); // Observed this requirement only present for B tensor is 1D quantized. if (B.scaling_mode == NVTE_BLOCK_SCALING_1D) { NVTE_CHECK( - (ret.b_major_dim % 8) == 0, + (b_major_dim % 8) == 0, "Outer dimension requirement on B for NVTE_BLOCK_SCALING GEMM. Caller must pad."); } - NVTE_CHECK((ret.lda % 16) == 0, - "A tensor stride requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); NVTE_CHECK((ret.ldb % 16) == 0, "B tensor stride requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); } @@ -327,12 +295,10 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, } // Create matrix descriptors. Not setting any extra attributes. - NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate( - &Adesc, A_type, param.transA == CUBLAS_OP_N ? param.a_major_dim : param.a_minor_dim, - param.transA == CUBLAS_OP_N ? param.a_minor_dim : param.a_major_dim, param.lda)); - NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate( - &Bdesc, B_type, param.transB == CUBLAS_OP_N ? param.b_minor_dim : param.b_major_dim, - param.transB == CUBLAS_OP_N ? param.b_major_dim : param.b_minor_dim, param.ldb)); + NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, A_type, param.transA == CUBLAS_OP_N ? m : k, + param.transA == CUBLAS_OP_N ? k : m, param.lda)); + NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, param.transB == CUBLAS_OP_N ? k : n, + param.transB == CUBLAS_OP_N ? n : k, param.ldb)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd)); From c4a41b88cb0e628098b6e821bdd4eb91134ae923 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Fri, 14 Mar 2025 17:24:32 -0700 Subject: [PATCH 11/29] Format py code. Signed-off-by: Keith Wyss --- tests/pytorch/test_float8_blockwise_gemm_exact.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py index 2b2911c32f..022f754444 100644 --- a/tests/pytorch/test_float8_blockwise_gemm_exact.py +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -15,6 +15,7 @@ def fp8_blockwise_gemm_supported() -> bool: return float(torch.version.cuda) >= 12.9 + def cublas_gemm_fp8_blockwise_case( x_dtype, w_dtype, From 51ed2fb604289900f1b0677477f92ae2b5156c42 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Tue, 1 Apr 2025 10:17:29 -0700 Subject: [PATCH 12/29] Update GEMM DGelu tests to match support depending on output dtype. Signed-off-by: Keith Wyss --- tests/pytorch/test_float8_blockwise_gemm_exact.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py index 022f754444..eedd7056c9 100644 --- a/tests/pytorch/test_float8_blockwise_gemm_exact.py +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -540,7 +540,7 @@ def test_cublas_gemm_fp8_blockwise_columnwise( ) @pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) @pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) -@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16], ids=str) @pytest.mark.parametrize("noise_type", ["normal"], ids=str) @pytest.mark.parametrize("x_magnitude", [1], ids=str) @pytest.mark.parametrize("w_magnitude", [1], ids=str) @@ -738,8 +738,12 @@ def test_gelu_unsupported_cases_error( is_x_1d_scaled, is_w_1d_scaled, ) -> None: - if use_grad and not use_bias: - pytest.skip("DGELU epilogue is supported.") + if use_grad and not use_bias and out_dtype == torch.bfloat16: + pytest.skip("DGELU epilogue is supported for bfloat16.") + elif use_grad and not use_bias: + expected_err = "an unsupported value or parameter was passed" + else: + expected_err = "Epilogue requested outside of the available" cublas_gemm_test_constraint_enforced( x_dtype, w_dtype, @@ -754,7 +758,7 @@ def test_gelu_unsupported_cases_error( use_grad=use_grad, use_bias=use_bias, use_gelu=True, - expected_err_msg="Epilogue requested outside of the available", + expected_err_msg=expected_err, ) From e7af1404abd43de0b9bcea30378dff5c0f78e212 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Wed, 2 Apr 2025 11:29:28 -0700 Subject: [PATCH 13/29] Force pow2Scales in GEMM Signed-off-by: Keith Wyss --- tests/pytorch/test_float8_blockwise_gemm_exact.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py index eedd7056c9..728f84fb2f 100644 --- a/tests/pytorch/test_float8_blockwise_gemm_exact.py +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -82,7 +82,7 @@ def cublas_gemm_fp8_blockwise_case( rowwise=True, columnwise=True, amax_epsilon=0.0, - force_pow_2_scales=False, + force_pow_2_scales=True, block_scaling_dim=x_block_scaling_dim, ) w_quantizer = Float8BlockQuantizer( @@ -90,7 +90,7 @@ def cublas_gemm_fp8_blockwise_case( rowwise=True, columnwise=True, amax_epsilon=0.0, - force_pow_2_scales=False, + force_pow_2_scales=True, block_scaling_dim=w_block_scaling_dim, ) @@ -252,7 +252,7 @@ def cublas_gemm_test_constraint_enforced( rowwise=True, columnwise=True, amax_epsilon=0.0, - force_pow_2_scales=False, + force_pow_2_scales=True, block_scaling_dim=x_block_scaling_dim, ) w_quantizer = Float8BlockQuantizer( @@ -260,7 +260,7 @@ def cublas_gemm_test_constraint_enforced( rowwise=True, columnwise=True, amax_epsilon=0.0, - force_pow_2_scales=False, + force_pow_2_scales=True, block_scaling_dim=w_block_scaling_dim, ) From 596a00912553e457909254f5b4217cba27ad0423 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Wed, 2 Apr 2025 11:42:26 -0700 Subject: [PATCH 14/29] Add GEMM test to pytorch test suite. Signed-off-by: Keith Wyss --- qa/L0_pytorch_unittest/test.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 21eaededc4..1206012195 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -32,6 +32,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail " python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" From 4aa6067ef3707e2f023d7bbc7fe42f165b57812d Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Wed, 2 Apr 2025 12:17:59 -0700 Subject: [PATCH 15/29] Add copyright to GEMM test. Signed-off-by: Keith Wyss --- tests/pytorch/test_float8_blockwise_gemm_exact.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py index 728f84fb2f..61cdef742c 100644 --- a/tests/pytorch/test_float8_blockwise_gemm_exact.py +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -1,3 +1,7 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + import pytest import torch import transformer_engine as te From 758dc4a2cc1c4c3476635f87cd16a0b7687643e2 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Thu, 3 Apr 2025 18:05:13 -0700 Subject: [PATCH 16/29] Update import for GEMM test. Signed-off-by: Keith Wyss --- tests/pytorch/test_float8_blockwise_gemm_exact.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py index 61cdef742c..94014d36b5 100644 --- a/tests/pytorch/test_float8_blockwise_gemm_exact.py +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -12,8 +12,8 @@ Float8BlockQuantizer, Float8BlockwiseQTensor, ) -from tests.pytorch.references.blockwise_quantizer_reference import CuBLASScaleMunger -from tests.pytorch.references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm +from references.blockwise_quantizer_reference import CuBLASScaleMunger +from references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm def fp8_blockwise_gemm_supported() -> bool: From 7d5b5d99865501923613f11ba37da30141818a30 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Fri, 4 Apr 2025 10:25:20 -0700 Subject: [PATCH 17/29] Add license. Signed-off-by: Keith Wyss --- tests/pytorch/references/blockwise_fp8_gemm_reference.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/pytorch/references/blockwise_fp8_gemm_reference.py b/tests/pytorch/references/blockwise_fp8_gemm_reference.py index 3487dfb810..5aef986e37 100644 --- a/tests/pytorch/references/blockwise_fp8_gemm_reference.py +++ b/tests/pytorch/references/blockwise_fp8_gemm_reference.py @@ -1,3 +1,7 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + from typing import Tuple import torch From ff884e20fcd7b15130120577279a3f37d644acb5 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Fri, 4 Apr 2025 14:47:16 -0400 Subject: [PATCH 18/29] [JAX] Flatten_axis for quantization and Sharding propagation fixes (#1644) * rename QuantizeAxis to QuantizeLayout, get_layout to get_data_layout, q_axis to q_layout * add fatten_axis option * added gated act to test encoder * sharding constraint fixes * fix padding when flattening first dim needs to be padded * update test sizes so that padding is tested * rm output sharding as it can be done in the flax module * sharding scale_inv for mxfp8 --------- Signed-off-by: Phuong Nguyen --- .../encoder/test_model_parallel_encoder.py | 15 +- tests/jax/test_custom_call_compute.py | 238 ++++++----- tests/jax/test_distributed_layernorm_mlp.py | 50 +-- transformer_engine/jax/activation.py | 1 - .../jax/cpp_extensions/activation.py | 390 +++++++++--------- transformer_engine/jax/cpp_extensions/gemm.py | 23 +- transformer_engine/jax/cpp_extensions/misc.py | 42 +- .../jax/cpp_extensions/normalization.py | 30 +- .../jax/cpp_extensions/quantization.py | 298 +++++++------ .../jax/csrc/extensions/activation.cpp | 169 ++++---- .../jax/csrc/extensions/gemm.cpp | 4 +- transformer_engine/jax/csrc/extensions/misc.h | 2 +- .../jax/csrc/extensions/pybind.cpp | 10 +- .../jax/csrc/extensions/quantization.cpp | 92 +++-- transformer_engine/jax/dense.py | 59 ++- transformer_engine/jax/flax/module.py | 152 ++++--- transformer_engine/jax/layernorm_dense.py | 28 +- transformer_engine/jax/layernorm_mlp.py | 75 +++- .../jax/quantize/dequantizer.py | 19 +- transformer_engine/jax/quantize/quantizer.py | 155 ++++--- .../jax/quantize/scaling_modes.py | 114 +++-- transformer_engine/jax/quantize/tensor.py | 178 ++++++-- transformer_engine/jax/sharding.py | 30 +- 23 files changed, 1288 insertions(+), 886 deletions(-) diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 977c3c2912..7e6605c9fe 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -57,13 +57,14 @@ def __call__(self, x, mask, disable_dropout=False): self_attn_mask_type="padding", enable_relative_embedding=False, enable_sequence_parallel=self.enable_seq_paral, + mlp_activations=("gelu", "linear"), ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = x.reshape(x.shape[0], -1) if self.enable_seq_paral: - # Trigger all-gather to collect a complete tensor alone seqence on each device. + # Trigger all-gather to collect a complete tensor alone sequence on each device. x = jax.lax.with_sharding_constraint( x, jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) ) @@ -459,7 +460,7 @@ def setUpClass(cls): def test_te_bf16(self): """Test Transformer Engine with BF16""" actual = train_and_evaluate(self.args) - assert actual[0] < 0.50 and actual[1] > 0.76 + assert actual[0] < 0.455 and actual[1] > 0.785 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8(self): @@ -467,7 +468,7 @@ def test_te_delayed_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.50 and actual[1] > 0.76 + assert actual[0] < 0.455 and actual[1] > 0.785 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) def test_te_mxfp8(self): @@ -475,14 +476,14 @@ def test_te_mxfp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.50 and actual[1] > 0.76 + assert actual[0] < 0.455 and actual[1] > 0.785 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16_with_sp(self): """Test Transformer Engine with BF16 + SP""" self.args.enable_sp = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.50 and actual[1] > 0.76 + assert actual[0] < 0.455 and actual[1] > 0.785 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8_with_sp(self): @@ -491,7 +492,7 @@ def test_te_delayed_scaling_fp8_with_sp(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.50 and actual[1] > 0.76 + assert actual[0] < 0.455 and actual[1] > 0.785 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) def test_te_mxfp8_with_sp(self): @@ -500,7 +501,7 @@ def test_te_mxfp8_with_sp(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.50 and actual[1] > 0.76 + assert actual[0] < 0.455 and actual[1] > 0.785 if __name__ == "__main__": diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 1efc7e1f3c..4dc07a2eea 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -29,7 +29,7 @@ ScaledTensor, ScalingMode, QuantizerFactory, - QuantizeAxis, + QuantizeLayout, ) from transformer_engine.jax.quantize import helper from transformer_engine.jax.activation import activation @@ -82,8 +82,9 @@ def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor): def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray): if isinstance(a, ScaledTensor1x): - if a.layout == "T": - b_transpose = jnp.transpose(b, (-1, *range(b.ndim - 1))) + if a.data_layout == "T": + flatten_axis = a.data.ndim - a.flatten_axis + b_transpose = jnp.transpose(b, (*range(flatten_axis, b.ndim), *range(flatten_axis))) assert_allclose(a.dequantize(), b_transpose, dtype=a.data.dtype) else: assert_allclose(a.dequantize(), b, dtype=a.data.dtype) @@ -141,7 +142,8 @@ def primitive_func(self, inputs, activation_type, quantizer): def test_act_grad(self, shape, activation_type): key = jax.random.PRNGKey(0) x = jax.random.uniform(key, shape, jnp.float32) - x = jnp.repeat(x, len(activation_type), axis=-1) + x = jnp.expand_dims(x, axis=-2) + x = jnp.repeat(x, len(activation_type), axis=-2) value_n_grad_primitive_func = jit( value_and_grad(self.primitive_func, (0,)), static_argnums=(1,) @@ -159,7 +161,8 @@ def test_act_grad(self, shape, activation_type): @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2]) def test_act_grad_with_delayed_scaling_fp8(self, random_inputs, activation_type, output_type): x = random_inputs - x = jnp.repeat(x, len(activation_type), axis=-1) + x = jnp.expand_dims(x, axis=-2) + x = jnp.repeat(x, len(activation_type), axis=-2) self.activation_type = activation_type value_n_grad_primitive_func = jit( @@ -169,7 +172,7 @@ def test_act_grad_with_delayed_scaling_fp8(self, random_inputs, activation_type, quantizer = QuantizerFactory.create( scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, q_dtype=output_type, - q_axis=QuantizeAxis.ROWWISE, + q_layout=QuantizeLayout.ROWWISE, ) prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, quantizer) @@ -182,19 +185,22 @@ def test_act_grad_with_delayed_scaling_fp8(self, random_inputs, activation_type, @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2]) - @pytest_parametrize_wrapper("q_axis", [QuantizeAxis.ROWWISE, QuantizeAxis.ROWWISE_COLWISE]) + @pytest_parametrize_wrapper( + "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE] + ) def test_act_forward_with_delayed_scaling_fp8( - self, random_inputs, activation_type, output_type, q_axis + self, random_inputs, activation_type, output_type, q_layout ): x = random_inputs - x = jnp.repeat(x, len(activation_type), axis=-1) + x = jnp.expand_dims(x, axis=-2) + x = jnp.repeat(x, len(activation_type), axis=-2) self.activation_type = activation_type te_quantizer, jax_quantizer = QuantizerFactory.create( n_quantizers=2, scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, q_dtype=output_type, - q_axis=q_axis, + q_layout=q_layout, ) te_output = tex.act_lu(x, activation_type, te_quantizer) @@ -203,19 +209,21 @@ def test_act_forward_with_delayed_scaling_fp8( assert_bitwise_scaled_tensors(te_output, jax_output) @pytest.mark.skipif(not is_mxfp8_supported, reason=reason) - @pytest_parametrize_wrapper("shape", [(128, 128)]) + @pytest_parametrize_wrapper("shape", [(2, 64, 1, 256)]) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2]) - @pytest_parametrize_wrapper("q_axis", [QuantizeAxis.ROWWISE, QuantizeAxis.ROWWISE_COLWISE]) + @pytest_parametrize_wrapper( + "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE] + ) def test_act_forward_with_block_scaling_fp8( - self, random_inputs, activation_type, output_type, q_axis + self, random_inputs, activation_type, output_type, q_layout ): x = random_inputs - x = jnp.repeat(x, len(activation_type), axis=-1) + x = jnp.repeat(x, len(activation_type), axis=-2) self.activation_type = activation_type quantizer = QuantizerFactory.create( - scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, q_dtype=output_type, q_axis=q_axis + scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout ) output = tex.act_lu(x, activation_type, quantizer) @@ -324,9 +332,11 @@ def test_norm_grad(self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp @pytest.mark.skipif(not is_fp8_supported, reason=reason) # No Norm FWD E5M2 in TE backend @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn]) - @pytest_parametrize_wrapper("q_axis", [QuantizeAxis.ROWWISE, QuantizeAxis.ROWWISE_COLWISE]) + @pytest_parametrize_wrapper( + "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE] + ) def test_norm_grad_with_delayed_scaling_fp8( - self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_axis + self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_layout ): """ Test transformer_engine.jax.layernorm.layernorm @@ -335,7 +345,9 @@ def test_norm_grad_with_delayed_scaling_fp8( pytest.skip("RMSNorm and zero_centered_gamma is not supported!") quantizer = QuantizerFactory.create( - scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, q_dtype=out_dtype, q_axis=q_axis + scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, + q_dtype=out_dtype, + q_layout=q_layout, ) self._test_norm_grad( n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer @@ -351,7 +363,7 @@ def _test_norm_forward( inp_dtype, out_dtype, scaling_mode, - q_axis, + q_layout, ): key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 3) @@ -363,7 +375,7 @@ def _test_norm_forward( gamma = jnp.asarray(gamma, inp_dtype) quantizer, ref_quantizer = QuantizerFactory.create( - n_quantizers=2, scaling_mode=scaling_mode, q_dtype=out_dtype, q_axis=q_axis + n_quantizers=2, scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout ) if norm_type == "layernorm": beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1) @@ -391,9 +403,11 @@ def _test_norm_forward( @pytest.mark.skipif(not is_fp8_supported, reason=reason) # No Norm FWD E5M2 in TE backend @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn]) - @pytest_parametrize_wrapper("q_axis", [QuantizeAxis.ROWWISE, QuantizeAxis.ROWWISE_COLWISE]) + @pytest_parametrize_wrapper( + "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE] + ) def test_norm_forward_with_delayed_scaling_fp8( - self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_axis + self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_layout ): if norm_type == "rmsnorm" and zero_centered_gamma is True: pytest.skip("RMSNorm and zero_centered_gamma is not supported!") @@ -407,7 +421,7 @@ def test_norm_forward_with_delayed_scaling_fp8( inp_dtype=inp_dtype, out_dtype=out_dtype, scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, - q_axis=q_axis, + q_layout=q_layout, ) @pytest.mark.skipif(not is_mxfp8_supported, reason=reason) @@ -424,7 +438,7 @@ def test_norm_forward_with_block_scaling_fp8( inp_dtype=inp_dtype, out_dtype=out_dtype, scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, - q_axis=QuantizeAxis.ROWWISE_COLWISE, + q_layout=QuantizeLayout.ROWWISE_COLWISE, ) @@ -434,14 +448,14 @@ def test_norm_forward_with_block_scaling_fp8( } ALL_QUANTIZE_TEST_SHAPES = [ - (128, 128), - (4, 256, 512), + (32, 64), + (2, 64, 32), ] QUANTIZE_TEST_SHAPES = { "L0": [ - (256, 128), - (64, 16, 2, 256), + (32, 256, 128), + (64, 32, 32, 256), ], "L2": ALL_QUANTIZE_TEST_SHAPES, } @@ -457,48 +471,52 @@ def test_norm_forward_with_block_scaling_fp8( @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("input_shape", ALL_QUANTIZE_TEST_SHAPES) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) +@pytest_parametrize_wrapper("flatten_axis", [-1, -2]) @pytest_parametrize_wrapper( - "q_axis", [QuantizeAxis.ROWWISE, QuantizeAxis.COLWISE, QuantizeAxis.ROWWISE_COLWISE] + "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE] ) class TestQuantize: """ Purely quantization related tests that will always test on a wider set of types and shapes """ - def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_axis): + def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis): key = jax.random.PRNGKey(0) # Quantizer is created once as some quantization approaches use state from previous iterations (e.g. delayed scaling) quantizer = QuantizerFactory.create( scaling_mode=scaling_mode, q_dtype=q_dtype, - q_axis=q_axis, + q_layout=q_layout, ) + # Adding dimension to test if padding is done correctly when flatten 3D to 2D + if flatten_axis == -2: + input_shape = input_shape[:-1] + (2,) + input_shape[-1:] n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1 for _ in range(n_iterations): x = jax.random.uniform(key, input_shape, in_dtype) - scaled_tensor = quantizer.quantize(x) + scaled_tensor = quantizer.quantize(x, flatten_axis=flatten_axis) assert_dequantized_scaled_tensor(scaled_tensor, x) - def test_quantize_bitwise(self, in_dtype, input_shape, q_dtype, scaling_mode, q_axis): - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8( - input_shape - ): - pytest.skip(f"Input shape {input_shape} is not supported by MXFP8") + def test_quantize_bitwise( + self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis + ): key = jax.random.PRNGKey(0) + if flatten_axis == -2: + input_shape = input_shape[:-1] + (2,) + input_shape[-1:] input = jax.random.uniform(key, input_shape, in_dtype) te_quantizer, jax_quantizer = QuantizerFactory.create( - n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_axis=q_axis + n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout ) - jax_output = _jax_quantize(input, quantizer=jax_quantizer) + jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis) - te_output = tex.quantize(input, quantizer=te_quantizer) - assert_bitwise_scaled_tensors(jax_output, te_output) + te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis) + assert_bitwise_scaled_tensors(te_output, jax_output) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) @@ -508,9 +526,13 @@ class TestFusedQuantize: @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("input_shape", QUANTIZE_TEST_SHAPES) @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) - @pytest_parametrize_wrapper("q_axis", [QuantizeAxis.ROWWISE, QuantizeAxis.ROWWISE_COLWISE]) - def test_quantize_dbias(self, in_dtype, input_shape, out_dtype, scaling_mode, q_axis): - transpose_axis = -1 + @pytest_parametrize_wrapper( + "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE] + ) + @pytest_parametrize_wrapper("flatten_axis", [-1, -2]) + def test_quantize_dbias( + self, in_dtype, input_shape, out_dtype, scaling_mode, q_layout, flatten_axis + ): if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8( input_shape ): @@ -520,35 +542,37 @@ def test_quantize_dbias(self, in_dtype, input_shape, out_dtype, scaling_mode, q_ input = jax.random.uniform(key, input_shape, in_dtype) jax_quantizer, te_quantizer = QuantizerFactory.create( - n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_axis=q_axis + n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout ) - te_output, te_dbias = jit(lambda input: tex.quantize_dbias(input, quantizer=te_quantizer))( - input - ) + te_output, te_dbias = jit( + lambda input: tex.quantize_dbias( + input, quantizer=te_quantizer, flatten_axis=flatten_axis + ) + )(input) jax_output, jax_dbias = jit( lambda input: _jax_quantize_dbias( - input, - quantizer=jax_quantizer, + input, quantizer=jax_quantizer, flatten_axis=flatten_axis ) )(input) - assert_bitwise_scaled_tensors(jax_output, te_output) + assert_bitwise_scaled_tensors(te_output, jax_output) - assert_allclose(jax_dbias, te_dbias) + assert_allclose(te_dbias, jax_dbias) def _test_quantize_dact_dbias( - self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_axis + self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_layout ): key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 2) x = jax.random.uniform(subkeys[0], input_shape, in_dtype, -1, 1) - x = jnp.repeat(x, len(activation_type), axis=-1) + x = jnp.expand_dims(x, axis=-2) + x = jnp.repeat(x, len(activation_type), axis=-2) dz = jax.random.uniform(subkeys[1], input_shape, in_dtype, -1, 1) jax_quantizer, te_quantizer = QuantizerFactory.create( - n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_axis=q_axis + n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout ) is_casted_output = te_quantizer is not None @@ -573,12 +597,12 @@ def _test_quantize_dact_dbias( )(dz, x) if is_casted_output: - assert_bitwise_scaled_tensors(jax_output, te_output) + assert_bitwise_scaled_tensors(te_output, jax_output) else: - assert_allclose(jax_output, te_output) + assert_allclose(te_output, jax_output) if is_dbias: - assert_allclose(jax_dbias, te_dbias) + assert_allclose(te_dbias, jax_dbias) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES) @@ -597,7 +621,7 @@ def test_quantize_dact_dbias_no_quantization( scaling_mode=ScalingMode.NVTE_NO_SCALING, activation_type=activation_type, is_dbias=is_dbias, - q_axis=QuantizeAxis.ROWWISE, + q_layout=QuantizeLayout.ROWWISE, ) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @@ -605,9 +629,11 @@ def test_quantize_dact_dbias_no_quantization( @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) @pytest_parametrize_wrapper("is_dbias", [True, False]) - @pytest_parametrize_wrapper("q_axis", [QuantizeAxis.COLWISE, QuantizeAxis.ROWWISE_COLWISE]) + @pytest_parametrize_wrapper( + "q_layout", [QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE] + ) def test_quantize_dact_dbias_delayed_scaling( - self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_axis + self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout ): self._test_quantize_dact_dbias( in_dtype=in_dtype, @@ -616,7 +642,7 @@ def test_quantize_dact_dbias_delayed_scaling( scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, activation_type=activation_type, is_dbias=is_dbias, - q_axis=q_axis, + q_layout=q_layout, ) @pytest.mark.skipif(not is_mxfp8_supported, reason=reason) @@ -626,9 +652,11 @@ def test_quantize_dact_dbias_delayed_scaling( ) @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) @pytest_parametrize_wrapper("is_dbias", [True, False]) - @pytest_parametrize_wrapper("q_axis", [QuantizeAxis.COLWISE, QuantizeAxis.ROWWISE_COLWISE]) + @pytest_parametrize_wrapper( + "q_layout", [QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE] + ) def test_quantize_dact_dbias_mxfp8_scaling( - self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_axis + self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout ): if reduce(operator.mul, input_shape[:-1]) % 128 != 0 or input_shape[-1] % 128 != 0: # TODO(Jeremy): Remove this if pulling in newer TE branch supports non-full-tile shapes. @@ -645,75 +673,75 @@ def test_quantize_dact_dbias_mxfp8_scaling( scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, activation_type=activation_type, is_dbias=is_dbias, - q_axis=q_axis, + q_layout=q_layout, ) class TestDense: - def _ref_gemm_with_jnp_dot(self, a, b, layout): - if layout[0] == "T": + def _ref_gemm_with_jnp_dot(self, a, b, data_layout): + if data_layout[0] == "T": a = jnp.swapaxes(a, -1, -2) - if layout[1] == "T": + if data_layout[1] == "T": b = jnp.swapaxes(b, -1, -2) return jnp.dot(a, b) - def _generate_gemm_input(self, m, n, k, layout): + def _generate_gemm_input(self, m, n, k, data_layout): key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 2) x = jax.random.uniform( subkeys[0], - (m if layout[0] == "N" else k, k if layout[0] == "N" else m), + (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m), dtype=jnp.bfloat16, ) / jnp.sqrt(k) w = jax.random.uniform( subkeys[1], - (k if layout[1] == "N" else n, n if layout[1] == "N" else k), + (k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k), dtype=jnp.bfloat16, ) / jnp.sqrt(n) - lhs_contracting_dim = (1,) if layout[0] == "N" else (0,) - rhs_contracting_dim = (0,) if layout[1] == "N" else (1,) + lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,) + rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,) contracting_dims = (lhs_contracting_dim, rhs_contracting_dim) return (x, w, contracting_dims) - @pytest_parametrize_wrapper("m,n,k", [(512, 128, 256)]) - @pytest_parametrize_wrapper("layout", ["TN", "NT", "NN", "TT"]) - def test_gemm_bf16(self, m, n, k, layout): - x, w, contracting_dims = self._generate_gemm_input(m, n, k, layout) + @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) + @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) + def test_gemm_bf16(self, m, n, k, data_layout): + x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) primitive_out = tex.gemm(x, w, contracting_dims) - ref_out = self._ref_gemm_with_jnp_dot(x, w, layout) + ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest_parametrize_wrapper("m,n,k", [(512, 128, 256)]) + @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) - @pytest_parametrize_wrapper("layout", ["TN", "NT", "NN", "TT"]) - def test_gemm_fp8(self, m, n, k, q_dtype, scaling_mode, layout): - x, w, contracting_dims = self._generate_gemm_input(m, n, k, layout) + @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) + def test_gemm_fp8(self, m, n, k, q_dtype, scaling_mode, data_layout): + x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) quantizer_set = QuantizerFactory.create_set( scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=False ) primitive_out = tex.gemm( x, w, contracting_dims=contracting_dims, quantizer_set=quantizer_set ) - ref_out = self._ref_gemm_with_jnp_dot(x, w, layout) + ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) assert_allclose(primitive_out, ref_out, dtype=q_dtype) - @pytest_parametrize_wrapper("m,n,k", [(512, 128, 256)]) + @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) def test_dense_grad_bf16(self, m, n, k): - layout = "NN" - x, w, contracting_dims = self._generate_gemm_input(m, n, k, layout) + data_layout = "NN" + x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) def primitive_func(x, w, contracting_dims): primitive_out = dense(x, w, contracting_dims=contracting_dims) return jnp.mean(primitive_out) - def ref_func(x, w, layout): - return jnp.mean(self._ref_gemm_with_jnp_dot(x, w, layout)) + def ref_func(x, w, data_layout): + return jnp.mean(self._ref_gemm_with_jnp_dot(x, w, data_layout)) value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1)) @@ -722,19 +750,19 @@ def ref_func(x, w, layout): primitive_out, (primitive_x_grad, primitive_w_grad) = value_n_grad_primitive_func( x, w, contracting_dims ) - ref_out, (ref_x_grad, ref_w_grad) = value_n_grad_ref_func(x, w, layout) + ref_out, (ref_x_grad, ref_w_grad) = value_n_grad_ref_func(x, w, data_layout) assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.bfloat16) assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest_parametrize_wrapper("m,n,k", [(512, 128, 256)]) + @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) def test_dense_grad_fp8(self, m, n, k, q_dtype, scaling_mode): - layout = "NN" - x, w, contracting_dims = self._generate_gemm_input(m, n, k, layout) + data_layout = "NN" + x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) key = jax.random.PRNGKey(1) bias = jax.random.uniform(key, n, dtype=jnp.bfloat16) @@ -745,9 +773,9 @@ def primitive_func(x, w, bias, contracting_dims, quantizer_set): ) return jnp.mean(primitive_out) - def ref_func(x, w, bias, layout): + def ref_func(x, w, bias, data_layout): return jnp.mean( - self._ref_gemm_with_jnp_dot(x, w, layout) + jnp.expand_dims(bias, axis=0) + self._ref_gemm_with_jnp_dot(x, w, data_layout) + jnp.expand_dims(bias, axis=0) ) value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2)) @@ -763,7 +791,9 @@ def ref_func(x, w, bias, layout): value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set) ) - ref_out, (ref_x_grad, ref_w_grad, ref_bias_grad) = value_n_grad_ref_func(x, w, bias, layout) + ref_out, (ref_x_grad, ref_w_grad, ref_bias_grad) = value_n_grad_ref_func( + x, w, bias, data_layout + ) assert_allclose(primitive_out, ref_out, dtype=q_dtype) assert_allclose(primitive_x_grad, ref_x_grad, dtype=q_dtype) @@ -791,7 +821,7 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan class TestFusedDense: @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("m,n,k", [(512, 128, 128)]) + @pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) @pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) @@ -873,7 +903,7 @@ def ref_func(x, w, gamma, beta): assert_allclose(prim_beta_grad, ref_beta_grad, dtype=q_dtype) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("m,n,k", [(512, 128, 256)]) + @pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")]) @pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @@ -898,13 +928,13 @@ def test_layernorm_mlp_grad( x = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16) kernel_1 = jax.random.normal( - subkeys[1], (k, len(activation_type) * n), jnp.bfloat16 + subkeys[1], (k, len(activation_type), n), jnp.bfloat16 ) / jnp.sqrt(k) kernel_2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) / jnp.sqrt(n) gamma = jax.random.normal(subkeys[5], (k,), jnp.bfloat16) beta = None # was tested in TestNorm if use_bias: - bias_1 = jax.random.normal(subkeys[3], (len(activation_type) * n), jnp.bfloat16) + bias_1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16) bias_2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16) else: bias_1 = None @@ -1039,19 +1069,19 @@ def _generate_grouped_gemm_input(self, dtype, shape_list, layout_list): subkeys = jax.random.split(key, len(shape_list) * 2) lhs_list, rhs_list, contracting_dims_list = [], [], [] - for i, ((m, n, k), layout) in enumerate(zip(shape_list, layout_list)): + for i, ((m, n, k), data_layout) in enumerate(zip(shape_list, layout_list)): lhs = jax.random.uniform( subkeys[2 * i], - (m if layout[0] == "N" else k, k if layout[0] == "N" else m), + (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m), dtype=dtype, ) rhs = jax.random.uniform( subkeys[2 * i + 1], - (k if layout[1] == "N" else n, n if layout[1] == "N" else k), + (k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k), dtype=dtype, ) - lhs_contracting_dim = (1,) if layout[0] == "N" else (0,) - rhs_contracting_dim = (0,) if layout[1] == "N" else (1,) + lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,) + rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,) contracting_dims = (lhs_contracting_dim, rhs_contracting_dim) lhs_list.append(lhs) diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index efc24fe6ea..4350d5e8f3 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -45,11 +45,17 @@ SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling")) DTYPES = [jnp.bfloat16, jnp.float16] -INPUT_SHAPE = [[2, 64, 64]] # [batch, seqlen, hidden_in] +INPUT_SHAPE = [[4, 64, 128]] # [batch, seqlen, hidden_in] LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES) DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES) DOT_2_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES) +KERNEL_1_AXES = (W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES) +KERNEL_2_AXES = (W_TP_AXES, W_FSDP_AXES) +LN_SCALE_AXES = (W_NO_SHARD_AXES,) +LN_BIAS_AXES = (W_NO_SHARD_AXES,) +BIAS_1_AXES = (W_JOINED_AXES, W_TP_AXES) +BIAS_2_AXES = (W_NO_SHARD_AXES,) INTERMEDIATE = 64 @@ -60,7 +66,6 @@ def generate_fsdp_and_tp_configs(): configs.append( [2, (1, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")] ) - if is_devices_enough(4): configs.append( [4, (2, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")] @@ -80,13 +85,13 @@ def generate_inputs(self, input_shape, activation_type, use_bias, dtype): x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype) gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype) k1 = jax.random.normal( - subkeys[1], (hidden_in, len(activation_type) * INTERMEDIATE), dtype + subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE), dtype ) / jnp.sqrt(hidden_in) k2 = jax.random.normal(subkeys[2], (INTERMEDIATE, hidden_out), dtype) / jnp.sqrt( INTERMEDIATE ) if use_bias: - b1 = jax.random.normal(subkeys[3], (len(activation_type) * INTERMEDIATE), dtype) + b1 = jax.random.normal(subkeys[3], (len(activation_type), INTERMEDIATE), dtype) b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype) else: b1 = None @@ -111,10 +116,12 @@ def layernorm_fp8_mlp_prim_func( layernorm_input_axes = LAYERNORM_INPUT_AXES dot_1_input_axes = DOT_1_INPUT_AXES dot_2_input_axes = DOT_2_INPUT_AXES + kernel_1_axes = KERNEL_1_AXES + kernel_2_axes = KERNEL_2_AXES else: layernorm_input_axes = None - dot_1_input_axes = None - dot_2_input_axes = None + dot_1_input_axes = dot_2_input_axes = None + kernel_1_axes = kernel_2_axes = None quantizer_sets = QuantizerFactory.create_set(n_quantizer_sets=2) @@ -130,6 +137,8 @@ def layernorm_fp8_mlp_prim_func( norm_input_axes=layernorm_input_axes, dot_1_input_axes=dot_1_input_axes, dot_2_input_axes=dot_2_input_axes, + kernel_1_axes=kernel_1_axes, + kernel_2_axes=kernel_2_axes, activation_type=activation_type, quantizer_sets=quantizer_sets, ) @@ -142,7 +151,7 @@ def layernorm_fp8_mlp_prim_func( @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) - def test_layernorm_fp8_mlp_primitive( + def test_layernorm_mlp_grad( self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe ): device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config @@ -168,12 +177,12 @@ def test_layernorm_fp8_mlp_primitive( devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource): - k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", "tp")) + k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp")) k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp")) k1_ = jax.device_put(k1, k1_sharding) k2_ = jax.device_put(k2, k2_sharding) if use_bias: - b1_sharding = NamedSharding(mesh, PartitionSpec("tp")) + b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp")) b1_ = jax.device_put(b1, b1_sharding) else: b1_sharding = b1_ = None @@ -267,16 +276,7 @@ def _test_layernorm_mlp( transpose_batch_sequence=False, # input: [batch, seqlen, hidden] intermediate_dim=INTERMEDIATE, activations=activation_type, - scale_axes=(W_NO_SHARD_AXES,), - ln_bias_axes=(W_NO_SHARD_AXES,), - kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES), - kernel_axes_2=(W_TP_AXES, W_FSDP_AXES), use_bias=use_bias, - bias_axes_1=(W_JOINED_AXES, W_TP_AXES), - bias_axes_2=(W_NO_SHARD_AXES,), - layernorm_input_axes=LAYERNORM_INPUT_AXES, - dot_1_input_axes=DOT_1_INPUT_AXES, - dot_2_input_axes=DOT_2_INPUT_AXES, ) params_single = ln_mlp_single.init(init_rngs, x, deterministic=True) mlp_out_single, ln_out_single = ln_mlp_single.apply( @@ -295,13 +295,13 @@ def _test_layernorm_mlp( transpose_batch_sequence=False, intermediate_dim=INTERMEDIATE, activations=activation_type, - scale_axes=(W_NO_SHARD_AXES,), - ln_bias_axes=(W_NO_SHARD_AXES,), - kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES), - kernel_axes_2=(W_TP_AXES, W_FSDP_AXES), + scale_axes=LN_SCALE_AXES, + ln_bias_axes=LN_BIAS_AXES, + kernel_axes_1=KERNEL_1_AXES, + kernel_axes_2=KERNEL_2_AXES, use_bias=use_bias, - bias_axes_1=(W_JOINED_AXES, W_TP_AXES), - bias_axes_2=(W_NO_SHARD_AXES,), + bias_axes_1=BIAS_1_AXES, + bias_axes_2=BIAS_2_AXES, layernorm_input_axes=LAYERNORM_INPUT_AXES, dot_1_input_axes=DOT_1_INPUT_AXES, dot_2_input_axes=DOT_2_INPUT_AXES, @@ -334,7 +334,7 @@ def test_layernorm_mlp_layer(self, mesh_config, activation_type, use_bias, input @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) - def test_layernorm_fp8_mlp_layer( + def test_layernorm_mlp_layer_fp8( self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe ): self._test_layernorm_mlp( diff --git a/transformer_engine/jax/activation.py b/transformer_engine/jax/activation.py index a2d0a6f4d9..ef6def2d03 100644 --- a/transformer_engine/jax/activation.py +++ b/transformer_engine/jax/activation.py @@ -91,7 +91,6 @@ def _activation_bwd_rule(activation_type, ctx, g): (x, _) = ctx assert x.dtype == g.dtype dx = tex.dact_lu(g, x, activation_type) - dx = jnp.reshape(dx, x.shape) return (dx, None) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 70227e1620..d7676781c3 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -26,12 +26,12 @@ should_apply_1x_fused_dbias_war_for_arch_l_100, NamedSharding, ) -from .quantization import _jax_quantize_dbias, _jax_dbias, quantize_dbias +from .quantization import _jax_dbias, _quantize_dbias_impl from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..quantize import ScaledTensor, ScaledTensorFactory from ..quantize import ( Quantizer, - QuantizeAxis, + QuantizeLayout, DelayedScaleQuantizer, ScalingMode, ) @@ -110,41 +110,31 @@ def abstract( """ te_act_lu_p abstract """ - del act_enum, act_len, scale_shapes + del act_enum, scale_shapes dtype = dtypes.canonicalize_dtype(x_aval.dtype) assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval is None or scale_aval.dtype == jnp.float32 - - out_shape = ( - *x_aval.shape[:-2], - 1, - x_aval.shape[-1], + assert x_aval.shape[-2] == act_len, ( + "activation input should be replicated by act_len in the -2 axis, got input shape" + f" {x_aval.shape} and act_len {act_len}" ) + + out_shape = (*x_aval.shape[:-2], x_aval.shape[-1]) # Exclude act dim out_aval = x_aval.update(shape=out_shape, dtype=out_dtype) + updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode - ).get_scale_shape_2x(out_shape[:-2] + (out_shape[-1],), is_padded=not is_outer) - - if len(rowwise_scale_inv_shape) > 1: - rowwise_scale_inv_shape = ( - rowwise_scale_inv_shape[:-1] + (1,) + rowwise_scale_inv_shape[-1:] - ) - if len(colwise_scale_inv_shape) > 1: - colwise_scale_inv_shape = ( - colwise_scale_inv_shape[:-1] + (1,) + colwise_scale_inv_shape[-1:] - ) - + ).get_scale_shape_2x(out_shape, is_padded=not is_outer, flatten_axis=-1) + if not is_2x: + out_shape = (1,) + colwise_scale_inv_shape = (1,) + colwise_out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype) - - colwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=out_dtype) - colwise_scale_inv_aval = jax.core.ShapedArray(shape=(1,), dtype=scale_dtype) - if is_2x: - colwise_out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) - colwise_scale_inv_aval = jax.core.ShapedArray( - shape=colwise_scale_inv_shape, dtype=scale_dtype - ) + colwise_scale_inv_aval = jax.core.ShapedArray( + shape=colwise_scale_inv_shape, dtype=scale_dtype + ) return out_aval, colwise_out_aval, scale_inv_aval, colwise_scale_inv_aval, updated_amax_aval @@ -211,15 +201,8 @@ def impl( ) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode - ).get_scale_shape_2x(out.shape[:-2] + (out.shape[-1],), is_padded=False) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: - rowwise_scale_inv_shape = ( - rowwise_scale_inv_shape[:-1] + (1,) + rowwise_scale_inv_shape[-1:] - ) - if is_2x: - colwise_scale_inv_shape = ( - colwise_scale_inv_shape[:-1] + (1,) + colwise_scale_inv_shape[-1:] - ) + ).get_scale_shape_2x(out.shape, is_padded=False, flatten_axis=-1) + # Slice out padding for MXFP8, noop for DelayedScaling scale_inv = jax.lax.slice( scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape ) @@ -227,6 +210,7 @@ def impl( colwise_scale_inv = jax.lax.slice( colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape ) + return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax @staticmethod @@ -292,11 +276,14 @@ def infer_sharding_from_operands( is_outer, ) # Unused. x_spec = get_padded_spec(arg_infos[0]) - out_spec = (*x_spec[:-2], None, x_spec[-2]) + scale_spec = get_padded_spec(arg_infos[1]) + + out_spec = (*x_spec[:-2], x_spec[-1]) out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out") + if is_2x: if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: - colwise_out_spec = multidim_transpose(out_spec) + colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1) else: colwise_out_spec = out_spec else: @@ -304,18 +291,24 @@ def infer_sharding_from_operands( colwise_out_sharding = NamedSharding( mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out" ) + + scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) + if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + scale_inv_spec = amax_spec = scale_spec + elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + scale_inv_spec = out_spec + + if is_2x: + colwise_scale_inv_spec = scale_inv_spec + scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="ActLuPrimitive.scale_inv" + mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv" ) - amax_sharding = scale_inv_sharding.duplicate_with_new_description("ActLuPrimitive.amax") - - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: - scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.scale_inv" - ) - colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description( - "ActLuPrimitive.colwise_scale_inv" + amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax") + colwise_scale_inv_sharding = NamedSharding( + mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv" ) + return ( out_sharding, colwise_out_sharding, @@ -340,14 +333,14 @@ def partition( ): del result_infos, is_outer # Unused. x_spec = get_padded_spec(arg_infos[0]) - out_spec = (*x_spec[:-1], x_spec[-1]) - if act_len == 2 and x_spec[-1] is None: - # Ensure last axis is partitioned and not the gating axis - out_spec = (*x_spec[:-2], None, x_spec[-2]) + scale_spec = get_padded_spec(arg_infos[1]) + + out_spec = (*x_spec[:-2], x_spec[-1]) out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out") + if is_2x: if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: - colwise_out_spec = multidim_transpose(out_spec) + colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1) else: colwise_out_spec = out_spec else: @@ -355,21 +348,25 @@ def partition( colwise_out_sharding = NamedSharding( mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out" ) + + scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) + if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + scale_inv_spec = amax_spec = scale_spec + elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + scale_inv_spec = out_spec + + if is_2x: + colwise_scale_inv_spec = scale_inv_spec + scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="ActLuPrimitive.scale_inv" + mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv" ) - amax_sharding = scale_inv_sharding.duplicate_with_new_description("ActLuPrimitive.amax") - - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: - scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.scale_inv" - ) - colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description( - "ActLuPrimitive.colwise_scale_inv" + amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax") + colwise_scale_inv_sharding = NamedSharding( + mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv" ) - arg_shardings = list(arg_i.sharding for arg_i in arg_infos) - arg_shardings[0] = NamedSharding(mesh, PartitionSpec(*out_spec)) - arg_shardings = tuple(arg_shardings) + + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) out_shardings = ( out_sharding, colwise_out_sharding, @@ -413,6 +410,7 @@ def sharded_impl(x, scale): register_primitive(ActLuPrimitive) +# TODO(Jeremy): replace is_2x with q_layout class DActLuDBiasQuantizePrimitive(BasePrimitive): """ DActLu DBias Cast Transpose Primitive @@ -445,42 +443,41 @@ def abstract( te_dact_dbias_quantize_p abstract """ del act_enum, scale_shapes - dtype = dtypes.canonicalize_dtype(dz_aval.dtype) - assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert x_aval.dtype == dtype + dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype) + assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert x_aval.dtype == dz_dtype + assert x_aval.shape[-2] == act_len, ( + "activation input should be replicated by act_len in the -2 axis, got input shape" + f" {x_aval.shape} and act_len {act_len}" + ) assert scale_aval.dtype == jnp.float32 ir_hidden_size = dz_aval.shape[-1] - gi_hidden_size = x_aval.shape[-1] + gi_hidden_size = act_len * x_aval.shape[-1] assert act_len * ir_hidden_size == gi_hidden_size out_shape = x_aval.shape out_aval = x_aval.update(shape=out_shape, dtype=out_dtype) + updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode - ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer) - - scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype) - - colwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) - colwise_scale_inv_aval = jax.core.ShapedArray(shape=(1,), dtype=scale_dtype) - - dbias_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) - wkspace_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) + ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2) if is_2x: - # Don't transpose output for MXFP8 - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: - t_shape = out_shape + if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2) else: - t_shape = multidim_transpose(out_shape) - colwise_out_aval = x_aval.update(shape=t_shape, dtype=out_dtype) - colwise_scale_inv_aval = jax.core.ShapedArray( - shape=colwise_scale_inv_shape, dtype=scale_dtype - ) + colwise_out_shape = out_shape + else: + colwise_out_shape = (1,) + colwise_scale_inv_shape = (1,) + colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype) + scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype) + colwise_scale_inv_aval = jax.core.ShapedArray( + shape=colwise_scale_inv_shape, dtype=scale_dtype + ) if is_dbias: - dbias_shape = gi_hidden_size - dbias_aval = x_aval.update(shape=dbias_shape, dtype=dtype) + dbias_shape = (act_len, ir_hidden_size) (wkspace_info,) = transformer_engine_jax.get_dact_dbias_quantize_workspace_sizes( x_aval.size // gi_hidden_size, gi_hidden_size, @@ -489,9 +486,14 @@ def abstract( scaling_mode, is_2x, ) - wkspace_aval = x_aval.update( - shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) - ) + wkspace_shape = wkspace_info[0] + wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1]) + else: + dbias_shape = (1,) + wkspace_shape = (1,) + wkspace_dtype = jnp.float32 + dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dz_dtype) + wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype) return ( out_aval, @@ -587,23 +589,16 @@ def impl( ) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode - ).get_scale_shape_2x(x.shape, is_padded=False) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: - scale_inv = jax.lax.slice( - scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape + ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=-2) + # Slice out padding for MXFP8, noop for DelayedScaling + scale_inv = jax.lax.slice( + scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape + ) + if is_2x: + colwise_scale_inv = jax.lax.slice( + colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape ) - if is_2x: - colwise_scale_inv = jax.lax.slice( - colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape - ) - return ( - out, - colwise_out, - scale_inv, - colwise_scale_inv, - updated_amax, - dbias, - ) # Exclude wkspace + return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias @staticmethod def batcher( @@ -670,15 +665,16 @@ def infer_sharding_from_operands( result_infos, ): del out_dtype, result_infos, act_enum - del scale_dtype, scale_shapes, is_dbias, act_len, is_outer + del scale_dtype, scale_shapes, act_len, is_outer x_spec = get_padded_spec(arg_infos[1]) + scale_spec = get_padded_spec(arg_infos[2]) out_sharding = NamedSharding( mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out" ) if is_2x: if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: - colwise_x_spec = multidim_transpose(x_spec) + colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2) else: colwise_x_spec = x_spec else: @@ -687,23 +683,32 @@ def infer_sharding_from_operands( mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out" ) - dbias_shaprding = NamedSharding( + dbias_spec = x_spec[-2:] if is_dbias else (None,) + dbias_sharding = NamedSharding( mesh, - PartitionSpec(x_spec[-1]), + PartitionSpec(*dbias_spec), desc="DActLuDBiasQuantizePrimitive.dbias", ) + + scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) + if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + scale_inv_spec = amax_spec = scale_spec + elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + scale_inv_spec = x_spec + + if is_2x: + colwise_scale_inv_spec = scale_inv_spec + scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.scale_inv" + mesh, PartitionSpec(*scale_inv_spec), desc="DActLuDBiasQuantizePrimitive.scale_inv" ) amax_sharding = NamedSharding( - mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.amax" + mesh, PartitionSpec(*amax_spec), desc="DActLuDBiasQuantizePrimitive.amax" ) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: - scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.scale_inv" - ) - colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description( - "DActLuDBiasQuantizePrimitive.colwise_scale_inv" + colwise_scale_inv_sharding = NamedSharding( + mesh, + PartitionSpec(*colwise_scale_inv_spec), + desc="DActLuDBiasQuantizePrimitive.colwise_scale_inv", ) return ( out_sharding, @@ -711,7 +716,7 @@ def infer_sharding_from_operands( scale_inv_sharding, colwise_scale_inv_sharding, amax_sharding, - dbias_shaprding, + dbias_sharding, ) @staticmethod @@ -731,10 +736,15 @@ def partition( ): del result_infos, is_outer x_spec = get_padded_spec(arg_infos[1]) - out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec), desc="out") + scale_spec = get_padded_spec(arg_infos[2]) + + out_sharding = NamedSharding( + mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out" + ) + if is_2x: if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: - colwise_x_spec = multidim_transpose(x_spec) + colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2) else: colwise_x_spec = x_spec else: @@ -743,38 +753,39 @@ def partition( mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out" ) - dbias_shaprding = NamedSharding( + dbias_spec = x_spec[-2:] if is_dbias else (None,) + dbias_sharding = NamedSharding( mesh, - PartitionSpec(x_spec[-1]), + PartitionSpec(*dbias_spec), desc="DActLuDBiasQuantizePrimitive.dbias", ) + + scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) + if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + scale_inv_spec = amax_spec = scale_spec + elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + scale_inv_spec = x_spec + + if is_2x: + colwise_scale_inv_spec = scale_inv_spec + scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.scale_inv" + mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv" ) - amax_sharding = NamedSharding( - mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.amax" - ) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: - scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.scale_inv" - ) - colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description( - "DActLuDBiasQuantizePrimitive.colwise_scale_inv" + amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax") + colwise_scale_inv_sharding = NamedSharding( + mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv" ) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) - arg_shardings = ( - arg_shardings[1], - arg_shardings[1], - *arg_shardings[2:], - ) # dz and x are the same + out_shardings = ( out_sharding, colwise_out_sharding, scale_inv_sharding, colwise_scale_inv_sharding, amax_sharding, - dbias_shaprding, + dbias_sharding, ) def sharded_impl(dz, x, scale): @@ -816,14 +827,21 @@ def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, S """ JAX native activation implementation """ - x = jnp.split(inputs, len(activation_type), axis=-1) + act_len = len(activation_type) + assert inputs.shape[-2] == act_len, ( + "activation input should be replicated by act_len in the -2 axis, got input shape" + f" {inputs.shape} and act_len {act_len}" + ) + + x = jnp.split(inputs, act_len, axis=-2) acts = [] for idx, act_fn in enumerate(activation_type): x_i = _convert_to_activation_function(act_fn)(x[idx]) acts.append(x_i) x = reduce(operator.mul, acts) + x = jnp.squeeze(x, axis=-2) if quantizer: - return quantizer.quantize(x) + return quantizer.quantize(x, flatten_axis=-1) return x @@ -837,6 +855,12 @@ def _jax_quantize_dact_dbias( """ JAX implementation of dact_lu and dbias with optional quantization """ + act_len = len(activation_type) + assert x.shape[-2] == act_len, ( + "activation input should be replicated by act_len in the -2 axis, got input shape" + f" {x.shape} and act_len {act_len}" + ) + _, vjp_func = jax.vjp( partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32) ) @@ -844,10 +868,10 @@ def _jax_quantize_dact_dbias( dbias = None if is_dbias: - dbias = _jax_dbias(dx).astype(x.dtype) + dbias = _jax_dbias(dx, dtype=x.dtype, flatten_axis=-2) if quantizer is not None: - dx = quantizer.quantize(dx, dq_dtype=x.dtype) + dx = quantizer.quantize(dx, dq_dtype=x.dtype, flatten_axis=-2) else: dx = dx.astype(x.dtype) @@ -863,6 +887,7 @@ def act_lu( Args: x: Input tensor to be processed. + Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations activation_type: Type of activation function to apply. quantizer: Optional quantizer for FP8 quantization of the output. @@ -873,12 +898,17 @@ def act_lu( A ScaledTensor containing the quantized activated input. """ act_type_id = ActivationEnum[activation_type].value + act_len = len(activation_type) + assert x.shape[-2] == act_len, ( + "activation input should be replicated by act_len in the -2 axis, got input shape" + f" {x.shape} and act_len {act_len}" + ) if not ActLuPrimitive.enabled(): return _jax_act_lu(x, activation_type, quantizer) # TE/common does not support colwise-only quantization yet - if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE: + if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: return _jax_act_lu(x, activation_type, quantizer) # TE/common does not support 2x quantization for DelayedScaling yet @@ -889,16 +919,15 @@ def act_lu( return war_output scale = jnp.empty((1,), jnp.float32) - output_shape = (*x.shape[:-1], x.shape[-1] // len(activation_type)) + output_shape = (*x.shape[:-2], x.shape[-1]) if quantizer is None: - x = x.reshape((-1, len(activation_type), x.shape[-1] // len(activation_type))) out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind( x, scale, out_dtype=x.dtype, act_enum=act_type_id, - act_len=len(activation_type), + act_len=act_len, scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value, is_2x=False, scale_dtype=jnp.float32, @@ -911,7 +940,6 @@ def act_lu( if isinstance(quantizer, DelayedScaleQuantizer): scale = quantizer.scale - x = x.reshape((*x.shape[:-1], len(activation_type), x.shape[-1] // len(activation_type))) ( rowwise_casted_output, colwise_casted_output, @@ -923,25 +951,15 @@ def act_lu( scale, out_dtype=quantizer.q_dtype, act_enum=act_type_id, - act_len=len(activation_type), + act_len=act_len, scaling_mode=quantizer.scaling_mode.value, is_2x=quantizer.is_2x2x(), scale_dtype=quantizer.get_scale_dtype(), - scale_shapes=quantizer.get_scale_shapes(output_shape), + # output does not have act axis + scale_shapes=quantizer.get_scale_shapes(output_shape, flatten_axis=-1), is_outer=True, ) - rowwise_casted_output = rowwise_casted_output.reshape(output_shape) - if len(rowwise_scale_inv.shape) > 1: - rowwise_scale_inv = jnp.squeeze(rowwise_scale_inv, axis=-2) # Remove act axis - if quantizer.q_axis in (QuantizeAxis.COLWISE, QuantizeAxis.ROWWISE_COLWISE): - colwise_output_shape = output_shape - if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: - colwise_output_shape = multidim_transpose(output_shape) - colwise_casted_output = colwise_casted_output.reshape(colwise_output_shape) - if len(colwise_scale_inv.shape) > 1: - colwise_scale_inv = jnp.squeeze(colwise_scale_inv, axis=-2) # Remove act axis - quantizer.update(updated_amax) return ScaledTensorFactory.create( @@ -951,8 +969,8 @@ def act_lu( colwise_scale_inv=colwise_scale_inv, scaling_mode=quantizer.scaling_mode, dq_dtype=x.dtype, - q_axis=quantizer.q_axis, - layout=quantizer.get_layout(), + q_layout=quantizer.q_layout, + data_layout=quantizer.get_data_layout(), ) @@ -968,7 +986,7 @@ def quantize_dact_dbias( Args: dz: Gradient of the output with respect to the activation output. x: Input tensor that was processed by the forward pass. - Shape: (..., ACT_DIM * K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations + Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",). is_dbias: If True, compute bias gradient. Defaults to True. quantizer: Optional quantizer for FP8 quantization of the output. @@ -979,21 +997,25 @@ def quantize_dact_dbias( - The gradient of the activation with respect to the bias. """ + act_len = len(activation_type) + assert x.shape[-2] == act_len, ( + "activation input should be replicated by act_len in the -2 axis, got input shape" + f" {x.shape} and act_len {act_len}" + ) + if not DActLuDBiasQuantizePrimitive.enabled(): return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) # TE/common does not support colwise-only quantization yet - if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE: + if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): - out, _ = quantize_dact_dbias( - dz=dz, x=x, activation_type=activation_type, is_dbias=False, quantizer=None - ) - return quantize_dbias(out, is_dbias=True, quantizer=quantizer) + out = dact_lu(dz, x, activation_type, quantizer=None) + return _quantize_dbias_impl(out, quantizer, is_dbias=True, flatten_axis=-2) - is_gated = len(activation_type) == 2 + is_gated = act_len == 2 # TE/common does not support DelayedScaling2x for gated-act yet if is_gated: war_output = try_apply_delayed_scaling_2x_war( @@ -1003,6 +1025,7 @@ def quantize_dact_dbias( activation_type=activation_type, is_dbias=is_dbias, quantizer=quantizer, + flatten_axis=-2, ) if war_output is not None: return war_output @@ -1025,12 +1048,12 @@ def quantize_dact_dbias( scale_shapes=((), ()), # unused is_dbias=False, act_enum=act_type_id, - act_len=len(activation_type), + act_len=act_len, is_outer=True, ) dbias = None if is_dbias: - dbias = _jax_dbias(output).astype(x.dtype) + dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2) return output.astype(x.dtype), dbias if isinstance(quantizer, DelayedScaleQuantizer): @@ -1041,16 +1064,9 @@ def quantize_dact_dbias( dgated = dact_lu( dz.astype(jnp.float32), x.astype(jnp.float32), activation_type=activation_type ) - # TODO(Jeremy): Debug - TE's quantize_dbias produced nans in this case for distributed layernorm_mlp tests - if quantizer.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: - out, dbias = _jax_quantize_dbias(dgated, quantizer=quantizer, dq_dtype=x.dtype) - else: - out, dbias = quantize_dbias( - dgated, - quantizer=quantizer, - is_dbias=True, - dq_dtype=x.dtype, - ) + out, dbias = _quantize_dbias_impl( + dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 + ) return out, dbias out_shape = x.shape @@ -1070,10 +1086,11 @@ def quantize_dact_dbias( scaling_mode=quantizer.scaling_mode.value, is_2x=quantizer.is_2x2x(), scale_dtype=quantizer.get_scale_dtype(), - scale_shapes=quantizer.get_scale_shapes(out_shape), + # output has act axis + scale_shapes=quantizer.get_scale_shapes(out_shape, flatten_axis=-2), is_dbias=is_dbias, act_enum=act_type_id, - act_len=len(activation_type), + act_len=act_len, is_outer=True, ) @@ -1090,8 +1107,9 @@ def quantize_dact_dbias( colwise_scale_inv=colwise_scale_inv, scaling_mode=quantizer.scaling_mode, dq_dtype=x.dtype, - q_axis=quantizer.q_axis, - layout=quantizer.get_layout(), + q_layout=quantizer.q_layout, + data_layout=quantizer.get_data_layout(), + flatten_axis=-2, # as output has act axis ) return out, dbias diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 0fad75817f..736105dd75 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -6,9 +6,9 @@ from typing import Tuple, Sequence, Union, Dict, List from functools import partial, reduce import operator -from transformer_engine_jax import get_device_compute_capability import jax import jax.numpy as jnp +from transformer_engine_jax import get_device_compute_capability from .base import BasePrimitive, register_primitive @@ -183,10 +183,9 @@ def __jitted_jax_gemm_delayed_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision): # Reshape + Transpose # [..., M, K] -> [B, M, K] # [..., K, M] -> [B, M, K] - lhs_3d = _shape_normalization(lhs_dq, lhs_dn, lhs.layout == "N") - rhs_3d = _shape_normalization(rhs_dq, rhs_dn, rhs.layout == "T") + lhs_3d = _shape_normalization(lhs_dq, lhs_dn, lhs.data_layout == "N") + rhs_3d = _shape_normalization(rhs_dq, rhs_dn, rhs.data_layout == "T") - # _shape_normalization ensures contracting_dims=2 and batch_dims=0 dim_nums = (((2,), (2,)), ((0,), (0,))) out_3d = jax.lax.dot_general( lhs_3d, rhs_3d, dim_nums, precision=precision, preferred_element_type=lhs.dq_dtype @@ -203,9 +202,9 @@ def _jax_gemm_delayed_scaling_fp8( ), "rhs does not have delayed tensor scaling mode" (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums - if lhs.layout == "T": + if lhs.data_layout == "T": lhs_contract = tuple((lhs.data.ndim - 1 - i) % lhs.data.ndim for i in lhs_contract) - if rhs.layout == "T": + if rhs.data_layout == "T": rhs_contract = tuple((rhs.data.ndim - 1 - i) % rhs.data.ndim for i in rhs_contract) lhs_dn = (lhs_contract, lhs_batch) @@ -403,19 +402,19 @@ def grouped_gemm( lhs_shape = lhs.data.shape rhs_shape = rhs.data.shape out_dtype = lhs.dq_dtype - # For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal layout + # For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal data_layout if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: assert not ( lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2 ), "FP8 GEMM does not support E5M2 * E5M2" ((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims - if lhs.layout == "T": + if lhs.data_layout == "T": lhs_contract_dim = (lhs_contract_dim - 1) % lhs.data.ndim - if rhs.layout == "T": + if rhs.data_layout == "T": rhs_contract_dim = (rhs_contract_dim - 1) % rhs.data.ndim dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ()) else: - # For jnp.ndarray, only consider contracting_dims, layout is always NN + # For jnp.ndarray, only consider contracting_dims, data_layout is always NN scaling_mode = ScalingMode.NVTE_NO_SCALING lhs_shape = lhs.shape rhs_shape = rhs.shape @@ -432,8 +431,8 @@ def grouped_gemm( lhs_3d = _shape_normalization(lhs, lhs_dn) rhs_3d = _shape_normalization(rhs, rhs_dn) elif scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: - lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.layout == "N") - rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.layout == "T") + lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N") + rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T") elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: lhs_3d = _shape_normalization(lhs.data, lhs_dn) rhs_3d = _shape_normalization(rhs.data, rhs_dn) diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index 980ea556bb..c79eda5568 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -19,7 +19,7 @@ import transformer_engine_jax from ..sharding import get_padded_spec as te_get_padded_spec -from ..quantize import ScalingMode, ScaledTensorFactory, QuantizeAxis +from ..quantize import ScalingMode, ScaledTensorFactory, QuantizeLayout TEDType = transformer_engine_jax.DType @@ -107,37 +107,37 @@ def normalize_axis_boundary(axis, ndim): return axis if axis >= 0 else ndim + axis -def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis_boundary=-1): +def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis=-1): """ te_cast_transpose_p multi-dims transpose static_axis_boundary: int, Indicate those axes <= static_axis_boundary would not be involved into transpose, -1 means all axes involve into transpose. - transpose_axis_boundary: int, Indicate how to split multi-dimensions tensors to 2D matrix for - transpose. Note, transpose_axis_boundary should be greater than static_axis_boundary + transpose_axis: int, Indicate how to split multi-dimensions tensors to 2D matrix for + transpose. Note, transpose_axis should be greater than static_axis_boundary examples: X in shape (dim0, dim1, dim2, dim3, dim4) - static_axis_boundary == -1, transpose_axis_boundary == 2 + static_axis_boundary == -1, transpose_axis == 2 Xt = (dim2, dim3, dim4, dim0, dim1) - static_axis_boundary == 0, transpose_axis_boundary == 2 + static_axis_boundary == 0, transpose_axis == 2 Xt = (dim0, dim2, dim3, dim4, dim1) - static_axis_boundary == 0, transpose_axis_boundary == 3 + static_axis_boundary == 0, transpose_axis == 3 Xt = (dim0, dim3, dim4, dim1. dim2) """ if static_axis_boundary < 0: static_axis_boundary = -1 # means no static axes assert static_axis_boundary < len(shape) - 2 # at least 2 remaining for transpose. transpose_start_idx = static_axis_boundary + 1 - transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, len(shape)) - assert transpose_start_idx < transpose_axis_boundary + transpose_axis = normalize_axis_boundary(transpose_axis, len(shape)) + assert transpose_start_idx < transpose_axis return ( *shape[:transpose_start_idx], - *shape[transpose_axis_boundary:], - *shape[transpose_start_idx:transpose_axis_boundary], + *shape[transpose_axis:], + *shape[transpose_start_idx:transpose_axis], ) @@ -195,13 +195,13 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant break return ( quantizer is not None - and quantizer.q_axis == QuantizeAxis.ROWWISE + and quantizer.q_layout == QuantizeLayout.ROWWISE and arch_l_100 and is_dbias ) -def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs): +def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, flatten_axis=-1, **kwargs): """ Applies a workaround for delayed scaling 2x and can be used when the TE common kernels do not yet support 2x delayed scaling. It will call the given function 'f' with the given arguments and quantizer as 1x and calculate the colwise output by transposing result. @@ -224,14 +224,19 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs): # 2x is not supported by TE kernels for delayed scaling # so revert to 1x and transpose in JAX - quantizer.q_axis = QuantizeAxis.ROWWISE + quantizer.q_layout = QuantizeLayout.ROWWISE rowwise = f(*args, **kwargs, quantizer=quantizer) other_outputs = None if isinstance(rowwise, tuple): other_outputs = rowwise[1:] rowwise = rowwise[0] - quantizer.q_axis = QuantizeAxis.ROWWISE_COLWISE - colwise_data = jnp.transpose(rowwise.data, (-1, *range(rowwise.data.ndim - 1))) + quantizer.q_layout = QuantizeLayout.ROWWISE_COLWISE + if flatten_axis < 0: + flatten_axis += rowwise.data.ndim + assert 0 < flatten_axis < rowwise.data.ndim, "flatten_axis is out of bounds" + colwise_data = jnp.transpose( + rowwise.data, (*range(flatten_axis, rowwise.data.ndim), *range(flatten_axis)) + ) output_2x = ScaledTensorFactory.create( data=rowwise.data, scale_inv=rowwise.scale_inv, @@ -239,8 +244,9 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs): colwise_scale_inv=rowwise.scale_inv, scaling_mode=quantizer.scaling_mode, dq_dtype=rowwise.dq_dtype, - q_axis=QuantizeAxis.ROWWISE_COLWISE, - layout=quantizer.get_layout(), + q_layout=QuantizeLayout.ROWWISE_COLWISE, + data_layout=quantizer.get_data_layout(), + flatten_axis=flatten_axis, ) if other_outputs is not None: return (output_2x,) + other_outputs diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 4a342dd4e0..74882c92db 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -30,7 +30,7 @@ from ..quantize import ScaledTensor, ScaledTensorFactory from ..quantize import ( Quantizer, - QuantizeAxis, + QuantizeLayout, DelayedScaleQuantizer, ScalingMode, ) @@ -277,14 +277,14 @@ def impl( rowwise_scale_inv_shape, colwise_scale_inv_shape = scaling_mode.get_scale_shape_2x( x.shape, is_padded=False ) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: - scale_inv = scale_inv.flatten()[ - : reduce(operator.mul, rowwise_scale_inv_shape) - ].reshape(rowwise_scale_inv_shape) - if is_2x: - colwise_scale_inv = colwise_scale_inv.flatten()[ - : reduce(operator.mul, colwise_scale_inv_shape) - ].reshape(colwise_scale_inv_shape) + # slice out padding for mxfp8, noop for DelayedScaling + scale_inv = scale_inv.flatten()[: reduce(operator.mul, rowwise_scale_inv_shape, 1)].reshape( + rowwise_scale_inv_shape + ) + if is_2x: + colwise_scale_inv = colwise_scale_inv.flatten()[ + : reduce(operator.mul, colwise_scale_inv_shape, 1) + ].reshape(colwise_scale_inv_shape) return ( out, colwise_out, @@ -816,7 +816,7 @@ def layernorm_fwd( return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer) # TE/common does not support normalization with colwise only quantization yet - if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE: + if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer) scale = ( @@ -900,8 +900,8 @@ def layernorm_fwd( colwise_scale_inv=colwise_scale_inv, scaling_mode=quantizer.scaling_mode, dq_dtype=x.dtype, - q_axis=quantizer.q_axis, - layout=quantizer.get_layout(), + q_layout=quantizer.q_layout, + data_layout=quantizer.get_data_layout(), ) return scaled_tensor, mu, rsigma @@ -997,7 +997,7 @@ def rmsnorm_fwd( return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer) # TE/common does not support normalization with colwise only quantization yet - if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE: + if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer) scale = ( @@ -1082,8 +1082,8 @@ def rmsnorm_fwd( colwise_scale_inv=colwise_scale_inv, scaling_mode=quantizer.scaling_mode, dq_dtype=x.dtype, - q_axis=quantizer.q_axis, - layout=quantizer.get_layout(), + q_layout=quantizer.q_layout, + data_layout=quantizer.get_data_layout(), ) return scaled_tensor, rsigma diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 551b4b4bdb..034e149c50 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -2,6 +2,8 @@ # # See LICENSE for license information. """JAX/TE custom ops for quantization""" +import operator +from functools import reduce from typing import Tuple, Optional from packaging import version @@ -24,7 +26,7 @@ ) from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..quantize import ScaledTensor2x, ScaledTensor, ScaledTensorFactory -from ..quantize import Quantizer, QuantizeAxis, DelayedScaleQuantizer, ScalingMode +from ..quantize import Quantizer, QuantizeLayout, DelayedScaleQuantizer, ScalingMode if version.parse(jax.__version__) >= version.parse("0.5.0"): from jax import ffi # pylint: disable=ungrouped-imports @@ -50,7 +52,8 @@ class DBiasQuantizePrimitive(BasePrimitive): 6, 7, 8, - ) # out_dtype, scaling_mode, q_axis, scale_dtype, scale_shapes, is_dbias, is_outer + 9, + ) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, scale_shapes, is_dbias, is_outer inner_primitive = None outer_primitive = None @@ -61,7 +64,8 @@ def abstract( *, out_dtype, scaling_mode, - q_axis, + q_layout, + flatten_axis, scale_dtype, scale_shapes, is_dbias, @@ -73,49 +77,52 @@ def abstract( del scale_shapes dtype = dtypes.canonicalize_dtype(x_aval.dtype) assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + out_shape = x_aval.shape assert scale_aval is None or scale_aval.dtype == jnp.float32 - rowwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=out_dtype) - - if q_axis in (QuantizeAxis.ROWWISE.value, QuantizeAxis.ROWWISE_COLWISE.value): - rowwise_out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) + if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + rowwise_out_shape = out_shape + else: + rowwise_out_shape = (1,) + rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype) updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode - ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer) + ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis) + if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis) + else: + colwise_out_shape = out_shape + else: + colwise_out_shape = (1,) + colwise_scale_inv_shape = (1,) + colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype) scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype) - - colwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=out_dtype) - colwise_scale_inv_aval = jax.core.ShapedArray(shape=(1,), dtype=scale_dtype) - - dbias_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) - wkspace_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) - if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value): - t_shape = multidim_transpose(x_aval.shape) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: - # Don't transpose output for MXFP8 - t_shape = x_aval.shape - colwise_out_aval = x_aval.update(shape=t_shape, dtype=out_dtype) - colwise_scale_inv_aval = jax.core.ShapedArray( - shape=colwise_scale_inv_shape, dtype=scale_dtype - ) + colwise_scale_inv_aval = jax.core.ShapedArray( + shape=colwise_scale_inv_shape, dtype=scale_dtype + ) if is_dbias: - gi_hidden_size = x_aval.shape[-1] - dbias_shape = (gi_hidden_size,) - dbias_aval = x_aval.update(shape=dbias_shape, dtype=dtype) + dbias_shape = x_aval.shape[flatten_axis:] + gi_hidden_size = reduce(operator.mul, x_aval.shape[flatten_axis:], 1) (wkspace_info,) = transformer_engine_jax.get_dbias_quantize_workspace_sizes( x_aval.size // gi_hidden_size, gi_hidden_size, jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(out_dtype), ) - wkspace_aval = x_aval.update( - shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) - ) + wkspace_shape = wkspace_info[0] + wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1]) + else: + dbias_shape = (1,) + wkspace_shape = (1,) + wkspace_dtype = jnp.float32 + dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dtype) + wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype) return ( rowwise_out_aval, @@ -151,7 +158,8 @@ def lowering( *, out_dtype, scaling_mode, - q_axis, + q_layout, + flatten_axis, scale_dtype, scale_shapes, is_dbias, @@ -169,7 +177,8 @@ def lowering( x, scale, scaling_mode=scaling_mode, - q_axis=q_axis, + q_layout=q_layout, + flatten_axis=flatten_axis, is_dbias=is_dbias, ) @@ -179,7 +188,8 @@ def impl( scale, out_dtype, scaling_mode, - q_axis, + q_layout, + flatten_axis, scale_dtype, scale_shapes, is_dbias, @@ -203,7 +213,8 @@ def impl( scale, out_dtype=out_dtype, scaling_mode=scaling_mode, - q_axis=q_axis, + q_layout=q_layout, + flatten_axis=flatten_axis, scale_dtype=scale_dtype, scale_shapes=scale_shapes, is_dbias=is_dbias, @@ -211,16 +222,14 @@ def impl( ) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode - ).get_scale_shape_2x(x.shape, is_padded=False) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: - if q_axis in (QuantizeAxis.ROWWISE.value, QuantizeAxis.ROWWISE_COLWISE.value): - scale_inv = jax.lax.slice( - scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape - ) - if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value): - colwise_scale_inv = jax.lax.slice( - colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape - ) + ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=flatten_axis) + scale_inv = jax.lax.slice( + scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape + ) + if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + colwise_scale_inv = jax.lax.slice( + colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape + ) return ( out, colwise_out, @@ -237,7 +246,8 @@ def batcher( *, out_dtype, scaling_mode, - q_axis, + q_layout, + flatten_axis, scale_dtype, scale_shapes, is_dbias, @@ -260,7 +270,8 @@ def batcher( scale, out_dtype=out_dtype, scaling_mode=scaling_mode, - q_axis=q_axis, + q_layout=q_layout, + flatten_axis=flatten_axis, scale_dtype=scale_dtype, scale_shapes=scale_shapes, is_dbias=is_dbias, @@ -272,7 +283,8 @@ def batcher( def infer_sharding_from_operands( out_dtype, scaling_mode, - q_axis, + q_layout, + flatten_axis, scale_dtype, scale_shapes, is_dbias, @@ -281,16 +293,17 @@ def infer_sharding_from_operands( arg_infos, result_infos, ): - del (out_dtype, result_infos, scale_dtype, scale_shapes, is_dbias, is_outer) # Unused. + del (out_dtype, result_infos, scale_dtype, scale_shapes, is_outer) # Unused. x_spec = get_padded_spec(arg_infos[0]) + scale_spec = get_padded_spec(arg_infos[1]) out_sharding = NamedSharding( mesh, - PartitionSpec(*x_spec[:-1], x_spec[-1]), + PartitionSpec(*x_spec), desc="DBiasQuantizePrimitive.out_sharding", ) - if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value): + if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: - colwise_out_spec = multidim_transpose(x_spec) + colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) else: colwise_out_spec = x_spec else: @@ -300,26 +313,35 @@ def infer_sharding_from_operands( PartitionSpec(*colwise_out_spec), desc="DBiasQuantizePrimitive.colwise_out_sharding", ) - scale_inv_sharding = NamedSharding( + + dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,) + dbias_sharding = NamedSharding( mesh, - PartitionSpec(*get_padded_spec(arg_infos[1])), - desc="DBiasQuantizePrimitive.scale_inv", + PartitionSpec(*dbias_spec), + desc="DBiasQuantizePrimitive.dbias_sharding", ) - amax_sharding = scale_inv_sharding.duplicate_with_new_description( - desc="DBiasQuantizePrimitive.amax_sharding" + + scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) + if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + scale_inv_spec = amax_spec = scale_spec + elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + scale_inv_spec = x_spec + + if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + colwise_scale_inv_spec = scale_inv_spec + + scale_inv_sharding = NamedSharding( + mesh, PartitionSpec(*scale_inv_spec), desc="DBiasQuantizePrimitive.scale_inv" ) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: - scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec), desc="DBiasQuantizePrimitive.scale_inv" - ) - colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description( - "DBiasQuantizePrimitive.colwise_scale_inv" + amax_sharding = NamedSharding( + mesh, PartitionSpec(*amax_spec), desc="DBiasQuantizePrimitive.amax" ) - dbias_sharding = NamedSharding( + colwise_scale_inv_sharding = NamedSharding( mesh, - PartitionSpec(x_spec[-1]), - desc="DBiasQuantizePrimitive.dbias_sharding", + PartitionSpec(*colwise_scale_inv_spec), + desc="DBiasQuantizePrimitive.colwise_scale_inv", ) + return ( out_sharding, colwise_out_sharding, @@ -333,7 +355,8 @@ def infer_sharding_from_operands( def partition( out_dtype, scaling_mode, - q_axis, + q_layout, + flatten_axis, scale_dtype, scale_shapes, is_dbias, @@ -344,14 +367,15 @@ def partition( ): del result_infos, is_outer x_spec = get_padded_spec(arg_infos[0]) + scale_spec = get_padded_spec(arg_infos[1]) out_sharding = NamedSharding( mesh, - PartitionSpec(*x_spec[:-1], x_spec[-1]), + PartitionSpec(*x_spec), desc="DBiasQuantizePrimitive.out_sharding", ) - if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value): + if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: - colwise_out_spec = multidim_transpose(x_spec) + colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) else: colwise_out_spec = x_spec else: @@ -361,26 +385,35 @@ def partition( PartitionSpec(*colwise_out_spec), desc="DBiasQuantizePrimitive.colwise_out_sharding", ) - scale_inv_sharding = NamedSharding( + + dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,) + dbias_sharding = NamedSharding( mesh, - PartitionSpec(*get_padded_spec(arg_infos[1])), - desc="DBiasQuantizePrimitive.scale_inv", + PartitionSpec(*dbias_spec), + desc="DBiasQuantizePrimitive.dbias_sharding", ) - amax_sharding = scale_inv_sharding.duplicate_with_new_description( - desc="DBiasQuantizePrimitive.amax_sharding" + + scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) + if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: + scale_inv_spec = amax_spec = scale_spec + elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: + scale_inv_spec = x_spec + + if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + colwise_scale_inv_spec = scale_inv_spec + + scale_inv_sharding = NamedSharding( + mesh, PartitionSpec(*scale_inv_spec), desc="DBiasQuantizePrimitive.scale_inv" ) - if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: - scale_inv_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec), desc="DBiasQuantizePrimitive.scale_inv" - ) - colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description( - "DBiasQuantizePrimitive.colwise_scale_inv" + amax_sharding = NamedSharding( + mesh, PartitionSpec(*amax_spec), desc="DBiasQuantizePrimitive.amax" ) - dbias_sharding = NamedSharding( + colwise_scale_inv_sharding = NamedSharding( mesh, - PartitionSpec(x_spec[-1]), - desc="DBiasQuantizePrimitive.dbias_sharding", + PartitionSpec(*colwise_scale_inv_spec), + desc="DBiasQuantizePrimitive.colwise_scale_inv", ) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) out_shardings = ( out_sharding, @@ -404,7 +437,8 @@ def sharded_impl(x, scale): scale, out_dtype=out_dtype, scaling_mode=scaling_mode, - q_axis=q_axis, + q_layout=q_layout, + flatten_axis=flatten_axis, scale_dtype=scale_dtype, scale_shapes=scale_shapes, is_dbias=is_dbias, @@ -436,49 +470,45 @@ def sharded_impl(x, scale): register_primitive(DBiasQuantizePrimitive) -def _jax_quantize(x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None): +def _jax_quantize( + x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1 +): if quantizer is None: return x - return quantizer.quantize(x, dq_dtype=dq_dtype) + return quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis) -def _jax_dbias(dx: jnp.ndarray): +def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1): + assert flatten_axis < 0 + dtype = dtype or dx.dtype dbias = jnp.sum( - dx, - axis=tuple(range(dx.ndim - 1)), + dx.astype(jnp.float32), + axis=tuple(range(dx.ndim + flatten_axis)), keepdims=False, ) - dbias = dbias.ravel() # C++ function returns an 1D array for dbias - return dbias + return dbias.astype(dtype) def _jax_quantize_dbias( x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, + flatten_axis: int = -1, ): if quantizer is None: return x, None - return quantizer.quantize(x, dq_dtype=dq_dtype), _jax_dbias(x) - - -def _jax_dbias( - dx: jnp.ndarray, -): - dbias = jnp.sum( - dx.astype(jnp.float32), - axis=tuple(range(dx.ndim - 1)), - keepdims=False, + return ( + quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis), + _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis), ) - dbias = dbias.ravel() # C++ function returns an 1D array for dbias - return dbias.astype(dx.dtype) -def _quantize_impl( +def _quantize_dbias_impl( x: jnp.ndarray, quantizer: Quantizer, is_dbias: bool = False, dq_dtype: Optional[jnp.dtype] = None, + flatten_axis: int = -1, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """ Cast wrapper @@ -488,40 +518,51 @@ def _quantize_impl( quantizer is not None ), "quantizer must be provided if dq_dtype is provided" + dq_dtype = dq_dtype or x.dtype + if not DBiasQuantizePrimitive.enabled(): if is_dbias: return _jax_quantize_dbias( x, quantizer=quantizer, dq_dtype=dq_dtype, + flatten_axis=flatten_axis, ) - return _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype), None + return ( + _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis), + None, + ) # TE/common doesn't support colwise only quantization yet - if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE: + if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: if is_dbias: return _jax_quantize_dbias( x, quantizer=quantizer, dq_dtype=dq_dtype, + flatten_axis=flatten_axis, ) - return _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype), None + return ( + _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis), + None, + ) scale = jnp.empty((), jnp.float32) # TE/common dbias_quantize does not support 1x on arch < 100 if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): - out, _ = _quantize_impl( + out, _ = _quantize_dbias_impl( x=x, is_dbias=False, quantizer=quantizer, dq_dtype=dq_dtype, + flatten_axis=flatten_axis, ) - dbias = _jax_dbias(x) + dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) return out, dbias if quantizer is None: if is_dbias: - return x, _jax_dbias(x) + return x, _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) return x, None if isinstance(quantizer, DelayedScaleQuantizer): @@ -539,9 +580,10 @@ def _quantize_impl( scale, out_dtype=quantizer.q_dtype, scaling_mode=quantizer.scaling_mode.value, - q_axis=quantizer.q_axis.value, + q_layout=quantizer.q_layout.value, + flatten_axis=flatten_axis, scale_dtype=quantizer.get_scale_dtype(), - scale_shapes=quantizer.get_scale_shapes(x.shape), + scale_shapes=quantizer.get_scale_shapes(x.shape, flatten_axis=flatten_axis), is_dbias=is_dbias, is_outer=True, ) @@ -557,18 +599,18 @@ def _quantize_impl( colwise_data=colwise_casted_output, colwise_scale_inv=colwise_scale_inv, scaling_mode=quantizer.scaling_mode, - dq_dtype=dq_dtype if dq_dtype is not None else x.dtype, - q_axis=quantizer.q_axis, - layout=quantizer.get_layout(), + dq_dtype=dq_dtype, + q_layout=quantizer.q_layout, + data_layout=quantizer.get_data_layout(), + flatten_axis=flatten_axis, ) - return out, dbias + return out, dbias.astype(dq_dtype) -# TODO(Phuong): do not expose dq_dtype to users def quantize( x: jnp.ndarray, quantizer: Quantizer, - dq_dtype: Optional[jnp.dtype] = None, + flatten_axis: int = -1, ) -> Tuple[ScaledTensor]: """Quantize input tensor according to the quantizer. @@ -576,26 +618,25 @@ def quantize( x: Input tensor to be quantized. Shape: (..., K) where K is the hidden size. quantizer: Quantizer for FP8 quantization of the output. - dq_dtype: Optional dtype for dequantization. - If None, uses the same dtype as the input tensor. + flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. + Defaults to -1. Returns: A ScaledTensor containing the quantized input tensor. """ - out, _ = _quantize_impl( + out, _ = _quantize_dbias_impl( x, quantizer=quantizer, - dq_dtype=dq_dtype, + flatten_axis=flatten_axis, ) return out -# TODO(Phuong): do not expose dq_dtype to users def quantize_dbias( dz: jnp.ndarray, quantizer: Quantizer, is_dbias: bool = True, - dq_dtype: Optional[jnp.dtype] = None, + flatten_axis: int = -1, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """Quantize input tensor and compute bias gradient. @@ -604,8 +645,8 @@ def quantize_dbias( Shape: (..., K) where K is the hidden size. quantizer: Quantizer for FP8 quantization of the output. is_dbias: If True, compute bias gradient. Defaults to True. - dq_dtype: Optional dtype for dequantization. - If None, uses the same dtype as the input tensor. + flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. + Defaults to -1. Returns: A tuple containing: @@ -614,9 +655,6 @@ def quantize_dbias( - The bias gradient tensor. Shape: (K,) or empty if is_dbias is False. """ - return _quantize_impl( - dz, - quantizer=quantizer, - is_dbias=is_dbias, - dq_dtype=dq_dtype, + return _quantize_dbias_impl( + dz, quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis ) diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 861db97a26..e71597e4b3 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -11,14 +11,6 @@ #include "transformer_engine/cast.h" #include "xla/ffi/api/c_api.h" -namespace { -bool is_gated(NVTE_Activation_Type act_type) { - return act_type == NVTE_Activation_Type::GEGLU || act_type == NVTE_Activation_Type::SWIGLU || - act_type == NVTE_Activation_Type::REGLU || act_type == NVTE_Activation_Type::QGEGLU || - act_type == NVTE_Activation_Type::SREGLU; -} -} // namespace - namespace transformer_engine { namespace jax { @@ -44,38 +36,56 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal auto act_len = input_dims[input_dims.size() - 2]; auto scaling_mode = static_cast(scaling_mode_enum); auto is_2x = static_cast(is_2x_int); + auto flatten_axis = output_buf->dimensions().size() - 1; // output does not have act axis auto input_shape = std::vector{m, act_len * n}; auto output_shape = std::vector{m, n}; + auto output_trans_shape = std::vector{n, m}; auto input_tensor = TensorWrapper(input, input_shape, static_cast(in_dtype)); auto output_tensor = TensorWrapper(scaling_mode); output_tensor.set_rowwise_data(output, static_cast(out_dtype), output_shape); if (is_fp8_dtype(out_dtype)) { - output_tensor.set_rowwise_scale_inv( - scale_inv_buf->untyped_data(), - convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), - std::vector{ - product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1), - scale_inv_buf->dimensions().back()}); - } - - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) { - NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); - NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); - cudaMemsetAsync(amax, 0, sizeof(float), stream); - output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); - output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); + NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); + cudaMemsetAsync(amax, 0, sizeof(float), stream); + output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); + output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); + output_tensor.set_rowwise_scale_inv( + scale_inv_buf->untyped_data(), + convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector{1}); + } else { + output_tensor.set_rowwise_scale_inv( + scale_inv_buf->untyped_data(), + convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), + std::vector{product(scale_inv_buf->dimensions(), 0, flatten_axis), + product(scale_inv_buf->dimensions(), flatten_axis, + scale_inv_buf->dimensions().size())}); + } } if (is_2x) { - output_tensor.set_columnwise_data(colwise_output, static_cast(out_dtype), output_shape); - output_tensor.set_columnwise_scale_inv( - colwise_scale_inv_buf->untyped_data(), - convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()), - std::vector{product(colwise_scale_inv_buf->dimensions(), 0, - colwise_scale_inv_buf->dimensions().size() - 1), - colwise_scale_inv_buf->dimensions().back()}); + auto &tmp_shape = + (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape; + output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape); + + if (is_fp8_dtype(out_dtype)) { + // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling + auto &tmp_buf = + (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf; + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + output_tensor.set_columnwise_scale_inv( + tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), + std::vector{1}); + } else { + output_tensor.set_columnwise_scale_inv( + tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), + std::vector{ + product(tmp_buf->dimensions(), 0, flatten_axis), + product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())}); + } + } } switch (act_type) { @@ -162,8 +172,10 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid } if (is_2x) { - output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, - output_trans_shape); + auto &tmp_shape = scaling_mode == static_cast(NVTE_DELAYED_TENSOR_SCALING) + ? output_trans_shape + : output_shape; + output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, tmp_shape); // Only the pointers will be checked for scale_inv, thus the shapes do not matter if (is_fp8_dtype(out_dtype)) { @@ -190,9 +202,9 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act_input_buf, Buffer_Type scale_buf, - Result_Type output_buf, Result_Type output_trans_buf, - Result_Type scale_inv_buf, Result_Type trans_scale_inv_buf, - Result_Type amax_out_buf, Result_Type dbias_buf, + Result_Type output_buf, Result_Type colwise_output_buf, + Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, + Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf, int64_t scaling_mode_enum, bool is_2x, bool is_dbias, int64_t act_enum) { auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); @@ -201,11 +213,15 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, auto *input = input_buf.untyped_data(); auto *act_input = act_input_buf.untyped_data(); + float *scale = reinterpret_cast(scale_buf.untyped_data()); + float *amax = reinterpret_cast(amax_buf->untyped_data()); auto scaling_mode = static_cast(scaling_mode_enum); + auto act_type = static_cast(act_enum); + auto flatten_axis = output_buf->dimensions().size() - 2; // output has act axis auto *output = output_buf->untyped_data(); - auto *output_trans = output_trans_buf->untyped_data(); + auto *colwise_output = colwise_output_buf->untyped_data(); auto *dbias = dbias_buf->untyped_data(); void *workspace = workspace_buf->untyped_data(); @@ -213,17 +229,18 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, auto act_input_dims = act_input_buf.dimensions(); auto workspace_dims = workspace_buf->dimensions(); // m = x_batch_size = reduce(operator.mul, x_shape[:-2]), x_shape == act_input_dims - // n = ir_dz_shape[-1], ir_dz_shape == input_dims - auto input_ranks = input_dims.size(); - auto act_input_ranks = act_input_dims.size(); - auto m = product(act_input_dims, 0, act_input_dims.size() - 1); - // 'n' will be 2x the size of input_dims.back() if the dactivation is dgated - auto n = act_input_dims.back(); - auto input_shape = std::vector{m, input_dims.back()}; - auto act_input_shape = std::vector{m, n}; - auto output_shape = std::vector{m, n}; - auto output_trans_shape = std::vector{m, n}; - auto dbias_shape = std::vector{n}; + // n = ir_dz_shape[-1] * act_len, ir_dz_shape == input_dims + auto act_len = act_input_dims[act_input_dims.size() - 2]; + NVTE_CHECK(act_input_dims.back() == input_dims.back(), + "Shape mismatch between activation input and gradient input"); + auto m = product(act_input_dims, 0, act_input_dims.size() - 2); + auto n = input_dims.back(); + + auto input_shape = std::vector{m, n}; + auto act_input_shape = std::vector{m, n * act_len}; + auto output_shape = std::vector{m, n * act_len}; + auto output_trans_shape = std::vector{n * act_len, m}; + auto dbias_shape = std::vector{n * act_len}; std::vector workspace_shape(workspace_dims.begin(), workspace_dims.end()); auto input_tensor = TensorWrapper(input, input_shape, in_dtype); @@ -231,50 +248,56 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, auto output_tensor = TensorWrapper(scaling_mode); output_tensor.set_rowwise_data(output, out_dtype, output_shape); if (is_fp8_dtype(out_dtype)) { - output_tensor.set_rowwise_scale_inv( - scale_inv_buf->untyped_data(), - convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), - std::vector{ - product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1), - scale_inv_buf->dimensions().back()}); - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { - float *scale = reinterpret_cast(scale_buf.untyped_data()); - float *amax_out = reinterpret_cast(amax_out_buf->untyped_data()); NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); - NVTE_CHECK(amax_out != nullptr, "amax must be provided for delayed tensor scaling"); - cudaMemsetAsync(amax_out, 0, sizeof(float), stream); + NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); + cudaMemsetAsync(amax, 0, sizeof(float), stream); output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); - output_tensor.set_amax(amax_out, DType::kFloat32, std::vector{1}); + output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); + output_tensor.set_rowwise_scale_inv( + scale_inv_buf->untyped_data(), + convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector{1}); + } else { + output_tensor.set_rowwise_scale_inv( + scale_inv_buf->untyped_data(), + convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), + std::vector{product(scale_inv_buf->dimensions(), 0, flatten_axis), + product(scale_inv_buf->dimensions(), flatten_axis, + scale_inv_buf->dimensions().size())}); } } if (is_2x) { - output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape); + auto &tmp_shape = + (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape; + output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape); if (is_fp8_dtype(out_dtype)) { // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling - auto &colwise_scale_inv_buf = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : trans_scale_inv_buf; - output_tensor.set_columnwise_scale_inv( - colwise_scale_inv_buf->untyped_data(), - convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()), - std::vector{product(colwise_scale_inv_buf->dimensions(), 0, - colwise_scale_inv_buf->dimensions().size() - 1), - colwise_scale_inv_buf->dimensions().back()}); + auto &tmp_buf = + (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf; + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + output_tensor.set_columnwise_scale_inv( + tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), + std::vector{1}); + } else { + output_tensor.set_columnwise_scale_inv( + tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), + std::vector{ + product(tmp_buf->dimensions(), 0, flatten_axis), + product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())}); + } } } auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype); auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype); - auto act_type = static_cast(act_enum); - // fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead - NVTE_CHECK(!(is_gated(act_type) && is_dbias), "Unsupported DGatedActedDBias Fusion!"); - NVTE_CHECK(!(scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING && is_2x && - is_gated(act_type)), - "TE/common does not support delayed scaling for 2x with gated activations."); + NVTE_CHECK(!(act_len == 2 && is_dbias), "Unsupported DGatedActedDBias Fusion!"); + NVTE_CHECK( + !(scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING && is_2x && act_len == 2), + "TE/common does not support delayed scaling for 2x with gated activations."); if (is_dbias) { switch (act_type) { diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 74909319cc..c1e008a5bc 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -44,12 +44,12 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh cudaStreamSynchronize(stream); // Notes on matrix layouts and transpose: - // Jax uses row-major layout, on entering this function, each input matrix pair: + // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major with size [m, k], // B: row-major with size [n, k], needs transpose, // on exiting this function, JAX expect: // C: row-major with size [m, n]. - // cuBLAS uses column-major layout, in this view, each input matrix pair: + // cuBLAS uses column-major data_layout, in this view, each input matrix pair: // A: column-major with size [k, m], needs transpose, // B: column-major with size [k, n]. // If we call cuBLAS GEMM for A * B, the output will be: diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index 09ccf6be86..c8526e20c0 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -34,7 +34,7 @@ inline size_t product(const std::vector &shape) { return ret; } -enum class QuantizeAxis { +enum class QuantizeLayout { ROWWISE, COLWISE, ROWWISE_COLWISE, diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index c777a02c99..ebdfe461c7 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -144,11 +144,11 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("NVTE_INVALID_SCALING", NVTEScalingMode::NVTE_MXFP8_1D_SCALING) .export_values(); - pybind11::enum_(m, "QuantizeAxis", - pybind11::module_local()) - .value("ROWWISE", transformer_engine::jax::QuantizeAxis::ROWWISE) - .value("COLWISE", transformer_engine::jax::QuantizeAxis::COLWISE) - .value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeAxis::ROWWISE_COLWISE) + pybind11::enum_(m, "QuantizeLayout", + pybind11::module_local()) + .value("ROWWISE", transformer_engine::jax::QuantizeLayout::ROWWISE) + .value("COLWISE", transformer_engine::jax::QuantizeLayout::COLWISE) + .value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE) .export_values(); } diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index c8f98dd43f..b48ee8a9b9 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -42,10 +42,10 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, Result_Type output_buf, Result_Type output_trans_buf, - Result_Type scale_inv_buf, Result_Type trans_scale_inv_buf, - Result_Type amax_out_buf, Result_Type dbias_buf, - Result_Type workspace_buf, int64_t scaling_mode_enum, - int64_t quantize_axis_enum, bool is_dbias) { + Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, + Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf, + int64_t scaling_mode_enum, int64_t quantize_layout_enum, bool is_dbias, + int64_t flatten_axis) { auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); @@ -55,7 +55,7 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T auto *input = input_buf.untyped_data(); auto scaling_mode = static_cast(scaling_mode_enum); - auto const quantize_axis = static_cast(quantize_axis_enum); + auto const quantize_layout = static_cast(quantize_layout_enum); auto *output = output_buf->untyped_data(); auto *output_trans = output_trans_buf->untyped_data(); @@ -63,9 +63,13 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T void *workspace = workspace_buf->untyped_data(); auto input_dims = input_buf.dimensions(); + int64_t input_ndim = input_dims.size(); + if (flatten_axis < 0) flatten_axis += input_ndim; + NVTE_CHECK(flatten_axis < input_ndim && flatten_axis > 0, "flatten_axis is out of bounds!"); + auto workspace_dims = workspace_buf->dimensions(); - auto m = product(input_dims, 0, input_dims.size() - 1); - auto n = input_dims.back(); + auto m = product(input_dims, 0, flatten_axis); + auto n = product(input_dims, flatten_axis, input_ndim); auto input_shape = std::vector{m, n}; auto output_shape = std::vector{m, n}; auto output_trans_shape = std::vector{n, m}; @@ -75,37 +79,54 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto output_tensor = TensorWrapper(scaling_mode); - if (quantize_axis == QuantizeAxis::ROWWISE || quantize_axis == QuantizeAxis::ROWWISE_COLWISE) { + if (quantize_layout == QuantizeLayout::ROWWISE || + quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { output_tensor.set_rowwise_data(output, out_dtype, output_shape); - output_tensor.set_rowwise_scale_inv( - scale_inv_buf->untyped_data(), - convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), - std::vector{ - product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1), - scale_inv_buf->dimensions().back()}); - } - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { - float *scale = reinterpret_cast(scale_buf.untyped_data()); - float *amax_out = reinterpret_cast(amax_out_buf->untyped_data()); - NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); - NVTE_CHECK(amax_out != nullptr, "amax must be provided for delayed tensor scaling"); - output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); - cudaMemsetAsync(amax_out, 0, sizeof(float), stream); - output_tensor.set_amax(amax_out, DType::kFloat32, std::vector{1}); + if (is_fp8_dtype(out_dtype)) { + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + float *scale = reinterpret_cast(scale_buf.untyped_data()); + float *amax = reinterpret_cast(amax_buf->untyped_data()); + NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); + NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); + output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); + cudaMemsetAsync(amax, 0, sizeof(float), stream); + output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); + output_tensor.set_rowwise_scale_inv( + scale_inv_buf->untyped_data(), + convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), + std::vector{1}); + } else { + output_tensor.set_rowwise_scale_inv( + scale_inv_buf->untyped_data(), + convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), + std::vector{product(scale_inv_buf->dimensions(), 0, flatten_axis), + product(scale_inv_buf->dimensions(), flatten_axis, + scale_inv_buf->dimensions().size())}); + } + } } - if (quantize_axis == QuantizeAxis::COLWISE || quantize_axis == QuantizeAxis::ROWWISE_COLWISE) { - output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape); + if (quantize_layout == QuantizeLayout::COLWISE || + quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { + auto &tmp_shape = + (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape; + output_tensor.set_columnwise_data(output_trans, out_dtype, tmp_shape); // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling - auto &colwise_scale_inv_buf = - (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : trans_scale_inv_buf; - output_tensor.set_columnwise_scale_inv( - colwise_scale_inv_buf->untyped_data(), - convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()), - std::vector{product(colwise_scale_inv_buf->dimensions(), 0, - colwise_scale_inv_buf->dimensions().size() - 1), - colwise_scale_inv_buf->dimensions().back()}); + auto &tmp_buf = + (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf; + + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + output_tensor.set_columnwise_scale_inv( + tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), + std::vector{1}); + } else { + output_tensor.set_columnwise_scale_inv( + tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), + std::vector{ + product(tmp_buf->dimensions(), 0, flatten_axis), + product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())}); + } } auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype); @@ -133,8 +154,9 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI, .Ret() // dbias .Ret() // wkspace .Attr("scaling_mode") - .Attr("q_axis") - .Attr("is_dbias"), + .Attr("q_layout") + .Attr("is_dbias") + .Attr("flatten_axis"), FFI_CudaGraph_Traits); Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 43336768cb..2ef8b91c86 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -15,7 +15,11 @@ import jax.numpy as jnp from . import cpp_extensions as tex -from .quantize import QuantizerSet, noop_quantizer_set +from .quantize import ( + QuantizerSet, + noop_quantizer_set, + with_sharding_constraint_by_logical_axes, +) def dense( @@ -23,6 +27,8 @@ def dense( kernel: jnp.ndarray, bias: jnp.ndarray = None, contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), + input_axes: Tuple[str, ...] = None, + kernel_axes: Tuple[str, ...] = None, quantizer_set: QuantizerSet = noop_quantizer_set, ): """Perform dense layer transformation with optional quantization. @@ -48,12 +54,12 @@ def dense( bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape output += jnp.reshape(bias, bias_new_shape) else: - output = _dense(x, kernel, bias, contracting_dims, quantizer_set) + output = _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set) return output -@partial(jax.custom_vjp, nondiff_argnums=(3,)) -def _dense(x, kernel, bias, contracting_dims, quantizer_set): +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5)) +def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set): """Internal implementation of dense layer transformation with custom VJP. This function implements the core dense layer transformation logic with support @@ -64,32 +70,37 @@ def _dense(x, kernel, bias, contracting_dims, quantizer_set): kernel: Weight matrix bias: Optional bias tensor contracting_dims: Contracting dimensions specification + input_axes: Logical axes for sharding the activation input + kernel_axes: Logical axes for sharding the weight matrix quantizer_set: QuantizerSet which contains quantizers for different tensor types Returns: Transformed output tensor """ - output, _ = _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set) + output, _ = _dense_fwd_rule( + x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set + ) return output -def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set): +def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set): """Forward pass rule for dense layer transformation. - Args: - x: Input tensor - kernel: Weight matrix - bias: Optional bias tensor - contracting_dims: Contracting dimensions specification - quantizer_set: QuantizerSet which contains quantizers for different tensor types - Returns: Tuple of (output, context) for backward pass """ x_contracting_dims, k_contracting_dims = contracting_dims - casted_x = tex.quantize(x, quantizer_set.x) - casted_kernel = tex.quantize(kernel, quantizer_set.kernel) + flatten_axis_x = -len(x_contracting_dims) + flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + + casted_x = tex.quantize(x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x) + casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) + + casted_kernel = tex.quantize( + kernel, flatten_axis=flatten_axis_k, quantizer=quantizer_set.kernel + ) + casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) # GEMM NN output = tex.gemm( @@ -97,6 +108,7 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set): casted_kernel.get_colwise_tensor(), (x_contracting_dims, k_contracting_dims), ) + use_bias = bias is not None if use_bias: bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape @@ -109,18 +121,16 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set): kernel.shape, use_bias, quantizer_set, + flatten_axis_k, ) return output, ctx -def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argument +def _dense_bwd_rule( + contracting_dims, input_axes, kernel_axes, ctx, grad +): # pylint: disable=unused-argument """Backward pass rule for dense layer transformation. - Args: - contracting_dims: Contracting dimensions specification - ctx: Context from forward pass - grad: Gradient from upstream - Returns: Tuple of gradients with respect to inputs """ @@ -133,9 +143,12 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu kernel_shape, use_bias, quantizer_set, + flatten_axis_k, ) = ctx - casted_grad, dbias = tex.quantize_dbias(grad, is_dbias=use_bias, quantizer=quantizer_set.dgrad) + casted_grad, dbias = tex.quantize_dbias( + grad, is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad + ) # GEMM NT # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim @@ -151,6 +164,7 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu rowwise_casted_kernel, (g_constracting_dim, k_constracting_dim), ) + dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) # GEMM TN # x_non_contracting_dims @@ -161,6 +175,7 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu wgrad = tex.gemm( colwise_casted_x, casted_grad.get_colwise_tensor(), (x_constracting_dim, g_constracting_dim) ) + wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) return dgrad, wgrad, dbias, quantizer_set diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index a0d1e33e38..a944848881 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -28,6 +28,7 @@ from ..sharding import with_sharding_constraint_by_logical_axes from ..cpp_extensions import is_softmax_kernel_available from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode +from ..sharding import get_non_contracting_logical_axes PRNGKey = Any Shape = Tuple[int, ...] @@ -406,6 +407,10 @@ class DenseGeneral(TransformerEngineBase): :math:`\frac{alpha}{rank} * lora_output`. None means no scaling. axis: Union[Iterable[int], int], default = -1 An integer tuple with axes to apply the transformation on. + input_axes: Tuple[str, ...], default = None + Indicate the logical axes of sharding constraint to the input, like + (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert + sharding constraint. Optimization parameters ----------------------- @@ -429,6 +434,7 @@ class DenseGeneral(TransformerEngineBase): axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 transpose_batch_sequence: bool = False + input_axes: Tuple[str, ...] = () def __post_init__(self): if self.kernel_init is None: @@ -460,29 +466,35 @@ def __call__(self, inputs: Array) -> Array: axis = _normalize_axes(axis, inputs.ndim) kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features + + if self.kernel_axes: + assert len(kernel_shape) == len(self.kernel_axes), ( + "Expected len(kernel_shape) to match len(kernel_axes)," + f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}" + ) kernel = nn_partitioning.param_with_axes( "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes ) + if not QuantizeConfig.is_fp8_enabled(): kernel = kernel.astype(input_dtype) - kernel_compute_shape = ( - reduce(operator.mul, [inputs.shape[ax] for ax in axis], 1), - reduce(operator.mul, features, 1), - ) - kernel = jnp.reshape(kernel, kernel_compute_shape) if self.use_bias: bias = nn_partitioning.param_with_axes( "bias", self.bias_init, features, self.dtype, axes=self.bias_axes - ) - bias = bias.reshape(kernel_compute_shape[-1]).astype(input_dtype) + ).astype(input_dtype) else: bias = None quantizer_set = self.generate_quantizer_set() contract_ind = tuple(range(0, len(axis))) y = dense( - inputs, kernel, contracting_dims=(axis, contract_ind), quantizer_set=quantizer_set + inputs, + kernel, + contracting_dims=(axis, contract_ind), + input_axes=self.input_axes, + kernel_axes=self.kernel_axes, + quantizer_set=quantizer_set, ) if self.enable_low_rank_adaptation: @@ -491,20 +503,14 @@ def __call__(self, inputs: Array) -> Array: *features[:-1], self.low_rank_adaptation_dim, ) - lora_a_kernel_init_shape = ( - kernel_compute_shape[0], - *features[:-1], - self.low_rank_adaptation_dim, - ) - lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape) + lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape) lora_a_kernel = nn_partitioning.param_with_axes( "lora_a_kernel", self.kernel_init, - lora_a_kernel_init_shape, + lora_a_kernel_shape, self.dtype, axes=lora_a_kernel_axes, ) - lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape) lora_a_kernel = lora_a_kernel.astype(input_dtype) lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1]) @@ -527,7 +533,6 @@ def __call__(self, inputs: Array) -> Array: y += jnp.reshape(bias, bias_shape) assert y.dtype == input_dtype - y = y.reshape(*inputs.shape[: self.axis], *features) return y @@ -678,6 +683,7 @@ def __call__(self, inputs: Array) -> Array: The output tensors of layer normalization. If :attr:`return_layernorm_output=False`, then this would be None. """ + assert self.axis == -1, "Only support axis = =-1 at this moment" input_dtype = inputs.dtype ln_output = None @@ -692,10 +698,7 @@ def __call__(self, inputs: Array) -> Array: if self.enable_layernorm: inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes) - - assert self.axis == -1 # Only support axis = =-1 at this moment features = inputs.shape[-1] - scale, ln_bias = _create_layernorm_parameters( self.layernorm_type, (features,), @@ -731,17 +734,12 @@ def __call__(self, inputs: Array) -> Array: axis = _normalize_axes(axis, y.ndim) - kernel_shape = tuple(y.shape[ax] for ax in axis) + features + kernel_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features kernel = nn_partitioning.param_with_axes( "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes ) if not QuantizeConfig.is_fp8_enabled(): kernel = kernel.astype(input_dtype) - kernel_compute_shape = ( - reduce(operator.mul, [inputs.shape[ax] for ax in axis], 1), - reduce(operator.mul, features, 1), - ) - kernel = jnp.reshape(kernel, kernel_compute_shape) contract_ind = tuple(range(0, len(axis))) @@ -756,11 +754,19 @@ def __call__(self, inputs: Array) -> Array: epsilon=self.epsilon, layernorm_input_axes=self.layernorm_input_axes, dot_input_axes=self.dot_input_axes, + kernel_axes=self.kernel_axes, quantizer_set=quantizer_set, ) else: y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes) - z = dense(y, kernel, contracting_dims=(axis, contract_ind), quantizer_set=quantizer_set) + z = dense( + y, + kernel, + contracting_dims=(axis, contract_ind), + input_axes=self.dot_input_axes, + kernel_axes=self.kernel_axes, + quantizer_set=quantizer_set, + ) if self.enable_low_rank_adaptation: lora_a_kernel_shape = ( @@ -768,20 +774,14 @@ def __call__(self, inputs: Array) -> Array: *features[:-1], self.low_rank_adaptation_dim, ) - lora_a_kernel_init_shape = ( - kernel_compute_shape[0], - *features[:-1], - self.low_rank_adaptation_dim, - ) - lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape) + lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape) lora_a_kernel = nn_partitioning.param_with_axes( "lora_a_kernel", self.kernel_init, - lora_a_kernel_init_shape, + lora_a_kernel_shape, self.dtype, axes=lora_a_kernel_axes, ) - lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape) lora_a_kernel = lora_a_kernel.astype(input_dtype) lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1]) @@ -803,8 +803,7 @@ def __call__(self, inputs: Array) -> Array: if self.use_bias: bias = nn_partitioning.param_with_axes( "bias", self.bias_init, features, self.dtype, axes=self.bias_axes - ) - bias = bias.reshape(kernel_compute_shape[-1]).astype(input_dtype) + ).astype(input_dtype) if bias is not None: bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape @@ -814,7 +813,7 @@ def __call__(self, inputs: Array) -> Array: z = z / self.depth_scaling assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}" - z = z.reshape(*inputs.shape[: self.axis], *features) + # z = z.reshape(*inputs.shape[: self.axis], *features) return z, ln_output # dense_output, layer_norm_output @@ -989,6 +988,8 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: The output tensors of layer normalization. If :attr:`return_layernorm_output=False`, then this would be None. """ + assert self.axis == -1, "Only support axis == -1 at this moment" + ffn1_quantizer_set = self.generate_quantizer_set("_0") ffn2_quantizer_set = self.generate_quantizer_set("_1") @@ -1027,7 +1028,6 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: ) # LayerNorm if self.enable_layernorm: - assert self.axis == -1 # Only support axis == -1 at this moment inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes) features = inputs.shape[-1] @@ -1071,7 +1071,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): num_activations = len(normalized_acts) axis = _canonicalize_tuple(self.axis) axis = _normalize_axes(axis, y.ndim) - kernel_1_each_shape = (np.prod([y.shape[ax] for ax in axis]), self.intermediate_dim) + kernel_1_each_shape = (np.prod([inputs.shape[ax] for ax in axis]), self.intermediate_dim) kernel_1 = nn_partitioning.param_with_axes( "wi_kernel", kernel_1_init, @@ -1081,17 +1081,10 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): self.dtype, axes=self.kernel_axes_1, ) - kernel_1_compute_shape = ( - reduce(operator.mul, [y.shape[ax] for ax in axis], 1), - num_activations * self.intermediate_dim, - ) - kernel_1 = jnp.reshape(kernel_1, kernel_1_compute_shape) + if not QuantizeConfig.is_fp8_enabled(): kernel_1 = kernel_1.astype(input_dtype) - if self.kernel_axes_1 is not None: - kernel_1 = with_sharding_constraint_by_logical_axes( - kernel_1, self.kernel_axes_1[:-2] + self.kernel_axes_1[-1:] - ) + hidden_size = inputs.shape[-1] hidden_size_tuple = _canonicalize_tuple(hidden_size) kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple @@ -1102,27 +1095,20 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): self.dtype, axes=self.kernel_axes_2, ) - kernel_2_compute_shape = ( - self.intermediate_dim, - reduce(operator.mul, hidden_size_tuple, 1), - ) - kernel_2 = jnp.reshape(kernel_2, kernel_2_compute_shape) if not QuantizeConfig.is_fp8_enabled(): kernel_2 = kernel_2.astype(input_dtype) - if self.kernel_axes_2 is not None: - kernel_2 = with_sharding_constraint_by_logical_axes(kernel_2, self.kernel_axes_2) + contract_ind = tuple(range(0, len(axis))) if self.use_bias: - bias_1_shape = num_activations * self.intermediate_dim + bias_1_shape = (num_activations, self.intermediate_dim) bias_1 = nn_partitioning.param_with_axes( "wi_bias", self.bias_init, bias_1_shape, self.dtype, axes=self.bias_axes_1, - ) - bias_1 = bias_1.reshape(kernel_1_compute_shape[-1]).astype(input_dtype) + ).astype(input_dtype) bias_2_shape = (hidden_size,) bias_2 = nn_partitioning.param_with_axes( @@ -1131,8 +1117,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): bias_2_shape, self.dtype, axes=self.bias_axes_2, - ) - bias_2 = bias_2.reshape(kernel_2_compute_shape[-1]).astype(input_dtype) + ).astype(input_dtype) else: bias_1 = None bias_2 = None @@ -1141,8 +1126,6 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): ffn2_ckpt_name = "ffn2" if use_fused_layernorm_mlp: - assert self.axis == -1 # Only support axis = =-1 at this moment - out = layernorm_mlp( y, scale, @@ -1155,6 +1138,8 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): norm_input_axes=self.layernorm_input_axes, dot_1_input_axes=self.dot_1_input_axes, dot_2_input_axes=self.dot_2_input_axes, + kernel_1_axes=self.kernel_axes_1, + kernel_2_axes=self.kernel_axes_2, ffn1_ckpt_name=ffn1_ckpt_name, ffn2_ckpt_name=ffn2_ckpt_name, activation_type=normalized_acts, @@ -1175,6 +1160,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): epsilon=self.epsilon, layernorm_input_axes=self.layernorm_input_axes, dot_input_axes=self.dot_1_input_axes, + kernel_axes=self.kernel_axes_1, quantizer_set=ffn1_quantizer_set, ) else: @@ -1183,35 +1169,31 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): y, kernel_1, contracting_dims=(axis, contract_ind), + input_axes=self.dot_1_input_axes, + kernel_axes=self.kernel_axes_1, quantizer_set=ffn1_quantizer_set, ) + dot_1_output_axes = ( + *get_non_contracting_logical_axes(y.ndim, self.dot_1_input_axes, axis), + *get_non_contracting_logical_axes(kernel_1.ndim, self.kernel_axes_1, contract_ind), + ) + x = with_sharding_constraint_by_logical_axes(x, dot_1_output_axes) if self.enable_low_rank_adaptation: - wi_lora_a_kernel_shape = ( - kernel_1_compute_shape[0], - num_activations, - self.low_rank_adaptation_dim, - ) - wi_lora_a_kernel_init_shape = ( - kernel_1_each_shape[0], - num_activations, - self.low_rank_adaptation_dim, - ) - wi_lora_a_kernel_init_each_shape = ( - kernel_1_each_shape[0], + wi_lora_a_kernel_each_shape = ( + kernel_1_each_shape[: len(axis)], self.low_rank_adaptation_dim, ) - wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_init_shape) + wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_each_shape + 1) wi_lora_a_kernel = nn_partitioning.param_with_axes( "wi_lora_a_kernel", kernel_1_init, num_activations, - -1, - wi_lora_a_kernel_init_each_shape, + -2, + wi_lora_a_kernel_each_shape, self.dtype, axes=wi_lora_a_kernel_axes, ) - wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape) wi_lora_a_kernel = wi_lora_a_kernel.astype(input_dtype) wi_lora_b_kernel_shape = ( @@ -1232,7 +1214,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): x += _apply_low_rank_adaptation( y, axis, - num_activations * self.intermediate_dim, + (num_activations, self.intermediate_dim), wi_lora_a_kernel, wi_lora_b_kernel, self.low_rank_adaptation_alpha, @@ -1246,11 +1228,12 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): z = activation(x, normalized_acts) else: activations = [] - x = jnp.split(x, num_activations, axis=-1) + x = jnp.split(x, num_activations, axis=-2) for idx, act_fn in enumerate(normalized_acts): x_i = _convert_to_activation_function(act_fn)(x[idx]) activations.append(x_i) z = reduce(operator.mul, activations) + z = jnp.squeeze(z, axis=-2) z = z.astype(input_dtype) z = nn.Dropout( @@ -1264,7 +1247,12 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): # DenseGeneral 2 out = dense( - z, kernel_2, contracting_dims=(axis, contract_ind), quantizer_set=ffn2_quantizer_set + z, + kernel_2, + contracting_dims=(axis, contract_ind), + input_axes=self.dot_2_input_axes, + kernel_axes=self.kernel_axes_2, + quantizer_set=ffn2_quantizer_set, ) if self.enable_low_rank_adaptation: diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index 3fe32401bd..727ff78c2d 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -33,10 +33,9 @@ def layernorm_dense( norm_type: str = "layernorm", zero_centered_gamma: bool = False, epsilon: float = 1e-6, - # The logic axes of sharding constraint to the layernorm input. layernorm_input_axes: Tuple[str, ...] = None, - # The logic axes of sharding constraint to the dot input. dot_input_axes: Tuple[str, ...] = None, + kernel_axes: Tuple[str, ...] = None, quantizer_set: QuantizerSet = noop_quantizer_set, ) -> jnp.ndarray: """Apply layer normalization followed by dense layer transformation. @@ -56,6 +55,7 @@ def layernorm_dense( epsilon: Small constant for numerical stability in normalization layernorm_input_axes: Logical axes for sharding the layernorm input dot_input_axes: Logical axes for sharding the matrix multiplication input + kernel_axes: Logical axes for sharding the weight matrix quantizer_set: Set of quantizers for different tensor types Returns: @@ -78,6 +78,7 @@ def layernorm_dense( epsilon, layernorm_input_axes, dot_input_axes, + kernel_axes, quantizer_set, ) return output @@ -91,6 +92,7 @@ def layernorm_dense( 7, 8, 9, + 10, ), ) def _layernorm_dense( @@ -104,6 +106,7 @@ def _layernorm_dense( epsilon: float, layernorm_input_axes: Tuple[str, ...], dot_input_axes: Tuple[str, ...], + kernel_axes: Tuple[str, ...], quantizer_set, ): """Internal implementation of layernorm_dense with custom VJP. @@ -139,6 +142,7 @@ def _layernorm_dense( epsilon, layernorm_input_axes, dot_input_axes, + kernel_axes, quantizer_set, ) return output @@ -155,6 +159,7 @@ def _layernorm_dense_fwd_rule( epsilon, layernorm_input_axes, dot_input_axes, + kernel_axes, quantizer_set, ): """Forward pass rule for layernorm_dense. @@ -171,7 +176,6 @@ def _layernorm_dense_fwd_rule( x_contracting_dims = (len(x.shape) - 1,) k_contracting_dims = (0,) assert x.shape[-1] == kernel.shape[0] - assert len(kernel.shape) == 2 # Otherwise need to merge dims in quantize x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes) @@ -184,11 +188,12 @@ def _layernorm_dense_fwd_rule( norm_type, quantizer_set.x, ) + casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) # Kernel in (hidden_in, hidden_out...) - casted_kernel = tex.quantize(kernel, quantizer_set.kernel) - - casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) + flatten_axis = 1 - len(kernel.shape) + casted_kernel = tex.quantize(kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel) + casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) # NN GEMM # (batch..., hidden_in) x (hidden_in, hidden_out...) @@ -217,6 +222,7 @@ def _layernorm_dense_fwd_rule( k_contracting_dims, use_bias, quantizer_set, + flatten_axis, ) return output, ctx @@ -228,6 +234,7 @@ def _layernorm_dense_bwd_rule( epsilon, layernorm_input_axes, dot_input_axes, # pylint: disable=unused-argument + kernel_axes, ctx, grad, ): @@ -256,11 +263,12 @@ def _layernorm_dense_bwd_rule( k_contracting_dims_in_fwd, use_bias, quantizer_set, + flatten_axis, ) = ctx - grad = with_sharding_constraint_by_logical_axes(grad, dot_input_axes) - - casted_grad, dbias = tex.quantize_dbias(grad, is_dbias=use_bias, quantizer=quantizer_set.dgrad) + casted_grad, dbias = tex.quantize_dbias( + grad, is_dbias=use_bias, flatten_axis=flatten_axis, quantizer=quantizer_set.dgrad + ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim g_constracting_dim = tuple( @@ -291,6 +299,8 @@ def _layernorm_dense_bwd_rule( (x_constracting_dim, g_constracting_dim), ) + wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) + dx, dgamma, dbeta = tex.normalization_bwd( dgrad, x, diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index f6caad62e3..e7e3fd2fb9 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -23,6 +23,7 @@ from . import cpp_extensions as tex from .layernorm import canonicalize_norm_type from .quantize import with_sharding_constraint_by_logical_axes, QuantizerSet, noop_quantizer_set +from .sharding import get_non_contracting_logical_axes def layernorm_mlp( @@ -37,6 +38,8 @@ def layernorm_mlp( norm_input_axes: Tuple[str, ...] = None, dot_1_input_axes: Tuple[str, ...] = None, dot_2_input_axes: Tuple[str, ...] = None, + kernel_1_axes: Tuple[str, ...] = None, + kernel_2_axes: Tuple[str, ...] = None, ffn1_ckpt_name: str = "ffn1", ffn2_ckpt_name: str = "ffn2", activation_type: Sequence[Union[str, Callable]] = ("gelu",), @@ -66,6 +69,8 @@ def layernorm_mlp( norm_input_axes: Logical axes for sharding the layernorm input dot_1_input_axes: Logical axes for sharding the first matrix multiplication dot_2_input_axes: Logical axes for sharding the second matrix multiplication + kernel_1_axes: Logical axes for sharding the first weight matrix + kernel_2_axes: Logical axes for sharding the second weight matrix ffn1_ckpt_name: Name for checkpointing the first feed-forward network ffn2_ckpt_name: Name for checkpointing the second feed-forward network activation_type: Activation function(s) to apply after the first dense layer transformation @@ -109,6 +114,8 @@ def layernorm_mlp( norm_input_axes, dot_1_input_axes, dot_2_input_axes, + kernel_1_axes, + kernel_2_axes, ffn1_ckpt_name, ffn2_ckpt_name, activation_type, @@ -117,7 +124,7 @@ def layernorm_mlp( return output -@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15)) +@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)) def _layernorm_mlp( x: jnp.ndarray, gamma: jnp.ndarray, @@ -132,6 +139,8 @@ def _layernorm_mlp( norm_input_axes: Tuple[str, ...], dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...], + kernel_1_axes: Tuple[str, ...], + kernel_2_axes: Tuple[str, ...], ffn1_ckpt_name: str, ffn2_ckpt_name: str, activation_type: Sequence[Union[str, Callable]], @@ -179,6 +188,8 @@ def _layernorm_mlp( norm_input_axes, dot_1_input_axes, dot_2_input_axes, + kernel_1_axes, + kernel_2_axes, ffn1_ckpt_name, ffn2_ckpt_name, activation_type, @@ -201,6 +212,8 @@ def _layernorm_mlp_fwd_rule( norm_input_axes, dot_1_input_axes, dot_2_input_axes, + kernel_1_axes, + kernel_2_axes, ffn1_ckpt_name, ffn2_ckpt_name, activation_type, @@ -220,20 +233,21 @@ def _layernorm_mlp_fwd_rule( Returns: Tuple of (output, context) for automatic differentiation """ + del kernel_2_axes + ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets # x should be in shape of (batch..., hidden) - # Kernel_1 should be in shape of (hidden_in, activation_len * intermediate) + # Kernel_1 should be in shape of (hidden_in, activation_len, intermediate) # Kernel_2 should be in shape of (intermediate, hidden_in) - assert len(kernel_1.shape) == 2 + assert len(kernel_1.shape) == 3 assert len(kernel_2.shape) == 2 - assert kernel_1.shape[1] == kernel_2.shape[0] * len(activation_type) + assert kernel_1.shape[-2] == len(activation_type) x_contracting_dims = (len(x.shape) - 1,) k_contracting_dims = (0,) assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]] - assert kernel_1.shape[1] == len(activation_type) * kernel_2.shape[0] use_bias_1 = bias_1 is not None use_bias_2 = bias_1 is not None @@ -249,11 +263,10 @@ def _layernorm_mlp_fwd_rule( norm_type, quantizer=ffn1_quantizer_set.x, ) - - casted_kernel_1 = tex.quantize(kernel_1, quantizer=ffn1_quantizer_set.kernel) - casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) + casted_kernel_1 = tex.quantize(kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel) + # NN GEMM # (batch..., hidden_in) x (hidden_in, hidden_out) dot_1_output = tex.gemm( @@ -261,6 +274,13 @@ def _layernorm_mlp_fwd_rule( casted_kernel_1.get_colwise_tensor(), (x_contracting_dims, k_contracting_dims), ) + + dot_1_output_axes = ( + *get_non_contracting_logical_axes(x.ndim, dot_1_input_axes, x_contracting_dims), + *get_non_contracting_logical_axes(kernel_1.ndim, kernel_1_axes, k_contracting_dims), + ) + dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes) + if use_bias_1: bias_1_shape = bias_1.shape bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape @@ -283,6 +303,12 @@ def _layernorm_mlp_fwd_rule( (x_contracting_dims, k_contracting_dims), ) + dot_2_output_axes = ( + *get_non_contracting_logical_axes(x.ndim, dot_2_input_axes, x_contracting_dims), + *get_non_contracting_logical_axes(kernel_2.ndim, None, k_contracting_dims), + ) + dot_2_output = with_sharding_constraint_by_logical_axes(dot_2_output, dot_2_output_axes) + if use_bias_2: bias_2_shape = bias_2.shape bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape @@ -320,8 +346,10 @@ def _layernorm_mlp_bwd_rule( norm_input_axes, dot_1_input_axes, dot_2_input_axes, - ffn1_ckpt_name, # pylint: disable=unused-argument - ffn2_ckpt_name, # pylint: disable=unused-argument + kernel_1_axes, + kernel_2_axes, + ffn1_ckpt_name, + ffn2_ckpt_name, activation_type, ctx, grad, @@ -339,6 +367,7 @@ def _layernorm_mlp_bwd_rule( Returns: Tuple of gradients for all input parameters """ + del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name ( x, mu, @@ -369,11 +398,11 @@ def _layernorm_mlp_bwd_rule( ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim - g_constracting_dim_2 = tuple( + g_contracting_dims_2 = tuple( range(grad.ndim - len(kernel_2_shape) + len(k_contracting_dims_in_fwd), grad.ndim) ) # k_non_contracting_dims - k_constracting_dim_2 = tuple( + k_contracting_dims_2 = tuple( dim for dim in range(len(kernel_2_shape)) if dim not in k_contracting_dims_in_fwd ) @@ -382,12 +411,12 @@ def _layernorm_mlp_bwd_rule( dgrad_2 = tex.gemm( casted_grad.get_rowwise_tensor(), rowwise_casted_kernel_2, - (g_constracting_dim_2, k_constracting_dim_2), + (g_contracting_dims_2, k_contracting_dims_2), ) dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) - x_constracting_dim = g_constracting_dim = tuple( + x_contracting_dims = g_contracting_dims = tuple( range(0, len(x.shape) - len(x_contracting_dims_in_fwd)) ) @@ -396,8 +425,9 @@ def _layernorm_mlp_bwd_rule( wgrad_2 = tex.gemm( colwise_casted_act_out, casted_grad.get_colwise_tensor(), - (x_constracting_dim, g_constracting_dim), + (x_contracting_dims, g_contracting_dims), ) + wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) casted_dact_out, dbias_1 = tex.quantize_dact_dbias( dgrad_2, @@ -408,11 +438,12 @@ def _layernorm_mlp_bwd_rule( ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim - g_constracting_dim_1 = tuple( - range(dgrad_2.ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dgrad_2.ndim) + dact_out_ndim = casted_dact_out.get_rowwise_tensor().data.ndim + g_contracting_dims_1 = tuple( + range(dact_out_ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dact_out_ndim) ) # k_non_contracting_dims - k_constracting_dim_1 = tuple( + k_contracting_dims_1 = tuple( dim for dim in range(len(kernel_1_shape)) if dim not in k_contracting_dims_in_fwd ) @@ -420,19 +451,21 @@ def _layernorm_mlp_bwd_rule( dgrad_1 = tex.gemm( casted_dact_out.get_rowwise_tensor(), rowwise_casted_kernel_1, - (g_constracting_dim_1, k_constracting_dim_1), + (g_contracting_dims_1, k_contracting_dims_1), ) - dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, norm_input_axes) + dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) # TN GEMM # (hidden, batch...) x (hidden, batch...) wgrad_1 = tex.gemm( colwise_casted_ln_out, casted_dact_out.get_colwise_tensor(), - (x_constracting_dim, g_constracting_dim), + (x_contracting_dims, g_contracting_dims), ) + wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes) + dx, dgamma, dbeta = tex.normalization_bwd( dgrad_1, x, diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index cdbe764ab2..b1e9ba03b4 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -57,18 +57,27 @@ def _dq_func_block_scaling(scaled_tensor): data = scaled_tensor.data.astype(jnp.float32) data_shape = data.shape scale = scaled_tensor.scale_inv.view(jnp.uint8).astype(jnp.float32) + flatten_axis = scaled_tensor.flatten_axis + flatten_axis = len(data_shape) + flatten_axis if flatten_axis < 0 else flatten_axis + assert ( + 0 < flatten_axis < len(data_shape) + ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}" scale_shape = scaled_tensor.scaling_mode.get_scale_shape( - scaled_tensor.data.shape, scaled_tensor.is_colwise, is_padded=False + data_shape, scaled_tensor.is_colwise, is_padded=False, flatten_axis=flatten_axis ) scale = jax.lax.slice(scale, [0] * len(scale_shape), scale_shape) # slice out the padding + data = data.reshape( - *data_shape[:-2], - scale_shape[-2], - int(data_shape[-2] / scale_shape[-2]), + *data_shape[: flatten_axis - 1], + scale_shape[flatten_axis - 1], + int(data_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]), + *data_shape[flatten_axis:-1], scale_shape[-1], int(data_shape[-1] / scale_shape[-1]), ) - scale = jnp.expand_dims(scale, axis=(-1, -3)) + + # E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers. + scale = jnp.expand_dims(scale, axis=(flatten_axis + 2 - 2, -1)) # E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers. return jnp.asarray(data * jnp.power(2, scale - 127), scaled_tensor.dq_dtype).reshape( data_shape diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index 629e3f5bc2..bd7045453b 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -14,7 +14,7 @@ import jax import jax.numpy as jnp from jax.tree_util import register_pytree_node_class -from transformer_engine_jax import QuantizeAxis +from transformer_engine_jax import QuantizeLayout from .scaling_modes import ScalingMode from .tensor import ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory @@ -24,7 +24,7 @@ ) __all__ = [ - "QuantizeAxis", + "QuantizeLayout", "Quantizer", "QuantizerSet", "DelayedScaleQuantizer", @@ -45,12 +45,12 @@ class Quantizer(ABC): Attributes: q_dtype: The data type for quantized values scaling_mode: The scaling mode to use for quantization - q_axis: The quantization axis (row-wise, column-wise, or both) + q_layout: The quantization axis (row-wise, column-wise, or both) """ q_dtype: jnp.dtype scaling_mode: ScalingMode - q_axis: QuantizeAxis + q_layout: QuantizeLayout def tree_flatten(self): """Flatten the quantizer for JAX tree operations. @@ -59,7 +59,7 @@ def tree_flatten(self): Tuple of (children, aux_data) for tree operations """ children = () - aux_data = (self.q_dtype, self.scaling_mode, self.q_axis) + aux_data = (self.q_dtype, self.scaling_mode, self.q_layout) return (children, aux_data) @classmethod @@ -85,30 +85,31 @@ def is_2x2x(self) -> bool: Returns: True if using both row-wise and column-wise quantization """ - return self.q_axis == QuantizeAxis.ROWWISE_COLWISE + return self.q_layout == QuantizeLayout.ROWWISE_COLWISE @abstractmethod - def get_layout(self) -> str: - """Get the data layout. + def get_data_layout(self) -> str: + """Get the data data_layout. Returns: - Data layout in string format + Data data_layout in string format """ @abstractmethod - def _quantize_func(self, x, is_colwise=False, dq_dtype=None) -> ScaledTensor1x: + def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x: """Core quantization function to be implemented by subclasses. Args: x: Input tensor to quantize is_colwise: Whether to use column-wise quantization dq_dtype: Data type for dequantized values, default is x.dtype + flatten_axis: The quantization axis for the tensor Returns: A ScaledTensor1x containing the quantized data """ - def quantize(self, x, is_rowwise=False, is_colwise=False, dq_dtype=None): + def quantize(self, x, is_rowwise=False, is_colwise=False, dq_dtype=None, flatten_axis=-1): """Quantize a tensor using the internal _quantize_func(). Args: @@ -116,21 +117,26 @@ def quantize(self, x, is_rowwise=False, is_colwise=False, dq_dtype=None): is_rowwise: Whether to use row-wise quantization is_colwise: Whether to use column-wise quantization dq_dtype: Data type for dequantized values + flatten_axis: The quantization axis for the tensor Returns: A ScaledTensor1x or ScaledTensor2x containing the quantized data """ if (is_rowwise and is_colwise) or self.is_2x2x(): - rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype) - colwise_tensor = self._quantize_func(x, is_colwise=True, dq_dtype=dq_dtype) + rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis) + colwise_tensor = self._quantize_func( + x, is_colwise=True, dq_dtype=dq_dtype, flatten_axis=flatten_axis + ) return ScaledTensor2x(rowwise_tensor, colwise_tensor) if is_colwise: - return self._quantize_func(x, is_colwise=True, dq_dtype=dq_dtype) + return self._quantize_func( + x, is_colwise=True, dq_dtype=dq_dtype, flatten_axis=flatten_axis + ) - return self._quantize_func(x, dq_dtype=dq_dtype) + return self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis) - def get_scale_shapes(self, data_shape, is_padded=True): + def get_scale_shapes(self, data_shape, is_padded=True, flatten_axis=-1): """Get shapes for scale tensors. Args: @@ -140,7 +146,7 @@ def get_scale_shapes(self, data_shape, is_padded=True): Returns: Tuple of (rowwise_scale_shape, colwise_scale_shape) """ - return self.scaling_mode.get_scale_shape_2x(data_shape, is_padded) + return self.scaling_mode.get_scale_shape_2x(data_shape, is_padded, flatten_axis) def get_scale_dtype(self): """Get the data type for scale tensors. @@ -161,13 +167,13 @@ class DelayedScaleQuantizer(Quantizer): Attributes: scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING - q_axis: Quantization axis (default: ROWWISE_COLWISE) + q_layout: Quantization axis (default: ROWWISE_COLWISE) scale: Current scaling factor amax_history: History of maximum absolute values """ scaling_mode: ScalingMode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING - q_axis: QuantizeAxis = QuantizeAxis.ROWWISE_COLWISE + q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32)) amax_history: jnp.ndarray = field( @@ -181,35 +187,37 @@ def tree_flatten(self): Tuple of (children, aux_data) for tree operations """ children = (self.scale, self.amax_history) - aux_data = (self.q_dtype, self.scaling_mode, self.q_axis) + aux_data = (self.q_dtype, self.scaling_mode, self.q_layout) return (children, aux_data) - def get_layout(self) -> str: - """Get the data layout string. + def get_data_layout(self) -> str: + """Get the data data_layout string. Returns: - Data layout in string format + Data data_layout in string format Raises: ValueError: If quantization axis is invalid """ - layout = "NT" - if self.q_axis == QuantizeAxis.ROWWISE_COLWISE: - return layout - if self.q_axis == QuantizeAxis.ROWWISE: - return layout[0] - if self.q_axis == QuantizeAxis.COLWISE: - return layout[1] - raise ValueError(f"Invalid q_axis: {self.q_axis}") - - def _quantize_func(self, x: jnp.ndarray, is_colwise=False, dq_dtype=None) -> ScaledTensor1x: + data_layout = "NT" + if self.q_layout == QuantizeLayout.ROWWISE_COLWISE: + return data_layout + if self.q_layout == QuantizeLayout.ROWWISE: + return data_layout[0] + if self.q_layout == QuantizeLayout.COLWISE: + return data_layout[1] + raise ValueError(f"Invalid q_layout: {self.q_layout}") + + def _quantize_func( + self, x: jnp.ndarray, is_colwise=False, dq_dtype=None, flatten_axis=-1 + ) -> ScaledTensor1x: """Quantize function helper for delayed scaling FP8. Args: x: Input tensor to quantize is_colwise: Whether to use column-wise quantization dq_dtype: Data type for dequantized values - + flatten_axis: The quantization axis for the tensor Returns: A ScaledTensor1x containing the quantized data """ @@ -232,9 +240,12 @@ def _quantize_func(self, x: jnp.ndarray, is_colwise=False, dq_dtype=None) -> Sca scale_inv=scale_inv, scaling_mode=self.scaling_mode, dq_dtype=dq_dtype, + flatten_axis=flatten_axis, ) - def quantize(self, x, is_rowwise: bool = None, is_colwise: bool = None, dq_dtype=None): + def quantize( + self, x, is_rowwise: bool = None, is_colwise: bool = None, dq_dtype=None, flatten_axis=-1 + ): """Quantize a tensor using the internal _quantize_func(). Args: @@ -242,32 +253,40 @@ def quantize(self, x, is_rowwise: bool = None, is_colwise: bool = None, dq_dtype is_rowwise: Whether to use row-wise quantization is_colwise: Whether to use column-wise quantization dq_dtype: Data type for dequantized values + flatten_axis: The quantization axis for the tensor Returns: A ScaledTensor1x or ScaledTensor2x containing the quantized data """ dq_dtype = dq_dtype if dq_dtype is not None else x.dtype + if flatten_axis < 0: + flatten_axis += x.ndim + assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!" + is_rowwise = ( is_rowwise if is_rowwise is not None - else (self.q_axis == QuantizeAxis.ROWWISE or self.is_2x2x()) + else (self.q_layout == QuantizeLayout.ROWWISE or self.is_2x2x()) ) is_colwise = ( is_colwise if is_colwise is not None - else (self.q_axis == QuantizeAxis.COLWISE or self.is_2x2x()) + else (self.q_layout == QuantizeLayout.COLWISE or self.is_2x2x()) ) - rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype) + rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis) colwise_tensor = None if is_colwise: colwise_tensor = ScaledTensorFactory.create_1x( - data=jnp.transpose(rowwise_tensor.data, (-1, *range(rowwise_tensor.data.ndim - 1))), + data=jnp.transpose( + rowwise_tensor.data, (*range(flatten_axis, x.ndim), *range(flatten_axis)) + ), scale_inv=rowwise_tensor.scale_inv, scaling_mode=self.scaling_mode, dq_dtype=dq_dtype, is_colwise=True, - layout="T", + data_layout="T", + flatten_axis=flatten_axis, ) if is_colwise and is_rowwise: return ScaledTensor2x(rowwise_tensor, colwise_tensor) @@ -353,46 +372,56 @@ class BlockScaleQuantizer(Quantizer): Attributes: scaling_mode: Set to NVTE_MXFP8_1D_SCALING - q_axis: Quantization axis (default: ROWWISE_COLWISE) + q_layout: Quantization axis (default: ROWWISE_COLWISE) """ scaling_mode: ScalingMode = ScalingMode.NVTE_MXFP8_1D_SCALING - q_axis: QuantizeAxis = QuantizeAxis.ROWWISE_COLWISE + q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE - def get_layout(self) -> str: - """Get the data layout string. + def get_data_layout(self) -> str: + """Get the data data_layout string. Returns: - Data layout in string format + Data data_layout in string format """ if self.is_2x2x(): return "NN" return "N" - def _quantize_func(self, x, is_colwise=False, dq_dtype=None) -> ScaledTensor1x: + def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x: """Quantize function helper for block scaling FP8. Args: x: Input tensor to quantize is_colwise: Whether to use column-wise quantization dq_dtype: Data type for dequantized values + flatten_axis: The quantization axis for the tensor Returns: A ScaledTensor1x containing the quantized data """ # TODO(Phuong): use quantize_func from JAX + if flatten_axis < 0: + flatten_axis = x.ndim + flatten_axis + assert ( + 0 <= flatten_axis < x.ndim + ), f"Invalid flatten_axis: {flatten_axis} for tensor of shape {x.shape}" + dq_dtype = dq_dtype if dq_dtype is not None else x.dtype x_shape = x.shape - scale_shape = self.scaling_mode.get_scale_shape(x_shape, is_colwise, is_padded=False) + scale_shape = self.scaling_mode.get_scale_shape( + x_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis + ) scale_dtype = self.scaling_mode.get_scale_dtype() x = x.reshape( - *x_shape[:-2], - scale_shape[-2], - int(x_shape[-2] / scale_shape[-2]), + *x_shape[: flatten_axis - 1], + scale_shape[flatten_axis - 1], + int(x_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]), + *x_shape[flatten_axis:-1], scale_shape[-1], int(x_shape[-1] / scale_shape[-1]), ) - amax = jnp.max(jnp.abs(x), axis=(-3, -1), keepdims=True) + amax = jnp.max(jnp.abs(x), axis=(flatten_axis + 2 - 2, -1), keepdims=True) MAX = jnp.finfo(self.q_dtype).max.astype(jnp.float32) scales = amax.astype(jnp.float32) / MAX @@ -409,6 +438,7 @@ def _quantize_func(self, x, is_colwise=False, dq_dtype=None) -> ScaledTensor1x: self.scaling_mode, is_colwise=is_colwise, dq_dtype=dq_dtype, + flatten_axis=flatten_axis, ) def _cast_to_e8m0_with_rounding_up(self, scales): @@ -509,7 +539,7 @@ def create( n_quantizers: int = 1, scaling_mode: ScalingMode = None, q_dtype: jnp.dtype = None, - q_axis: QuantizeAxis = None, + q_layout: QuantizeLayout = None, **kwargs, ) -> Quantizer: """Create one or more quantizers with specified parameters. @@ -518,7 +548,8 @@ def create( n_quantizers: Number of quantizers to create scaling_mode: Scaling mode to use q_dtype: Quantization data type - q_axis: Quantization axis + q_layout: Quantization axis + flatten_axis: The quantization axis for the tensor **kwargs: Additional arguments for quantizer initialization Returns: @@ -534,7 +565,7 @@ def create( quantizer_type = QuantizerFactory.quantizer_type_map.get(scaling_mode) quantizers.append( quantizer_type( - q_dtype=q_dtype, scaling_mode=scaling_mode, q_axis=q_axis, **kwargs + q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout, **kwargs ) ) return quantizers[0] if len(quantizers) == 1 else tuple(quantizers) @@ -554,11 +585,11 @@ def _create_set(scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, **kwargs) -> Quanti A QuantizerSet instance """ if is_2x2x: - q_axis_x = q_axis_kernel = q_axis_dgrad = QuantizeAxis.ROWWISE_COLWISE + q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE_COLWISE else: - q_axis_x = QuantizeAxis.ROWWISE - q_axis_kernel = QuantizeAxis.COLWISE - q_axis_dgrad = None + q_layout_x = QuantizeLayout.ROWWISE + q_layout_kernel = QuantizeLayout.COLWISE + q_layout_dgrad = None if "quantize_meta_set" in kwargs: quantize_meta_set = kwargs.get("quantize_meta_set") @@ -577,9 +608,11 @@ def _create_set(scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, **kwargs) -> Quanti else: args_x = args_kernel = args_grad = {} - q_x = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_axis_x, **args_x) - q_kernel = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_axis_kernel, **args_kernel) - q_dgrad = QuantizerFactory.create(1, scaling_mode, bwd_dtype, q_axis_dgrad, **args_grad) + q_x = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_layout_x, **args_x) + q_kernel = QuantizerFactory.create( + 1, scaling_mode, fwd_dtype, q_layout_kernel, **args_kernel + ) + q_dgrad = QuantizerFactory.create(1, scaling_mode, bwd_dtype, q_layout_dgrad, **args_grad) return QuantizerSet(x=q_x, kernel=q_kernel, dgrad=q_dgrad) @staticmethod diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index 805c034334..a9c93a3553 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -40,7 +40,11 @@ def get_scale_dtype(self) -> jnp.dtype: @abstractmethod def get_scale_shape( - self, data_shape: Tuple[int, ...], is_colwise: bool = False, is_padded: bool = True + self, + data_shape: Tuple[int, ...], + is_colwise: bool = False, + is_padded: bool = True, + flatten_axis: int = -1, ) -> Tuple[int, ...]: """Get the shape for scale tensors. @@ -48,7 +52,7 @@ def get_scale_shape( data_shape: The shape of the tensor being quantized is_colwise: Whether the scaling is column-wise is_padded: Whether to return padded shape - + flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. Returns: The shape for scale tensors """ @@ -69,7 +73,11 @@ def get_scale_dtype(self) -> jnp.dtype: return jnp.float32 def get_scale_shape( - self, data_shape: Tuple[int, ...], is_colwise: bool = False, is_padded: bool = True + self, + data_shape: Tuple[int, ...], + is_colwise: bool = False, + is_padded: bool = True, + flatten_axis: int = -1, ) -> Tuple[int, ...]: """Get the shape for scale tensors in delayed scaling. @@ -77,6 +85,7 @@ def get_scale_shape( data_shape: The shape of the tensor being scaled is_colwise: Whether the scaling is column-wise is_padded: Whether to return padded shape + flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. Returns: The shape for scale tensors - (1,) @@ -113,8 +122,35 @@ def get_scale_dtype(self) -> jnp.dtype: """ return jnp.float8_e8m0fnu + def _apply_scale_shape_correction(self, data_shape, n_scale_blocks, scale_block_dim): + """Remove excess padding from the scale shape and return the shape with respect to the original data shape.""" + if len(data_shape) > 1: + # handle last dim + assert data_shape[-1] % scale_block_dim == 0 + last = data_shape[-1] // scale_block_dim + scale_shape = (last,) + assert n_scale_blocks % last == 0 + n_scale_blocks //= last + # handle middle dim, exclude first and last + for mid in reversed(data_shape[1:-1]): + scale_shape = (mid,) + scale_shape + assert n_scale_blocks % mid == 0 + n_scale_blocks //= mid + scale_shape = (n_scale_blocks,) + scale_shape + else: + scale_shape = (n_scale_blocks,) + + assert len(scale_shape) == len( + data_shape + ), f"scale_shape {scale_shape}, data_shape {data_shape}" + return scale_shape + def get_scale_shape( - self, data_shape: Tuple[int, ...], is_colwise: bool = False, is_padded: bool = True + self, + data_shape: Tuple[int, ...], + is_colwise: bool = False, + is_padded: bool = True, + flatten_axis: int = -1, ) -> Tuple[int, ...]: """Get the shape for scale tensors in block scaling. @@ -122,6 +158,7 @@ def get_scale_shape( data_shape: The shape of the tensor being quantized is_colwise: Whether the scaling is column-wise is_padded: Whether to return padded shape + flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. Returns: The shape for scale tensors @@ -135,35 +172,48 @@ def get_scale_shape( block_x, block_y = self._block_dims alignment_x, alignment_y = block_alignment - seq_axis = len(data_shape) - 2 - + if flatten_axis < 0: + flatten_axis = len(data_shape) + flatten_axis assert ( - data_shape[seq_axis] % block_x == 0 - ), f"Input data of shape {data_shape} should be padded by {block_x} in axes={seq_axis}" + 0 < flatten_axis < len(data_shape) + ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}" + + assert data_shape[flatten_axis - 1] % block_x == 0, ( + f"Data shape {data_shape} should be divisible by block_x {block_x} in axis" + f" {flatten_axis - 1}" + ) assert ( data_shape[-1] % block_y == 0 - ), f"Input data of shape {data_shape} should be padded by {block_y} in axis -1" + ), f"Data shape {data_shape} should be divisible by block_y {block_y} in axis -1" - # NOTE: this overpads if dim > 2 and dims before seq_axis are greater than 1 - n_block_seq = data_shape[seq_axis] // block_x - n_block_y = data_shape[-1] // block_y + flattened_first_dim = reduce(operator.mul, data_shape[:flatten_axis], 1) + flattened_last_dim = reduce(operator.mul, data_shape[flatten_axis:], 1) - n_flat_first_dim = reduce(operator.mul, data_shape[:seq_axis], 1) * n_block_seq + assert flattened_first_dim % block_x == 0, ( + f"Flattened first dim - mutiplication of axes={tuple(range(0, flatten_axis))} of shape" + f" {data_shape} - should be divisible by block_x {block_x}" + ) + assert flattened_last_dim % block_y == 0, ( + "Flattened last dim - mutiplication of" + f" axes={tuple(range(flatten_axis, len(data_shape)))} of shape {data_shape} - should be" + f" divisible by block_y {block_y}" + ) - # Padding - n_flat_first_dim = ((n_flat_first_dim + alignment_x - 1) // alignment_x) * alignment_x - n_block_y = ((n_block_y + alignment_y - 1) // alignment_y) * alignment_y + n_block_x = int(flattened_first_dim / block_x) + n_block_y = int(flattened_last_dim / block_y) - out_shape = () - for i in range(seq_axis): - d = data_shape[i] - out_shape += (d,) - assert n_flat_first_dim % d == 0 - n_flat_first_dim //= d + # padding + n_block_x = int(((n_block_x + alignment_x - 1) // alignment_x) * alignment_x) + n_block_y = int(((n_block_y + alignment_y - 1) // alignment_y) * alignment_y) - out_shape += (n_flat_first_dim, n_block_y) + first_dim_scale_shape = self._apply_scale_shape_correction( + data_shape[:flatten_axis], n_block_x, block_x + ) + last_dim_scale_shape = self._apply_scale_shape_correction( + data_shape[flatten_axis:], n_block_y, block_y + ) - return out_shape + return (*first_dim_scale_shape, *last_dim_scale_shape) # (Phuong: Map the NVTEScalingMode value to the ScalingMode @@ -208,34 +258,40 @@ def get_scale_dtype(self): """ return self._get_impl().get_scale_dtype() - def get_scale_shape_2x(self, data_shape, is_padded=True) -> Tuple[Tuple[int]]: + def get_scale_shape_2x(self, data_shape, is_padded=True, flatten_axis=-1) -> Tuple[Tuple[int]]: """Get shapes for both row-wise and column-wise scaling. Args: data_shape: Shape of the data tensor is_padded: Whether to use padded shapes + flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. Returns: Tuple of (rowwise_scale_shape, colwise_scale_shape) """ rowwise_scale_shape = self.get_scale_shape( - data_shape, is_colwise=False, is_padded=is_padded + data_shape, is_colwise=False, is_padded=is_padded, flatten_axis=flatten_axis + ) + colwise_scale_shape = self.get_scale_shape( + data_shape, is_colwise=True, is_padded=is_padded, flatten_axis=flatten_axis ) - colwise_scale_shape = self.get_scale_shape(data_shape, is_colwise=True, is_padded=is_padded) return (rowwise_scale_shape, colwise_scale_shape) - def get_scale_shape(self, data_shape, is_colwise, is_padded=True) -> Tuple[int]: + def get_scale_shape( + self, data_shape, is_colwise, is_padded=True, flatten_axis=-1 + ) -> Tuple[int]: """Get the shape for scale tensors in this mode. Args: data_shape: Shape of the data tensor is_colwise: Whether to use column-wise scaling is_padded: Whether to use padded shapes + flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. Returns: The shape for scale tensors """ - return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded) + return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded, flatten_axis) def __eq__(self, other): """Compare this scaling mode with another. diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 8c01dd9af0..c34a235d94 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -15,7 +15,7 @@ import jax.numpy as jnp from jax.tree_util import register_pytree_node_class -from transformer_engine_jax import QuantizeAxis +from transformer_engine_jax import QuantizeLayout from .scaling_modes import ScalingMode from .dequantizer import Dequantizer @@ -84,6 +84,17 @@ def get_colwise_tensor(self): ValueError: If called on a tensor that doesn't support column-wise access """ + @abstractmethod + def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]): + """Applies sharding constraints to a tensor based on logical axis names. + + Args: + logical_axis_names: Tuple of logical axis names for sharding + + Returns: + The tensor with applied sharding constraints + """ + @register_pytree_node_class @dataclass @@ -100,7 +111,8 @@ class ScaledTensor1x(ScaledTensor): dq_dtype: The data type for dequantized values _dq_func: The dequantization function is_colwise: Whether the tensor uses column-wise quantization - layout: The layout specification for the tensor + data_layout: The data_layout specification for the tensor + flatten_axis: The quantization axis for the tensor """ data: jnp.ndarray @@ -109,7 +121,8 @@ class ScaledTensor1x(ScaledTensor): dq_dtype: jnp.dtype _dq_func: Callable is_colwise: bool - layout: str + data_layout: str + flatten_axis: int = -1 def __post_init__(self): """Validates and adjusts the scale_inv shape after initialization. @@ -117,11 +130,22 @@ def __post_init__(self): Ensures the scale_inv shape matches the expected shape based on the scaling mode and quantization direction. Pads the scale_inv if necessary. """ + flatten_axis = ( + len(self.data.shape) + self.flatten_axis if self.flatten_axis < 0 else self.flatten_axis + ) + assert ( + 0 < flatten_axis < len(self.data.shape) + ), f"flatten_axis {flatten_axis} is out of bounds for shape {self.data.shape}" + + if self.data_layout == "T": + flatten_axis = self.data.ndim - flatten_axis + self.flatten_axis = flatten_axis + expected_scale_shape = self.scaling_mode.get_scale_shape( - self.data.shape, self.is_colwise, is_padded=True + self.data.shape, self.is_colwise, is_padded=True, flatten_axis=flatten_axis ) expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape( - self.data.shape, self.is_colwise, is_padded=False + self.data.shape, self.is_colwise, is_padded=False, flatten_axis=flatten_axis ) if self.scale_inv.shape != expected_scale_shape: assert self.scale_inv.shape == expected_unpadded_scale_shape, ( @@ -144,7 +168,14 @@ def tree_flatten(self): A tuple containing (children, aux_data) for tree operations """ children = (self.data, self.scale_inv) - aux_data = (self.scaling_mode, self.dq_dtype, self._dq_func, self.is_colwise, self.layout) + aux_data = ( + self.scaling_mode, + self.dq_dtype, + self._dq_func, + self.is_colwise, + self.data_layout, + self.flatten_axis, + ) return (children, aux_data) def dequantize(self): @@ -183,6 +214,46 @@ def get_colwise_tensor(self): raise ValueError("Calling get_colwise_tensor() from a rowwise ScaledTensor1x!") + def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]): + """Applies sharding constraints to a tensor based on logical axis names. + + Args: + logical_axis_names: Tuple of logical axis names for sharding + + Returns: + The tensor with applied sharding constraints + """ + if not logical_axis_names: + return self + + # axis_names were given for N layout, so needs to be transpose for T layout + if self.data_layout == "T": + assert self.flatten_axis > 0 + flatten_axis = -self.flatten_axis + axis_names = (*logical_axis_names[flatten_axis:], *logical_axis_names[:flatten_axis]) + else: + axis_names = logical_axis_names + + data = with_sharding_constraint_by_logical_axes(self.data, axis_names) + + if self.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + # TODO(Phuong): Handle padding !? + scale_inv = with_sharding_constraint_by_logical_axes(self.scale_inv, axis_names) + else: + scale_inv = self.scale_inv + + # TODO(Phuong): constaint padded scale_inv? + return ScaledTensor1x( + data=data, + scale_inv=scale_inv, + scaling_mode=self.scaling_mode, + dq_dtype=self.dq_dtype, + _dq_func=self._dq_func, + is_colwise=self.is_colwise, + data_layout=self.data_layout, + flatten_axis=self.flatten_axis, + ) + @register_pytree_node_class @dataclass @@ -233,6 +304,27 @@ def get_colwise_tensor(self): """ return self.colwise_tensor + def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]): + """Applies sharding constraints to a tensor based on logical axis names. + + Args: + logical_axis_names: Tuple of logical axis names for sharding + + Returns: + The tensor with applied sharding constraints + """ + if not logical_axis_names: + return self + + rowwise_tensor = self.rowwise_tensor.apply_sharding_constraint_by_logical_axes( + logical_axis_names + ) + colwise_tensor = self.colwise_tensor.apply_sharding_constraint_by_logical_axes( + logical_axis_names + ) + + return ScaledTensor2x(rowwise_tensor, colwise_tensor) + @dataclass class ScaledTensorFactory: @@ -244,7 +336,13 @@ class ScaledTensorFactory: @staticmethod def create_1x( - data, scale_inv, scaling_mode, dq_dtype=jnp.bfloat16, is_colwise=False, layout="N" + data, + scale_inv, + scaling_mode, + dq_dtype=jnp.bfloat16, + is_colwise=False, + data_layout="N", + flatten_axis=-1, ): """Creates a single-scale quantized tensor. @@ -254,13 +352,16 @@ def create_1x( scaling_mode: The scaling mode for quantization dq_dtype: The data type for dequantized values (default: bfloat16) is_colwise: Whether to use column-wise quantization (default: False) - layout: The layout specification (default: "N") + data_layout: The data_layout specification (default: "N") + flatten_axis: The quantization axis for the tensor Returns: A ScaledTensor1x instance """ dq_func = Dequantizer.funcs.get(scaling_mode) - return ScaledTensor1x(data, scale_inv, scaling_mode, dq_dtype, dq_func, is_colwise, layout) + return ScaledTensor1x( + data, scale_inv, scaling_mode, dq_dtype, dq_func, is_colwise, data_layout, flatten_axis + ) @staticmethod def create_2x( @@ -270,7 +371,8 @@ def create_2x( colwise_scale_inv, scaling_mode, dq_dtype=jnp.bfloat16, - layout="NN", + data_layout="NN", + flatten_axis=-1, ): """Creates a double-scale quantized tensor. @@ -281,7 +383,8 @@ def create_2x( colwise_scale_inv: The column-wise inverse scaling factors scaling_mode: The scaling mode for quantization dq_dtype: The data type for dequantized values (default: bfloat16) - layout: The layout specification (default: "NN") + data_layout: The data_layout specification (default: "NN") + flatten_axis: The quantization axis for the tensor Returns: A ScaledTensor2x instance @@ -294,7 +397,8 @@ def create_2x( dq_dtype, dq_func, is_colwise=False, - layout=layout[0], + data_layout=data_layout[0], + flatten_axis=flatten_axis, ) colwise_tensor = ScaledTensor1x( colwise_data, @@ -303,7 +407,8 @@ def create_2x( dq_dtype, dq_func, is_colwise=True, - layout=layout[1], + data_layout=data_layout[1], + flatten_axis=flatten_axis, ) return ScaledTensor2x(rowwise_tensor, colwise_tensor) @@ -315,8 +420,9 @@ def create( colwise_scale_inv: jnp.ndarray, scaling_mode: ScalingMode, dq_dtype: jnp.dtype = jnp.bfloat16, - layout: str = "NN", - q_axis: QuantizeAxis = QuantizeAxis.ROWWISE, + data_layout: str = "NN", + q_layout: QuantizeLayout = QuantizeLayout.ROWWISE, + flatten_axis: int = -1, ): """Creates a scaled tensor based on the quantization axis. @@ -327,13 +433,13 @@ def create( colwise_scale_inv: The column-wise inverse scaling factors scaling_mode: The scaling mode for quantization dq_dtype: The data type for dequantized values (default: bfloat16) - layout: The layout specification (default: "NN") - q_axis: The quantization axis (default: ROWWISE) + data_layout: The data_layout specification (default: "NN") + q_layout: The quantization axis (default: ROWWISE) Returns: - Either a ScaledTensor1x or ScaledTensor2x instance depending on q_axis + Either a ScaledTensor1x or ScaledTensor2x instance depending on q_layout """ - if q_axis == QuantizeAxis.ROWWISE_COLWISE: + if q_layout == QuantizeLayout.ROWWISE_COLWISE: return ScaledTensorFactory.create_2x( data, scale_inv, @@ -341,12 +447,19 @@ def create( colwise_scale_inv, scaling_mode, dq_dtype, - layout=layout, + data_layout=data_layout, + flatten_axis=flatten_axis, ) - is_colwise = q_axis == QuantizeAxis.COLWISE + is_colwise = q_layout == QuantizeLayout.COLWISE return ScaledTensorFactory.create_1x( - data, scale_inv, scaling_mode, dq_dtype, is_colwise=is_colwise, layout=layout[0] + data, + scale_inv, + scaling_mode, + dq_dtype, + is_colwise=is_colwise, + data_layout=data_layout[0], + flatten_axis=flatten_axis, ) @@ -360,24 +473,7 @@ def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, . Returns: The tensor with applied sharding constraints """ - if isinstance(x, ScaledTensor1x): - return ScaledTensor1x( - data=with_sharding_constraint_by_logical_axes(x.data, logical_axis_names), - scale_inv=x.scale_inv, - scaling_mode=x.scaling_mode, - dq_dtype=x.dq_dtype, - _dq_func=x._dq_func, - is_colwise=x.is_colwise, - layout=x.layout, - ) - if isinstance(x, ScaledTensor2x): - return ScaledTensor2x( - rowwise_tensor=with_sharding_constraint_by_logical_axes( - x.rowwise_tensor, logical_axis_names - ), - colwise_tensor=with_sharding_constraint_by_logical_axes( - x.colwise_tensor, logical_axis_names - ), - ) + if isinstance(x, ScaledTensor): + return x.apply_sharding_constraint_by_logical_axes(logical_axis_names) return original_with_sharding_constraint_by_logical_axes(x, logical_axis_names) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 8e7ce93986..df3f38cbd1 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -89,7 +89,11 @@ def generate_pspec(logical_axis_names): Convert logical axes to PartitionSpec """ rules = get_sharding_map_logic_axis_to_mesh_axis() - mesh_axis_names = [rules[name] for name in logical_axis_names] + # mesh_axis_names = [rules[name] for name in logical_axis_names] + mesh_axis_names = [] + for name in logical_axis_names: + axis_name = rules[name] if name in rules else None + mesh_axis_names.append(axis_name) pspec = jax.sharding.PartitionSpec(*mesh_axis_names) return pspec @@ -112,7 +116,7 @@ def with_sharding_constraint_by_logical_axes(x: jnp.array, logical_axis_names: t """ A wrapper function to jax.lax.with_sharding_constraint to accept logical axes. """ - if logical_axis_names is None: + if not logical_axis_names: return x assert len(x.shape) == len(logical_axis_names) @@ -315,3 +319,25 @@ class ShardingType(Enum): TP_ROW = (MajorShardingType.TP, "tp_row") DP_TP_COL = (MajorShardingType.DPTP, "dp_tp_col") DP_TP_ROW = (MajorShardingType.DPTP, "dp_tp_row") + + +def get_non_contracting_logical_axes(ndim, logical_axes, contracting_dims): + """Get logical axes for non-contracting dimensions. + + Args: + ndim: Number of dimensions in the tensor. + logical_axes: Tuple of logical axes for each dimension. + contracting_dims: Set of dimensions that are being contracted. + + Returns: + Tuple of logical axes for non-contracting dimensions. + """ + if not logical_axes: + logical_axes = (None,) * ndim + elif len(logical_axes) < ndim: + logical_axes = logical_axes + (None,) * (ndim - len(logical_axes)) + assert len(logical_axes) == ndim + + non_contracting_dims = [i for i in range(ndim) if i not in contracting_dims] + non_contracting_logical_axes = tuple(logical_axes[i] for i in non_contracting_dims) + return non_contracting_logical_axes From efdf8e0d963b61c15eba7b29347f02dab2b33488 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Fri, 4 Apr 2025 14:21:13 -0700 Subject: [PATCH 19/29] Update test gemm supported predicate. Signed-off-by: Keith Wyss --- tests/pytorch/test_float8_blockwise_gemm_exact.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py index 94014d36b5..9ddb4b9989 100644 --- a/tests/pytorch/test_float8_blockwise_gemm_exact.py +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -12,12 +12,17 @@ Float8BlockQuantizer, Float8BlockwiseQTensor, ) +from transformer_engine.pytorch.utils import get_device_compute_capability from references.blockwise_quantizer_reference import CuBLASScaleMunger from references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm def fp8_blockwise_gemm_supported() -> bool: - return float(torch.version.cuda) >= 12.9 + return ( + get_device_compute_capability() >= (9, 0) + and get_device_compute_capability() < (10, 0) + and float(torch.version.cuda) >= 12.9 + ) def cublas_gemm_fp8_blockwise_case( From a9f209acca6ebb9441a8c04b4e834da9f6d0ead2 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Fri, 4 Apr 2025 17:38:37 -0700 Subject: [PATCH 20/29] Use sgemm like interfaces and naming. Signed-off-by: Keith Wyss --- .../common/gemm/cublaslt_gemm.cu | 63 ++++++++++--------- 1 file changed, 35 insertions(+), 28 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index ed13fccaef..aa88eb9bc4 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -82,7 +82,7 @@ struct GemmParam { GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA, const transformer_engine::Tensor &B, const cublasOperation_t transB, - int A0, int A1, int B0, int B1) { + int m, int k, int n) { using namespace transformer_engine; NVTE_CHECK( A.scaling_mode == B.scaling_mode || @@ -97,8 +97,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla bool transb_bool = transB == CUBLAS_OP_T; int arch = cuda::sm_arch(cuda::current_device()); - int a_major_dim; - int b_major_dim; + int a_storage_outer_dim; + int b_storage_outer_dim; if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) { // For this scaling mode, the quantizer stores // rowwise data and transposes the data for columnwise @@ -106,10 +106,10 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla // and the transA and transB values to pass to cublas // should always be TN. - a_major_dim = transa_bool ? A0 : A1; - b_major_dim = transb_bool ? B1 : B0; - ret.lda = transa_bool ? A1 : A0; - ret.ldb = transb_bool ? B0 : B1; + a_storage_outer_dim = m; + b_storage_outer_dim = n; + ret.lda = k; + ret.ldb = k; ret.transA = CUBLAS_OP_T; ret.transB = CUBLAS_OP_N; @@ -122,10 +122,10 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla // transB, which are passed along to cuBLAS. // NOTE: There is some logic below that may edit this // decision for A and B depending on dtype and arch. - a_major_dim = A0; - b_major_dim = B0; - ret.lda = A1; - ret.ldb = B1; + a_storage_outer_dim = transa_bool ? m : k; + b_storage_outer_dim = transb_bool ? k : n; + ret.lda = transa_bool ? k : m; + ret.ldb = transb_bool ? n : k; if (transa_bool && transb_bool) { // TT NVTE_ERROR("TT layout not allowed."); @@ -146,8 +146,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.A = A.columnwise_data.dptr; ret.transA = CUBLAS_OP_T; ret.A_scale_inv = A.columnwise_scale_inv.dptr; - a_major_dim = A1; - ret.lda = A0; + a_storage_outer_dim = m; + ret.lda = k; } } } @@ -162,8 +162,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.B = B.columnwise_data.dptr; ret.transB = CUBLAS_OP_N; ret.B_scale_inv = B.columnwise_scale_inv.dptr; - b_major_dim = B1; - ret.ldb = B0; + b_storage_outer_dim = n; + ret.ldb = k; } } } else { @@ -193,12 +193,12 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla "Inner dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); // Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement. // Smallest supported CType is 2 bytes in this scaling mode. - NVTE_CHECK((a_major_dim % 8) == 0, + NVTE_CHECK((a_storage_outer_dim % 8) == 0, "Outer dimension requirement on A for NVTE_BLOCK_SCALING GEMM. Caller must pad."); // Observed this requirement only present for B tensor is 1D quantized. if (B.scaling_mode == NVTE_BLOCK_SCALING_1D) { NVTE_CHECK( - (b_major_dim % 8) == 0, + (b_storage_outer_dim % 8) == 0, "Outer dimension requirement on B for NVTE_BLOCK_SCALING GEMM. Caller must pad."); } NVTE_CHECK((ret.ldb % 16) == 0, @@ -221,14 +221,11 @@ namespace transformer_engine { using cublasHandleManager = detail::HandleManager; void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, - const Tensor *inputBias, Tensor *outputPreGelu, int A0, int A1, int B0, int B1, + const Tensor *inputBias, Tensor *outputPreGelu, int m, int k, int n, cublasOperation_t transa, cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool gemm_producer, const Tensor *inputCounter, cudaStream_t stream) { - const int m = transa == CUBLAS_OP_T ? A0 : A1; - const int k = transa == CUBLAS_OP_T ? A1 : A0; - const int n = transb == CUBLAS_OP_T ? B1 : B0; const int ldd = m; // Return immediately if GEMM is trivial if (m <= 0 || n <= 0) { @@ -236,7 +233,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, } NVTE_CHECK(k > 0); - const GemmParam ¶m = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, A0, A1, B0, B1); + const GemmParam ¶m = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, k, n); void *C = outputD->data.dptr; void *D = outputD->data.dptr; @@ -359,11 +356,11 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride, sizeof(dummy_a_vec_stride))); } -#if CUDA_VERSION >= 12090 } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D || inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) && (inputB->scaling_mode == NVTE_BLOCK_SCALING_1D || inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) { +#if CUDA_VERSION >= 12090 float *A_scale_inverse = reinterpret_cast(param.A_scale_inv); float *B_scale_inverse = reinterpret_cast(param.B_scale_inv); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, @@ -381,8 +378,10 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, scaling_mode_b = inputB->scaling_mode == NVTE_BLOCK_SCALING_1D ? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F : CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; -#endif -#endif +#else + NVTE_ERROR("FP8 block scaling requires CUDA 12.9+"); +#endif // CUDA_VERSION >= 12090 +#endif // CUDA_VERSION >= 12080 } else { NVTE_ERROR("Not implemented scaling modes: " + to_string(inputA->scaling_mode) + " and " + to_string(inputB->scaling_mode) + "."); @@ -581,7 +580,11 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons const size_t B0 = inputB->flat_first_dim(); const size_t B1 = inputB->flat_last_dim(); - cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, A0, A1, B0, B1, + const int m = transa == CUBLAS_OP_T ? A0 : A1; + const int k = transa == CUBLAS_OP_T ? A1 : A0; + const int n = transb == CUBLAS_OP_T ? B1 : B0; + + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, k, n, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream); @@ -612,8 +615,12 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) && is_delayed_tensor_scaling(inputB->scaling_mode), "Atomic GEMM only supports delayed scaling."); - cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, inputA->data.shape[0], - inputA->data.shape[1], inputB->data.shape[0], inputB->data.shape[1], + + const int m = transa == CUBLAS_OP_T ? inputA->data.shape[0] : inputA->data.shape[1]; + const int k = transa == CUBLAS_OP_T ? inputA->data.shape[1] : inputA->data.shape[0]; + const int n = transb == CUBLAS_OP_T ? inputB->data.shape[0] : inputB->data.shape[1]; + + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, k, n, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream); From 861c8700175ab97cce0c0d5b56793b41f3c502d9 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Fri, 4 Apr 2025 17:58:14 -0700 Subject: [PATCH 21/29] Rewrite GEMM comment. Signed-off-by: Keith Wyss --- transformer_engine/common/gemm/cublaslt_gemm.cu | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index aa88eb9bc4..6824eea7a1 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -100,11 +100,13 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla int a_storage_outer_dim; int b_storage_outer_dim; if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) { - // For this scaling mode, the quantizer stores - // rowwise data and transposes the data for columnwise - // data so the physical layout is always row major - // and the transA and transB values to pass to cublas - // should always be TN. + // For this scaling mode, a quantized tensor of the data is stored + // in a row major layout for rowwise data and a quantized tensor of + // the transpose of the data is also stored in row major layout. + // + // cublas will be called with "TN", but Transformer engine uses + // the "TN" parameters to choose between rowwise and columnwise + // row major tensors. a_storage_outer_dim = m; b_storage_outer_dim = n; From ada643897c8b9bbbb64ca46d34de5f1d55390ada Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Fri, 4 Apr 2025 18:27:52 -0700 Subject: [PATCH 22/29] MR Feedback. Signed-off-by: Keith Wyss --- transformer_engine/common/gemm/cublaslt_gemm.cu | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 6824eea7a1..483a1380ef 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -582,9 +582,9 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons const size_t B0 = inputB->flat_first_dim(); const size_t B1 = inputB->flat_last_dim(); - const int m = transa == CUBLAS_OP_T ? A0 : A1; - const int k = transa == CUBLAS_OP_T ? A1 : A0; - const int n = transb == CUBLAS_OP_T ? B1 : B0; + const int m = transa ? A0 : A1; + const int k = transa ? A1 : A0; + const int n = transb ? B1 : B0; cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, k, n, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, @@ -618,9 +618,9 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor is_delayed_tensor_scaling(inputB->scaling_mode), "Atomic GEMM only supports delayed scaling."); - const int m = transa == CUBLAS_OP_T ? inputA->data.shape[0] : inputA->data.shape[1]; - const int k = transa == CUBLAS_OP_T ? inputA->data.shape[1] : inputA->data.shape[0]; - const int n = transb == CUBLAS_OP_T ? inputB->data.shape[0] : inputB->data.shape[1]; + const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1]; + const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0]; + const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0]; cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, k, n, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, From e484269c140010ed5830275f392b53649ec66842 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sun, 6 Apr 2025 01:11:33 +0000 Subject: [PATCH 23/29] Refactor GEMM param canonicalization Configure A and B matrices separately. Have separate code path for each scaling mode. Signed-off-by: Tim Moon --- .../common/gemm/cublaslt_gemm.cu | 300 +++++++++--------- 1 file changed, 153 insertions(+), 147 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 483a1380ef..ed691aa532 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -52,37 +52,36 @@ inline void CreateCublasHandle(cublasLtHandle_t *handle) { NVTE_CHECK_CUBLAS(cublasLtCreate(handle)); } +/* Parameters for cuBLAS GEMM + * + * cuBLAS follows the BLAS convention of column-major ordering. This + * is different than the row-major that is typically used in + * Transformer Engine. + * + */ struct GemmParam { - void *A; - void *B; - // The layout (e.g. TN to call cublas with) - cublasOperation_t transA; - cublasOperation_t transB; - transformer_engine::DType Atype; - transformer_engine::DType Btype; - void *A_scale_inv; - void *B_scale_inv; - // ld are leading dimensions or minor dimensions - // in storage - int lda; - int ldb; - - GemmParam(cublasOperation_t transA, cublasOperation_t transB) - : A(nullptr), - B(nullptr), - transA(transA), - transB(transB), - Atype(transformer_engine::DType::kNumTypes), - Btype(transformer_engine::DType::kNumTypes), - A_scale_inv(nullptr), - B_scale_inv(nullptr), - lda(0), - ldb(0) {} + void *A = nullptr; + void *B = nullptr; + cublasOperation_t transA = CUBLAS_OP_N; + cublasOperation_t transB = CUBLAS_OP_N; + transformer_engine::DType Atype = transformer_engine::DType::kNumTypes; + transformer_engine::DType Btype = transformer_engine::DType::kNumTypes; + void *A_scale_inv = nullptr; + void *B_scale_inv = nullptr; + int lda = 0; // A column strides + int ldb = 0; // B column strides }; +/* Populate parameters for cuBLAS GEMM + * + * cuBLAS follows the BLAS convention of column-major ordering. This + * is different than the row-major that is typically used in + * Transformer Engine. + * + */ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA, const transformer_engine::Tensor &B, const cublasOperation_t transB, - int m, int k, int n) { + int m, int n, int k) { using namespace transformer_engine; NVTE_CHECK( A.scaling_mode == B.scaling_mode || @@ -91,128 +90,135 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla "Inputs A and B to GEMM need to have compatible scaling modes!"); NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!"); NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!"); - GemmParam ret(transA, transB); + GemmParam ret; + + // Device compute capability + const int arch = cuda::sm_arch(); + // Transpose mode with column-major ordering bool transa_bool = transA == CUBLAS_OP_T; bool transb_bool = transB == CUBLAS_OP_T; - int arch = cuda::sm_arch(cuda::current_device()); - int a_storage_outer_dim; - int b_storage_outer_dim; - if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) { - // For this scaling mode, a quantized tensor of the data is stored - // in a row major layout for rowwise data and a quantized tensor of - // the transpose of the data is also stored in row major layout. - // - // cublas will be called with "TN", but Transformer engine uses - // the "TN" parameters to choose between rowwise and columnwise - // row major tensors. - - a_storage_outer_dim = m; - b_storage_outer_dim = n; - ret.lda = k; - ret.ldb = k; - - ret.transA = CUBLAS_OP_T; - ret.transB = CUBLAS_OP_N; - - NVTE_CHECK(ret.lda == ret.ldb, "Minor dimension must be equal for NVTE_BLOCK_SCALING Gemm."); - - } else { - // In these scaling modes, the physical layout of - // the tensor will always line up with transA and - // transB, which are passed along to cuBLAS. - // NOTE: There is some logic below that may edit this - // decision for A and B depending on dtype and arch. - a_storage_outer_dim = transa_bool ? m : k; - b_storage_outer_dim = transb_bool ? k : n; - ret.lda = transa_bool ? k : m; - ret.ldb = transb_bool ? n : k; - - if (transa_bool && transb_bool) { // TT - NVTE_ERROR("TT layout not allowed."); - } - } - + // Configure A matrix if (is_tensor_scaling(A.scaling_mode)) { + // Unscaled or FP8 tensor scaling ret.A = A.data.dptr; + ret.transA = transA; + ret.Atype = A.data.dtype; ret.A_scale_inv = A.scale_inv.dptr; - if (transA == CUBLAS_OP_T) { - ret.Atype = A.data.dtype; - } else { - ret.Atype = A.has_columnwise_data() ? A.columnwise_data.dtype : A.data.dtype; - if (is_fp8_dtype(ret.Atype)) { - if (arch < 100) { - // Hopper and Ada - we need to use columnwise_data and change transA - NVTE_CHECK(A.has_columnwise_data(), "Input A is not suitable for columnwise usage!"); - ret.A = A.columnwise_data.dptr; - ret.transA = CUBLAS_OP_T; - ret.A_scale_inv = A.columnwise_scale_inv.dptr; - a_storage_outer_dim = m; - ret.lda = k; - } + ret.lda = transa_bool ? k : m; + if (arch < 100 && !transa_bool) { + // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. + if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) { + ret.A = A.columnwise_data.dptr; + ret.transA = CUBLAS_OP_T; + ret.Atype = A.columnwise_data.dtype; + ret.A_scale_inv = A.columnwise_scale_inv.dptr; + ret.lda = k; + } else { + NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage"); } } - ret.B = B.data.dptr; - ret.B_scale_inv = B.scale_inv.dptr; - if (transB == CUBLAS_OP_T) { - ret.Btype = B.has_columnwise_data() ? B.columnwise_data.dtype : B.data.dtype; - if (is_fp8_dtype(ret.Btype)) { - if (arch < 100) { - // Hopper and Ada - we need to use columnwise_data and change transA - NVTE_CHECK(B.has_columnwise_data(), "Input B is not suitable for columnwise usage!"); - ret.B = B.columnwise_data.dptr; - ret.transB = CUBLAS_OP_N; - ret.B_scale_inv = B.columnwise_scale_inv.dptr; - b_storage_outer_dim = n; - ret.ldb = k; - } - } + } else if (is_mxfp_scaling(A.scaling_mode)) { + // MXFP8 + // Note: Row-wise and column-wise data are scaled along different + // dimensions (with matrix interpreted in row-major order). + if (transa_bool) { + NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage"); } else { - ret.Btype = B.data.dtype; + NVTE_CHECK(A.has_columnwise_data(), "Input A is missing columnwise-wise usage"); } - } else { - // MXF8 scaling or NVTE_BLOCK_SCALING - // If not tensor scaling (which includes also high precision types), we need to - // use the proper version of data - // For MXF8, we leave the transA/B values as is, since Blackwell supports transposes - // but for NVTE_BLOCK_SCALING, we force transA/B to TN since the quantizers - // store data in that manner and the GEMM requires that layout. - if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) { - if (transA == CUBLAS_OP_T) { - NVTE_CHECK(A.has_data(), "Input A is not suitable for rowwise usage!"); - } else { - NVTE_CHECK(A.has_columnwise_data(), "Input A is not suitable for columnwise usage!"); - } - if (transB == CUBLAS_OP_N) { - NVTE_CHECK(B.has_data(), "Input B is not suitable for rowwise usage!"); - } else { - NVTE_CHECK(B.has_columnwise_data(), "Input B is not suitable for columnwise usage!"); - } - // Requirements from - // https://docs.nvidia.com/cuda/cublas/#tensor-core-usage - NVTE_CHECK((ret.lda % 16) == 0, - "Inner dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); - // Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement. - // Smallest supported CType is 2 bytes in this scaling mode. - NVTE_CHECK((a_storage_outer_dim % 8) == 0, - "Outer dimension requirement on A for NVTE_BLOCK_SCALING GEMM. Caller must pad."); - // Observed this requirement only present for B tensor is 1D quantized. - if (B.scaling_mode == NVTE_BLOCK_SCALING_1D) { - NVTE_CHECK( - (b_storage_outer_dim % 8) == 0, - "Outer dimension requirement on B for NVTE_BLOCK_SCALING GEMM. Caller must pad."); - } - NVTE_CHECK((ret.ldb % 16) == 0, - "B tensor stride requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); + ret.A = transa_bool ? A.data.dptr : A.columnwise_data.dptr; + ret.transA = transA; + ret.Atype = transa_bool ? A.data.dtype : A.columnwise_data.dtype; + ret.A_scale_inv = transa_bool ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; + ret.lda = m; + } else if (A.scaling_mode == NVTE_BLOCK_SCALING_1D || A.scaling_mode == NVTE_BLOCK_SCALING_2D) { + // FP8 block scaling + // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. + if (transa_bool) { + NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage"); + } else { + NVTE_CHECK(A.has_columnwise_data(), "Input A is missing columnwise-wise usage"); } ret.A = transa_bool ? A.data.dptr : A.columnwise_data.dptr; + ret.transA = CUBLAS_OP_T; ret.Atype = transa_bool ? A.data.dtype : A.columnwise_data.dtype; ret.A_scale_inv = transa_bool ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; + ret.lda = k; + + // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage + NVTE_CHECK((ret.lda % 16) == 0, + "Inner dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); + // Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement. + // Smallest supported CType is 2 bytes in this scaling mode. + NVTE_CHECK((m % 8) == 0, + "Outer dimension requirement on A for NVTE_BLOCK_SCALING GEMM. Caller must pad."); + } else { + NVTE_ERROR("A has unsupported scaling mode"); + } + + // Configure B matrix + if (is_tensor_scaling(B.scaling_mode)) { + // Unscaled or FP8 tensor scaling + ret.B = B.data.dptr; + ret.transB = transB; + ret.Btype = B.data.dtype; + ret.B_scale_inv = B.scale_inv.dptr; + ret.ldb = transb_bool ? n : k; + if (arch < 100 && transb_bool) { + // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. + if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) { + ret.B = B.columnwise_data.dptr; + ret.transB = CUBLAS_OP_N; + ret.Btype = B.columnwise_data.dtype; + ret.B_scale_inv = B.columnwise_scale_inv.dptr; + ret.ldb = k; + } else { + NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage"); + } + } + } else if (is_mxfp_scaling(B.scaling_mode)) { + // MXFP8 + // Note: Row-wise and column-wise data are scaled along different + // dimensions (with matrix interpreted in row-major order). + if (transb_bool) { + NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); + } else { + NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage"); + } + ret.B = transb_bool ? B.columnwise_data.dptr : B.data.dptr; + ret.transB = transB; + ret.Btype = transb_bool ? B.columnwise_data.dtype : B.data.dtype; + ret.B_scale_inv = transb_bool ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; + ret.ldb = k; + } else if (B.scaling_mode == NVTE_BLOCK_SCALING_1D || B.scaling_mode == NVTE_BLOCK_SCALING_2D) { + // FP8 block scaling + // Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. + if (transb_bool) { + NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); + } else { + NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage"); + } ret.B = transb_bool ? B.columnwise_data.dptr : B.data.dptr; + ret.transB = CUBLAS_OP_N; ret.Btype = transb_bool ? B.columnwise_data.dtype : B.data.dtype; ret.B_scale_inv = transb_bool ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; + ret.ldb = k; + + // Requirements from + // https://docs.nvidia.com/cuda/cublas/#tensor-core-usage + NVTE_CHECK((ret.ldb % 16) == 0, + "B tensor stride requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); + if (B.scaling_mode == NVTE_BLOCK_SCALING_1D) { + // Observed this requirement only present for B tensor is 1D quantized. + NVTE_CHECK((n % 8) == 0, + "Outer dimension requirement on B for NVTE_BLOCK_SCALING GEMM. Caller must pad."); + } + } else { + NVTE_ERROR("B has unsupported scaling mode"); } + return ret; } @@ -223,19 +229,33 @@ namespace transformer_engine { using cublasHandleManager = detail::HandleManager; void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, - const Tensor *inputBias, Tensor *outputPreGelu, int m, int k, int n, + const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa, cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool gemm_producer, const Tensor *inputCounter, cudaStream_t stream) { + // Tensor dims in row-major order + const int A0 = inputA->flat_first_dim(); + const int A1 = inputA->flat_last_dim(); + const int B0 = inputB->flat_first_dim(); + const int B1 = inputB->flat_last_dim(); + + // GEMM dims in column-major order + const int m = transa == CUBLAS_OP_T ? A0 : A1; + const int n = transb == CUBLAS_OP_T ? B1 : B0; + const int k = transa == CUBLAS_OP_T ? A1 : A0; + NVTE_CHECK((transb == CUBLAS_OP_T ? B0 : B1) == k, + "GEMM inputs have incompatible dimensions (A is ", + A0, "x", A1, ", B is ", B0, "x", B1, ")"); const int ldd = m; + // Return immediately if GEMM is trivial if (m <= 0 || n <= 0) { return; } NVTE_CHECK(k > 0); - const GemmParam ¶m = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, k, n); + const GemmParam param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k); void *C = outputD->data.dptr; void *D = outputD->data.dptr; @@ -577,16 +597,7 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons Tensor *outputGelu = reinterpret_cast(pre_gelu_out); Tensor *wspace = reinterpret_cast(workspace); - const size_t A0 = inputA->flat_first_dim(); - const size_t A1 = inputA->flat_last_dim(); - const size_t B0 = inputB->flat_first_dim(); - const size_t B1 = inputB->flat_last_dim(); - - const int m = transa ? A0 : A1; - const int k = transa ? A1 : A0; - const int n = transb ? B1 : B0; - - cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, k, n, + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream); @@ -617,12 +628,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) && is_delayed_tensor_scaling(inputB->scaling_mode), "Atomic GEMM only supports delayed scaling."); - - const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1]; - const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0]; - const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0]; - - cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, k, n, + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream); From 9f0707e5f39ba9f0b924e63c0c08f9e993c04ba1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 6 Apr 2025 01:12:09 +0000 Subject: [PATCH 24/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/gemm/cublaslt_gemm.cu | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index ed691aa532..6fe3539257 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -229,11 +229,10 @@ namespace transformer_engine { using cublasHandleManager = detail::HandleManager; void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, - const Tensor *inputBias, Tensor *outputPreGelu, - cublasOperation_t transa, cublasOperation_t transb, bool grad, void *workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, - int math_sm_count, int m_split, int n_split, bool gemm_producer, - const Tensor *inputCounter, cudaStream_t stream) { + const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa, + cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize, + bool accumulate, bool use_split_accumulator, int math_sm_count, int m_split, + int n_split, bool gemm_producer, const Tensor *inputCounter, cudaStream_t stream) { // Tensor dims in row-major order const int A0 = inputA->flat_first_dim(); const int A1 = inputA->flat_last_dim(); @@ -245,8 +244,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const int n = transb == CUBLAS_OP_T ? B1 : B0; const int k = transa == CUBLAS_OP_T ? A1 : A0; NVTE_CHECK((transb == CUBLAS_OP_T ? B0 : B1) == k, - "GEMM inputs have incompatible dimensions (A is ", - A0, "x", A1, ", B is ", B0, "x", B1, ")"); + "GEMM inputs have incompatible dimensions (A is ", A0, "x", A1, ", B is ", B0, "x", B1, + ")"); const int ldd = m; // Return immediately if GEMM is trivial @@ -597,10 +596,9 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons Tensor *outputGelu = reinterpret_cast(pre_gelu_out); Tensor *wspace = reinterpret_cast(workspace); - cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, - (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, - wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, - math_sm_count, 0, 0, false, nullptr, stream); + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, + (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], + accumulate, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream); } void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, @@ -628,10 +626,10 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) && is_delayed_tensor_scaling(inputB->scaling_mode), "Atomic GEMM only supports delayed scaling."); - cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, - (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, - wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, - math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream); + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, + (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], + accumulate, use_split_accumulator, math_sm_count, m_split, n_split, gemm_producer, + inputCounter, stream); } void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, From ba605f1811dca3e5d36986cca1bb881dc42d0d75 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Mon, 7 Apr 2025 14:14:48 +0800 Subject: [PATCH 25/29] [PyTorch][Common] Refactor RoPE (#1626) * refactor to add cp support for sbhd/bshd Signed-off-by: Xin Yao * support interleaved Signed-off-by: Xin Yao * format Signed-off-by: Xin Yao * add interleaved to RotaryPositionEmbedding in test Signed-off-by: Xin Yao * update Signed-off-by: Xin Yao * merge sbhd/bshd and thd functions Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao --- tests/pytorch/test_fused_rope.py | 136 +++---- .../common/fused_rope/fused_rope.cu | 379 ++++++++---------- .../include/transformer_engine/fused_rope.h | 108 ++--- transformer_engine/pytorch/csrc/extensions.h | 14 +- .../pytorch/csrc/extensions/apply_rope.cpp | 286 ++++++------- .../pytorch/csrc/extensions/pybind.cpp | 4 - .../pytorch/dot_product_attention/rope.py | 248 ++++++++---- 7 files changed, 566 insertions(+), 609 deletions(-) diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index e236a29a9d..5d1adf4e02 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -11,52 +11,6 @@ ) -def _get_thd_freqs_on_this_cp_rank( - cp_rank: int, cp_size: int, x: torch.Tensor, freqs: torch.Tensor -) -> torch.Tensor: - if cp_size > 1: - cp_seg = x.size(0) // 2 - full_seqlen = cp_size * x.size(0) - return torch.cat( - [ - freqs[cp_rank * cp_seg : (cp_rank + 1) * cp_seg], - freqs[full_seqlen - (cp_rank + 1) * cp_seg : full_seqlen - cp_rank * cp_seg], - ] - ) - else: - return freqs[: x.size(0)] - - -def apply_rotary_pos_emb_thd( - t: torch.Tensor, - cu_seqlens: torch.Tensor, - freqs: torch.Tensor, - cp_size: int = 1, - cp_rank: int = 0, -) -> torch.Tensor: - """A baseline implementation of applying RoPE for `thd` format. - - Args: - t (Tensor): Input tensor T is of shape [t, h, d] - cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`, - with shape [b + 1] and dtype torch.int32. - freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d] - - Returns: - Tensor: Shape [t, h, d]. The input tensor after applying RoPE. - """ - cu_seqlens = cu_seqlens // cp_size - seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() - return torch.cat( - [ - apply_rotary_pos_emb( - x.unsqueeze(1), _get_thd_freqs_on_this_cp_rank(cp_rank, cp_size, x, freqs) - ) - for x in torch.split(t, seqlens) - ] - ).squeeze(1) - - # Gradient is a broadcasted scalar def _overlapping_grad(output: torch.Tensor) -> torch.Tensor: return output.sum() * 2 @@ -76,6 +30,8 @@ def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor: @pytest.mark.parametrize("transpose", [None, (0, 1), (2, 3)]) @pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"]) @pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad]) +@pytest.mark.parametrize("cp_size", [1, 2]) +@pytest.mark.parametrize("interleaved", [True, False]) def test_fused_rope( dtype: torch.dtype, seq_length: int, @@ -85,6 +41,8 @@ def test_fused_rope( transpose: Union[Tuple, None], tensor_format: str, loss_func: Callable, + cp_size: int, + interleaved: bool, ) -> None: device = torch.device("cuda:0") batch_size, head_num = 2, 64 @@ -99,35 +57,46 @@ def test_fused_rope( t = t.transpose(*transpose).contiguous().transpose(*transpose) t.requires_grad = True - rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent) - emb = rotary_pos_emb(seq_length) - - # unfused - # The fused kernel computes in float32 internally, so we force the unfused func to use float32 - # for more accurate comparison - output_unfused = apply_rotary_pos_emb( - t.float(), emb, tensor_format=tensor_format, fused=False - ).to(dtype) - loss_unfused = loss_func(output_unfused) - loss_unfused.backward() - grad_unfused = t.grad.detach().clone() - t.grad = None - - # fused - output_fused = apply_rotary_pos_emb( - t, - emb, - tensor_format=tensor_format, - fused=True, - ) - loss_fused = loss_func(output_fused) - loss_fused.backward() - grad_fused = t.grad.detach().clone() - t.grad = None + rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved) + emb = rotary_pos_emb(seq_length * cp_size) + assert emb.is_contiguous() - torch.testing.assert_close(output_fused, output_unfused) - torch.testing.assert_close(grad_fused, grad_unfused) - assert output_fused.is_contiguous() + for cp_rank in range(cp_size): + # unfused + # The fused kernel computes in float32 internally, so we force the unfused func to use float32 + # for more accurate comparison + output_unfused = apply_rotary_pos_emb( + t.float(), + emb, + tensor_format=tensor_format, + interleaved=interleaved, + fused=False, + cp_size=cp_size, + cp_rank=cp_rank, + ).to(dtype) + loss_unfused = loss_func(output_unfused) + loss_unfused.backward() + grad_unfused = t.grad.detach().clone() + t.grad = None + + # fused + output_fused = apply_rotary_pos_emb( + t, + emb, + tensor_format=tensor_format, + interleaved=interleaved, + fused=True, + cp_size=cp_size, + cp_rank=cp_rank, + ) + loss_fused = loss_func(output_fused) + loss_fused.backward() + grad_fused = t.grad.detach().clone() + t.grad = None + + torch.testing.assert_close(output_fused, output_unfused) + torch.testing.assert_close(grad_fused, grad_unfused) + assert output_fused.is_contiguous() @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) @@ -135,7 +104,8 @@ def test_fused_rope( @pytest.mark.parametrize("rotary_percent", [0.5, 1.0]) @pytest.mark.parametrize("transpose", [None, (1, 2)]) @pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad]) -@pytest.mark.parametrize("cp_size", [1, 2, 3]) +@pytest.mark.parametrize("cp_size", [1, 2]) +@pytest.mark.parametrize("interleaved", [True, False]) def test_fused_rope_thd( dtype: torch.dtype, hidden_size: int, @@ -143,6 +113,7 @@ def test_fused_rope_thd( transpose: Union[Tuple, None], loss_func: Callable, cp_size: int, + interleaved: bool, ) -> None: device = torch.device("cuda:0") batch_size, head_num = 2, 64 @@ -170,15 +141,23 @@ def test_fused_rope_thd( t = t.transpose(*transpose).contiguous().transpose(*transpose) t.requires_grad = True - rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent) + rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved) emb = rotary_pos_emb(cu_seqlens_padded[-1]) + assert emb.is_contiguous() for cp_rank in range(cp_size): # unfused # The fused kernel computes in float32 internally, so we force the unfused func to use float32 # for more accurate comparison - output_unfused = apply_rotary_pos_emb_thd( - t.float(), cu_seqlens_padded, emb, cp_size, cp_rank + output_unfused = apply_rotary_pos_emb( + t.float(), + emb, + tensor_format="thd", + interleaved=interleaved, + fused=False, + cu_seqlens=cu_seqlens_padded, + cp_size=cp_size, + cp_rank=cp_rank, ).to(dtype) loss_unfused = loss_func(output_unfused) loss_unfused.backward() @@ -189,6 +168,7 @@ def test_fused_rope_thd( output_fused = apply_rotary_pos_emb( t, emb, + interleaved=interleaved, fused=True, tensor_format="thd", cu_seqlens=cu_seqlens_padded, diff --git a/transformer_engine/common/fused_rope/fused_rope.cu b/transformer_engine/common/fused_rope/fused_rope.cu index 7f35ddd70b..1ab6d4ed2c 100644 --- a/transformer_engine/common/fused_rope/fused_rope.cu +++ b/transformer_engine/common/fused_rope/fused_rope.cu @@ -16,10 +16,11 @@ namespace transformer_engine { template __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs, scalar_t *dst, - const int s_id, const int offset_block, - const int offset_block_dst, const int h, const int d, - const int d2, const int stride_h, const int stride_d, - const int o_stride_h, const int o_stride_d) { + const bool interleaved, const int s_id, + const int offset_block, const int offset_block_dst, + const int h, const int d, const int d2, const int stride_h, + const int stride_d, const int o_stride_h, + const int o_stride_d) { #pragma unroll for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { float v_cos, v_sin; @@ -29,9 +30,18 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs int offset_src = offset_block + h_id * stride_h + d_id * stride_d; int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; float v_src = src[offset_src]; - float v_src_rotate = (d_id + d2 / 2 < d2) - ? -static_cast(src[offset_src + (d2 / 2) * stride_d]) - : static_cast(src[offset_src + (d2 / 2 - d2) * stride_d]); + float v_src_rotate; + if (!interleaved) { + v_src_rotate = (d_id + d2 / 2 < d2) + ? -static_cast(src[offset_src + (d2 / 2) * stride_d]) + : static_cast(src[offset_src + (d2 / 2 - d2) * stride_d]); + } else { + v_src_rotate = (d_id % 2 == 0) + // d_id + 1 + ? -static_cast(src[offset_src + stride_d]) + // d_id - 1 + : static_cast(src[offset_src - stride_d]); + } dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; } } @@ -52,22 +62,39 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs template __device__ void fused_rope_block_backward(const scalar_t *src, const float *freqs, scalar_t *dst, - const int s_id, const int offset_block, - const int offset_block_dst, const int h, const int d, - const int d2, const int stride_h, const int stride_d, + const bool interleaved, const int s_id, + const int offset_block, const int offset_block_dst, + const int h, const int d, const int d2, + const int stride_h, const int stride_d, const int o_stride_h, const int o_stride_d) { #pragma unroll for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { float v_cos = cosf(freqs[s_id * d2 + d_id]); - float v_sin = (d_id + d2 / 2 < d2) ? sinf(freqs[s_id * d2 + d_id + d2 / 2]) - : -sinf(freqs[s_id * d2 + d_id + d2 / 2 - d2]); + float v_sin; + if (!interleaved) { + v_sin = (d_id + d2 / 2 < d2) ? sinf(freqs[s_id * d2 + d_id + d2 / 2]) + : -sinf(freqs[s_id * d2 + d_id + d2 / 2 - d2]); + } else { + v_sin = + (d_id % 2 == 0) ? sinf(freqs[s_id * d2 + d_id + 1]) : -sinf(freqs[s_id * d2 + d_id - 1]); + } #pragma unroll for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { int offset_src = offset_block + h_id * stride_h + d_id * stride_d; int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; float v_src = src[offset_src]; - float v_src_rotate = (d_id + d2 / 2 < d2) ? src[offset_src + (d2 / 2) * stride_d] - : src[offset_src + (d2 / 2 - d2) * stride_d]; + float v_src_rotate; + if (!interleaved) { + v_src_rotate = (d_id + d2 / 2 < d2) + ? static_cast(src[offset_src + (d2 / 2) * stride_d]) + : static_cast(src[offset_src + (d2 / 2 - d2) * stride_d]); + } else { + v_src_rotate = (d_id % 2 == 0) + // d_id + 1 + ? static_cast(src[offset_src + stride_d]) + // d_id - 1 + : static_cast(src[offset_src - stride_d]); + } dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; } } @@ -87,51 +114,33 @@ __device__ void fused_rope_block_backward(const scalar_t *src, const float *freq } template -__global__ void fused_rope_forward_kernel(const scalar_t *src, const float *freqs, scalar_t *dst, +__global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seqlens, + const float *freqs, scalar_t *dst, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int h, const int d, const int d2, - const int stride_s, const int stride_b, + const int stride_s_or_t, const int stride_b, const int stride_h, const int stride_d, - const int o_stride_s, const int o_stride_b, + const int o_stride_s_or_t, const int o_stride_b, const int o_stride_h, const int o_stride_d) { int s_id = blockIdx.x, b_id = blockIdx.y; - int offset_block = s_id * stride_s + b_id * stride_b; - int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b; - fused_rope_block_forward(src, freqs, dst, s_id, offset_block, offset_block_dst, h, d, d2, - stride_h, stride_d, o_stride_h, o_stride_d); -} - -template -__global__ void fused_rope_backward_kernel(const scalar_t *src, const float *freqs, scalar_t *dst, - const int h, const int d, const int d2, - const int stride_s, const int stride_b, - const int stride_h, const int stride_d, - const int o_stride_s, const int o_stride_b, - const int o_stride_h, const int o_stride_d) { - int s_id = blockIdx.x, b_id = blockIdx.y; - int offset_block = s_id * stride_s + b_id * stride_b; - int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b; - fused_rope_block_backward(src, freqs, dst, s_id, offset_block, offset_block_dst, h, d, d2, - stride_h, stride_d, o_stride_h, o_stride_d); -} - -template -__global__ void fused_rope_thd_forward_kernel(const scalar_t *src, const int *cu_seqlens, - const float *freqs, scalar_t *dst, const int cp_size, - const int cp_rank, const int h, const int d, - const int d2, const int stride_t, const int stride_h, - const int stride_d, const int o_stride_t, - const int o_stride_h, const int o_stride_d) { - int s_id = blockIdx.x, b_id = blockIdx.y; - int start = cu_seqlens[b_id] / cp_size; - int end = cu_seqlens[b_id + 1] / cp_size; - int t_id = s_id + start; - if (t_id >= end) return; - int offset_block = t_id * stride_t; - int offset_block_dst = t_id * o_stride_t; + int offset_block, offset_block_dst; + int cur_seqlens; + if (cu_seqlens != nullptr) { // THD + int start = cu_seqlens[b_id] / cp_size; + int end = cu_seqlens[b_id + 1] / cp_size; + int t_id = s_id + start; + if (t_id >= end) return; + offset_block = t_id * stride_s_or_t; + offset_block_dst = t_id * o_stride_s_or_t; + cur_seqlens = end - start; + } else { // SBHD/BSHD + offset_block = s_id * stride_s_or_t + b_id * stride_b; + offset_block_dst = s_id * o_stride_s_or_t + b_id * o_stride_b; + cur_seqlens = s; + } int s_id_for_freqs; if (cp_size > 1) { - int cur_seqlens = end - start; assert(cur_seqlens % 2 == 0); if (s_id < cur_seqlens / 2) { s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2; @@ -142,28 +151,37 @@ __global__ void fused_rope_thd_forward_kernel(const scalar_t *src, const int *cu } else { s_id_for_freqs = s_id; } - fused_rope_block_forward(src, freqs, dst, s_id_for_freqs, offset_block, offset_block_dst, h, d, - d2, stride_h, stride_d, o_stride_h, o_stride_d); + + fused_rope_block_forward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block, + offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d); } template -__global__ void fused_rope_thd_backward_kernel(const scalar_t *src, const int *cu_seqlens, - const float *freqs, scalar_t *dst, const int cp_size, - const int cp_rank, const int h, const int d, - const int d2, const int stride_t, const int stride_h, - const int stride_d, const int o_stride_t, - const int o_stride_h, const int o_stride_d) { +__global__ void fused_rope_backward_kernel( + const scalar_t *src, const int *cu_seqlens, const float *freqs, scalar_t *dst, + const bool interleaved, const int cp_size, const int cp_rank, const int s, const int h, + const int d, const int d2, const int stride_s_or_t, const int stride_b, const int stride_h, + const int stride_d, const int o_stride_s_or_t, const int o_stride_b, const int o_stride_h, + const int o_stride_d) { int s_id = blockIdx.x, b_id = blockIdx.y; - int start = cu_seqlens[b_id] / cp_size; - int end = cu_seqlens[b_id + 1] / cp_size; - int t_id = s_id + start; - if (t_id >= end) return; - int offset_block = t_id * stride_t; - int offset_block_dst = t_id * o_stride_t; + int offset_block, offset_block_dst; + int cur_seqlens; + if (cu_seqlens != nullptr) { // THD + int start = cu_seqlens[b_id] / cp_size; + int end = cu_seqlens[b_id + 1] / cp_size; + int t_id = s_id + start; + if (t_id >= end) return; + offset_block = t_id * stride_s_or_t; + offset_block_dst = t_id * o_stride_s_or_t; + cur_seqlens = end - start; + } else { // SBHD/BSHD + offset_block = s_id * stride_s_or_t + b_id * stride_b; + offset_block_dst = s_id * o_stride_s_or_t + b_id * o_stride_b; + cur_seqlens = s; + } int s_id_for_freqs; if (cp_size > 1) { - int cur_seqlens = end - start; assert(cur_seqlens % 2 == 0); if (s_id < cur_seqlens / 2) { s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2; @@ -174,193 +192,136 @@ __global__ void fused_rope_thd_backward_kernel(const scalar_t *src, const int *c } else { s_id_for_freqs = s_id; } - fused_rope_block_backward(src, freqs, dst, s_id_for_freqs, offset_block, offset_block_dst, h, d, - d2, stride_h, stride_d, o_stride_h, o_stride_d); + + fused_rope_block_backward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block, + offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d); } template -void fused_rope_forward_launcher(const scalar_t *input, const float *freqs, scalar_t *output, +void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, const float *freqs, + scalar_t *output, const NVTE_QKV_Format qkv_format, + const bool interleaved, const int cp_size, const int cp_rank, const int s, const int b, const int h, const int d, const int d2, - const int stride_s, const int stride_b, const int stride_h, - const int stride_d, const int o_stride_s, const int o_stride_b, - const int o_stride_h, const int o_stride_d, cudaStream_t stream) { + const int stride_s_or_t, const int stride_b, const int stride_h, + const int stride_d, cudaStream_t stream) { int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(s, b); dim3 threads(THREADS_PER_WARP, warps_per_block); + int o_stride_s_or_t, o_stride_b; + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + NVTE_CHECK(cu_seqlens != nullptr, "cu_seqlens is required for THD format"); + o_stride_s_or_t = h * d; + o_stride_b = 0; + } else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) { + o_stride_s_or_t = b * h * d; + o_stride_b = h * d; + } else { + o_stride_s_or_t = h * d; + o_stride_b = s * h * d; + } + const int o_stride_h = d; + const int o_stride_d = 1; fused_rope_forward_kernel<<>>( - input, freqs, output, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, - o_stride_b, o_stride_h, o_stride_d); + input, cu_seqlens, freqs, output, interleaved, cp_size, cp_rank, s, h, d, d2, stride_s_or_t, + stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h, o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); } template -void fused_rope_backward_launcher(const scalar_t *output_grads, const float *freqs, - scalar_t *input_grads, const int s, const int b, const int h, - const int d, const int d2, const int stride_s, const int stride_b, - const int stride_h, const int stride_d, const int o_stride_s, - const int o_stride_b, const int o_stride_h, const int o_stride_d, +void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_seqlens, + const float *freqs, scalar_t *input_grads, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, const int stride_s_or_t, + const int stride_b, const int stride_h, const int stride_d, cudaStream_t stream) { int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(s, b); dim3 threads(THREADS_PER_WARP, warps_per_block); + int o_stride_s_or_t, o_stride_b; + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + NVTE_CHECK(cu_seqlens != nullptr, "cu_seqlens is required for THD format"); + o_stride_s_or_t = h * d; + o_stride_b = 0; + } else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) { + o_stride_s_or_t = b * h * d; + o_stride_b = h * d; + } else { + o_stride_s_or_t = h * d; + o_stride_b = s * h * d; + } + const int o_stride_h = d; + const int o_stride_d = 1; fused_rope_backward_kernel<<>>( - output_grads, freqs, input_grads, h, d, d2, stride_s, stride_b, stride_h, stride_d, - o_stride_s, o_stride_b, o_stride_h, o_stride_d); - NVTE_CHECK_CUDA(cudaGetLastError()); -} - -template -void fused_rope_thd_forward_launcher(const scalar_t *input, const int *cu_seqlens, - const float *freqs, scalar_t *output, const int cp_size, - const int cp_rank, const int max_s, const int b, const int h, - const int d, const int d2, const int stride_t, - const int stride_h, const int stride_d, const int o_stride_t, - const int o_stride_h, const int o_stride_d, - cudaStream_t stream) { - int warps_per_block = h < 16 ? 4 : 8; - dim3 blocks(max_s, b); - dim3 threads(THREADS_PER_WARP, warps_per_block); - - fused_rope_thd_forward_kernel<<>>( - input, cu_seqlens, freqs, output, cp_size, cp_rank, h, d, d2, stride_t, stride_h, stride_d, - o_stride_t, o_stride_h, o_stride_d); + output_grads, cu_seqlens, freqs, input_grads, interleaved, cp_size, cp_rank, s, h, d, d2, + stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h, + o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); } -template -void fused_rope_thd_backward_launcher(const scalar_t *output_grads, const int *cu_seqlens, - const float *freqs, scalar_t *input_grads, const int cp_size, - const int cp_rank, const int max_s, const int b, const int h, - const int d, const int d2, const int stride_t, - const int stride_h, const int stride_d, const int o_stride_t, - const int o_stride_h, const int o_stride_d, - cudaStream_t stream) { - int warps_per_block = h < 16 ? 4 : 8; - dim3 blocks(max_s, b); - dim3 threads(THREADS_PER_WARP, warps_per_block); - - fused_rope_thd_backward_kernel<<>>( - output_grads, cu_seqlens, freqs, input_grads, cp_size, cp_rank, h, d, d2, stride_t, stride_h, - stride_d, o_stride_t, o_stride_h, o_stride_d); - NVTE_CHECK_CUDA(cudaGetLastError()); -} - -void fused_rope_forward(const Tensor &input, const Tensor &freqs, Tensor *output, const int s, - const int b, const int h, const int d, const int d2, const int stride_s, - const int stride_b, const int stride_h, const int stride_d, - const int o_stride_s, const int o_stride_b, const int o_stride_h, - const int o_stride_d, cudaStream_t stream) { +void fused_rope_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs, + Tensor *output, const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, const int h, + const int d, const int d2, const int stride_s_or_t, const int stride_b, + const int stride_h, const int stride_d, cudaStream_t stream) { TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.data.dtype, scalar_t, fused_rope_forward_launcher(reinterpret_cast(input.data.dptr), + reinterpret_cast(cu_seqlens.data.dptr), reinterpret_cast(freqs.data.dptr), - reinterpret_cast(output->data.dptr), s, b, h, d, d2, - stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, - o_stride_h, o_stride_d, stream);); + reinterpret_cast(output->data.dptr), qkv_format, + interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t, + stride_b, stride_h, stride_d, stream);); } -void fused_rope_backward(const Tensor &output_grads, const Tensor &freqs, Tensor *input_grads, - const int s, const int b, const int h, const int d, const int d2, - const int stride_s, const int stride_b, const int stride_h, - const int stride_d, const int o_stride_s, const int o_stride_b, - const int o_stride_h, const int o_stride_d, cudaStream_t stream) { +void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, const Tensor &freqs, + Tensor *input_grads, const NVTE_QKV_Format qkv_format, + const bool interleaved, const int cp_size, const int cp_rank, const int s, + const int b, const int h, const int d, const int d2, + const int stride_s_or_t, const int stride_b, const int stride_h, + const int stride_d, cudaStream_t stream) { TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( output_grads.data.dtype, scalar_t, fused_rope_backward_launcher(reinterpret_cast(output_grads.data.dptr), + reinterpret_cast(cu_seqlens.data.dptr), reinterpret_cast(freqs.data.dptr), - reinterpret_cast(input_grads->data.dptr), s, b, h, d, - d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, - o_stride_b, o_stride_h, o_stride_d, stream);); -} - -void fused_rope_thd_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs, - Tensor *output, const int cp_size, const int cp_rank, const int max_s, - const int b, const int h, const int d, const int d2, const int stride_t, - const int stride_h, const int stride_d, const int o_stride_t, - const int o_stride_h, const int o_stride_d, cudaStream_t stream) { - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, scalar_t, - fused_rope_thd_forward_launcher(reinterpret_cast(input.data.dptr), - reinterpret_cast(cu_seqlens.data.dptr), - reinterpret_cast(freqs.data.dptr), - reinterpret_cast(output->data.dptr), cp_size, - cp_rank, max_s, b, h, d, d2, stride_t, stride_h, stride_d, - o_stride_t, o_stride_h, o_stride_d, stream);); -} - -void fused_rope_thd_backward(const Tensor &output_grads, const Tensor &cu_seqlens, - const Tensor &freqs, Tensor *input_grads, const int cp_size, - const int cp_rank, const int max_s, const int b, const int h, - const int d, const int d2, const int stride_t, const int stride_h, - const int stride_d, const int o_stride_t, const int o_stride_h, - const int o_stride_d, cudaStream_t stream) { - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - output_grads.data.dtype, scalar_t, - fused_rope_thd_backward_launcher(reinterpret_cast(output_grads.data.dptr), - reinterpret_cast(cu_seqlens.data.dptr), - reinterpret_cast(freqs.data.dptr), - reinterpret_cast(input_grads->data.dptr), - cp_size, cp_rank, max_s, b, h, d, d2, stride_t, stride_h, - stride_d, o_stride_t, o_stride_h, o_stride_d, stream);); + reinterpret_cast(input_grads->data.dptr), qkv_format, + interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t, + stride_b, stride_h, stride_d, stream);); } } // end namespace transformer_engine -void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVTETensor output, - const int s, const int b, const int h, const int d, const int d2, - const int stride_s, const int stride_b, const int stride_h, - const int stride_d, const int o_stride_s, const int o_stride_b, - const int o_stride_h, const int o_stride_d, cudaStream_t stream) { +void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens, + const NVTETensor freqs, NVTETensor output, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, const int stride_s_or_t, + const int stride_b, const int stride_h, const int stride_d, + cudaStream_t stream) { NVTE_API_CALL(nvte_fused_rope_forward); using namespace transformer_engine; fused_rope_forward(*reinterpret_cast(input), + *reinterpret_cast(cu_seqlens), *reinterpret_cast(freqs), reinterpret_cast(output), - s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, - o_stride_h, o_stride_d, stream); + qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t, + stride_b, stride_h, stride_d, stream); } -void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor freqs, - NVTETensor input_grads, const int s, const int b, const int h, - const int d, const int d2, const int stride_s, const int stride_b, - const int stride_h, const int stride_d, const int o_stride_s, - const int o_stride_b, const int o_stride_h, const int o_stride_d, +void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, + const NVTETensor freqs, NVTETensor input_grads, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, const int stride_s_or_t, + const int stride_b, const int stride_h, const int stride_d, cudaStream_t stream) { NVTE_API_CALL(nvte_fused_rope_backward); using namespace transformer_engine; fused_rope_backward(*reinterpret_cast(output_grads), + *reinterpret_cast(cu_seqlens), *reinterpret_cast(freqs), - reinterpret_cast(input_grads), s, b, h, d, d2, stride_s, stride_b, - stride_h, stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d, stream); -} - -void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor output, const int cp_size, - const int cp_rank, const int max_s, const int b, const int h, - const int d, const int d2, const int stride_t, const int stride_h, - const int stride_d, const int o_stride_t, const int o_stride_h, - const int o_stride_d, cudaStream_t stream) { - NVTE_API_CALL(nvte_fused_rope_thd_forward); - using namespace transformer_engine; - fused_rope_thd_forward(*reinterpret_cast(input), - *reinterpret_cast(cu_seqlens), - *reinterpret_cast(freqs), - reinterpret_cast(output), cp_size, cp_rank, max_s, b, h, d, d2, - stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream); -} - -void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor input_grads, const int cp_size, - const int cp_rank, const int max_s, const int b, const int h, - const int d, const int d2, const int stride_t, const int stride_h, - const int stride_d, const int o_stride_t, const int o_stride_h, - const int o_stride_d, cudaStream_t stream) { - NVTE_API_CALL(nvte_fused_rope_thd_backward); - using namespace transformer_engine; - fused_rope_thd_backward( - *reinterpret_cast(output_grads), - *reinterpret_cast(cu_seqlens), *reinterpret_cast(freqs), - reinterpret_cast(input_grads), cp_size, cp_rank, max_s, b, h, d, d2, stride_t, - stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream); + reinterpret_cast(input_grads), qkv_format, interleaved, cp_size, + cp_rank, s, b, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, stream); } diff --git a/transformer_engine/common/include/transformer_engine/fused_rope.h b/transformer_engine/common/include/transformer_engine/fused_rope.h index 41a0e3bc76..5a5bcc74ad 100644 --- a/transformer_engine/common/include/transformer_engine/fused_rope.h +++ b/transformer_engine/common/include/transformer_engine/fused_rope.h @@ -7,6 +7,7 @@ #ifndef TRANSFORMER_ENGINE_FUSED_ROPE_H_ #define TRANSFORMER_ENGINE_FUSED_ROPE_H_ +#include "fused_attn.h" #include "transformer_engine.h" #ifdef __cplusplus @@ -16,112 +17,63 @@ extern "C" { /*! \brief Apply rotary positional embedding to the input tensor. * * \param[in] input Input tensor for fused rope. + * \param[in] cu_seqlens The cumulative sum of sequence lengths tensor. + * (Required for the thd format, empty tensor for other formats) * \param[in] freqs The freqs tensor. * \param[out] output Output tensor. + * \param[in] qkv_format QKV format. + * \param[in] interleaved Whether to use interleaved rotary position embedding. + * \param[in] cp_size Context parallel world size. + * \param[in] cp_rank Context parallel rank. * \param[in] s Length of the s dimension of input. * \param[in] b Length of the b dimension of input. * \param[in] h Length of the h dimension of input. * \param[in] d Length of the d dimension of input. * \param[in] d2 Length of the d dimension of freqs. - * \param[in] stride_s Stride of the s dimension of input. - * \param[in] stride_b Stride of the b dimension of input. + * \param[in] stride_s_or_t Stride of the s (sbhd/bshd)/t (thd) dimension of input. + * \param[in] stride_b Stride of the b dimension of input. (0 for thd). * \param[in] stride_h Stride of the h dimension of input. * \param[in] stride_d Stride of the d dimension of input. - * \param[in] o_stride_s Stride of the s dimension of output. - * \param[in] o_stride_b Stride of the b dimension of output. - * \param[in] o_stride_h Stride of the h dimension of output. - * \param[in] o_stride_d Stride of the d dimension of output. * \param[in] stream CUDA stream used for the operation. */ -void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVTETensor output, - const int s, const int b, const int h, const int d, const int d2, - const int stride_s, const int stride_b, const int stride_h, - const int stride_d, const int o_stride_s, const int o_stride_b, - const int o_stride_h, const int o_stride_d, cudaStream_t stream); +void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens, + const NVTETensor freqs, NVTETensor output, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, const int stride_s_or_t, + const int stride_b, const int stride_h, const int stride_d, + cudaStream_t stream); /*! \brief Compute the backward of the fused rope. * * \param[in] output_grads Incoming gradient tensor for backward. + * \param[in] cu_seqlens The cumulative sum of sequence lengths tensor. + * (Required for the thd format, empty tensor for other formats) * \param[in] freqs The freqs tensor. * \param[out] input_grads Input gradient tensor to calculate. + * \param[in] qkv_format QKV format. + * \param[in] interleaved Whether to use interleaved rotary position embedding. + * \param[in] cp_size Context parallel world size. + * \param[in] cp_rank Context parallel rank. * \param[in] s Length of the s dimension of output_grads. * \param[in] b Length of the b dimension of output_grads. * \param[in] h Length of the h dimension of output_grads. * \param[in] d Length of the d dimension of output_grads. * \param[in] d2 Length of the d dimension of freqs. - * \param[in] stride_s Stride of the s dimension of output_grads. - * \param[in] stride_b Stride of the b dimension of output_grads. + * \param[in] stride_s_or_t Stride of the s (sbhd/bshd)/t (thd) dimension of output_grads. + * \param[in] stride_b Stride of the b dimension of output_grads. (0 for thd). * \param[in] stride_h Stride of the h dimension of output_grads. * \param[in] stride_d Stride of the d dimension of output_grads. - * \param[in] o_stride_s Stride of the s dimension of input_grads. - * \param[in] o_stride_b Stride of the b dimension of input_grads. - * \param[in] o_stride_h Stride of the h dimension of input_grads. - * \param[in] o_stride_d Stride of the d dimension of input_grads. * \param[in] stream CUDA stream used for the operation. */ -void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor freqs, - NVTETensor input_grads, const int s, const int b, const int h, - const int d, const int d2, const int stride_s, const int stride_b, - const int stride_h, const int stride_d, const int o_stride_s, - const int o_stride_b, const int o_stride_h, const int o_stride_d, +void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, + const NVTETensor freqs, NVTETensor input_grads, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, const int stride_s_or_t, + const int stride_b, const int stride_h, const int stride_d, cudaStream_t stream); -/*! \brief Apply rotary positional embedding to the input tensor in thd format. - * - * \param[in] input Input tensor for fused rope. - * \param[in] cu_seqlens The cumulative sum of sequence lengths tensor. - * \param[in] freqs The freqs tensor. - * \param[out] output Output tensor. - * \param[in] cp_size Context parallel world size. - * \param[in] cp_rank Context parallel rank. - * \param[in] max_s Max sequence length. - * \param[in] b Batch size. - * \param[in] h Length of the h dimension of input. - * \param[in] d Length of the d dimension of input. - * \param[in] d2 Length of the d dimension of freqs. - * \param[in] stride_t Stride of the t dimension of input. - * \param[in] stride_h Stride of the h dimension of input. - * \param[in] stride_d Stride of the d dimension of input. - * \param[in] o_stride_t Stride of the t dimension of output. - * \param[in] o_stride_h Stride of the h dimension of output. - * \param[in] o_stride_d Stride of the d dimension of output. - * \param[in] stream CUDA stream used for the operation. - */ -void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor output, const int cp_size, - const int cp_rank, const int max_s, const int b, const int h, - const int d, const int d2, const int stride_t, const int stride_h, - const int stride_d, const int o_stride_t, const int o_stride_h, - const int o_stride_d, cudaStream_t stream); - -/*! \brief Compute the backward of the fused rope in thd format. - * - * \param[in] output_grads Incoming gradient tensor for backward. - * \param[in] cu_seqlens The cumulative sum of sequence lengths tensor. - * \param[in] freqs The freqs tensor. - * \param[out] input_grads Input gradient to calculate. - * \param[in] cp_size Context parallel world size. - * \param[in] cp_rank Context parallel rank. - * \param[in] max_s Max sequence length. - * \param[in] b Batch size. - * \param[in] h Length of the h dimension of output_grads. - * \param[in] d Length of the d dimension of output_grads. - * \param[in] d2 Length of the d dimension of freqs. - * \param[in] stride_t Stride of the t dimension of output_grads. - * \param[in] stride_h Stride of the h dimension of output_grads. - * \param[in] stride_d Stride of the d dimension of output_grads. - * \param[in] o_stride_t Stride of the t dimension of input_grads. - * \param[in] o_stride_h Stride of the h dimension of input_grads. - * \param[in] o_stride_d Stride of the d dimension of input_grads. - * \param[in] stream CUDA stream used for the operation. - */ -void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor input_grads, const int cp_size, - const int cp_rank, const int max_s, const int b, const int h, - const int d, const int d2, const int stride_t, const int stride_h, - const int stride_d, const int o_stride_t, const int o_stride_h, - const int o_stride_d, cudaStream_t stream); - #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index d7abfcb45c..a66fbf950d 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -265,16 +265,14 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio **************************************************************************************************/ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, - const bool transpose_output_memory); + const NVTE_QKV_Format qkv_format, const bool interleaved, + const c10::optional cu_seqlens, const int cp_size, + const int cp_rank); at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, - const bool transpose_output_memory); - -at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens, - const at::Tensor &freqs, const int cp_size, const int cp_rank); - -at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens, - const at::Tensor &freqs, const int cp_size, const int cp_rank); + const NVTE_QKV_Format qkv_format, const bool interleaved, + const c10::optional cu_seqlens, const int cp_size, + const int cp_rank); /*************************************************************************************************** * Miscellaneous diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp index c323e7b6c1..424a988301 100644 --- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp +++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp @@ -7,217 +7,181 @@ #include "extensions.h" at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, - const bool transpose_output_memory) { + const NVTE_QKV_Format qkv_format, const bool interleaved, + const c10::optional cu_seqlens, const int cp_size, + const int cp_rank) { using namespace transformer_engine::pytorch; - TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); - TORCH_CHECK(input.size(0) <= freqs.size(0), - "expected freqs tensor has a longer sequence length than input"); TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1, "expected the second and third dims of the freqs tensor equal 1"); - TORCH_CHECK(input.size(3) >= freqs.size(3), - "expected the last dim of the input tensor equals or is " - "greater than the freqs tensor"); TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float, "Dtype of the freqs tensor must be float"); - // input sizes: (s, b, h, d) + // output + auto act_options = at::TensorOptions().dtype(input.scalar_type()).device(input.device()); + auto output = at::empty(input.sizes(), act_options); + + auto input_cu = makeTransformerEngineTensor(input); + auto freqs_cu = makeTransformerEngineTensor(freqs); + auto output_cu = makeTransformerEngineTensor(output); + + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); + TORCH_CHECK(cu_seqlens.has_value(), "expected cu_seqlens tensor"); + TORCH_CHECK(cu_seqlens.value().dim() == 1, "expected 1D tensor"); + TORCH_CHECK(input.size(2) >= freqs.size(3), + "expected the last dim of the input tensor equals or is " + "greater than the freqs tensor"); + + // input sizes: (t, h, d) + // t: cumulative sum of sequence lengths + // h: head num + // d: dim of each head + // const int t = input.size(0); + const int h = input.size(1); + const int d = input.size(2); + // input strides + const int stride_t = input.stride(0); + const int stride_h = input.stride(1); + const int stride_d = input.stride(2); + // batch size + const int b = cu_seqlens.value().size(0) - 1; + // freqs' shape is (max_s, 1, 1, d2) + const int max_s = freqs.size(0); + const int d2 = freqs.size(3); + + auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value()); + + nvte_fused_rope_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), + output_cu.data(), qkv_format, interleaved, cp_size, cp_rank, max_s, b, + h, d, d2, stride_t, /*stride_b=*/0, stride_h, stride_d, + at::cuda::getCurrentCUDAStream()); + + return output; + } + + TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); + // input sizes: (s, b, h, d) or (b, s, h, d) // s: sequence length // b: batch size // h: head num // d: dim of each head - const int s = input.size(0); - const int b = input.size(1); + const int s = qkv_format == NVTE_QKV_Format::NVTE_SBHD ? input.size(0) : input.size(1); + const int b = qkv_format == NVTE_QKV_Format::NVTE_SBHD ? input.size(1) : input.size(0); const int h = input.size(2); const int d = input.size(3); // input strides - const int stride_s = input.stride(0); - const int stride_b = input.stride(1); + const int stride_s = qkv_format == NVTE_QKV_Format::NVTE_SBHD ? input.stride(0) : input.stride(1); + const int stride_b = qkv_format == NVTE_QKV_Format::NVTE_SBHD ? input.stride(1) : input.stride(0); const int stride_h = input.stride(2); const int stride_d = input.stride(3); // freqs' shape is always (s, 1, 1, d2), so the strides are same under // different memory formats const int d2 = freqs.size(3); - // output - auto act_options = input.options().requires_grad(false); - at::Tensor output; - if (transpose_output_memory) { - output = torch::empty({b, s, h, d}, act_options).transpose(0, 1); - } else { - output = torch::empty({s, b, h, d}, act_options); - } - // output strides - const int o_stride_s = output.stride(0); - const int o_stride_b = output.stride(1); - const int o_stride_h = output.stride(2); - const int o_stride_d = output.stride(3); - - auto input_cu = makeTransformerEngineTensor(input); - auto freqs_cu = makeTransformerEngineTensor(freqs); - auto output_cu = makeTransformerEngineTensor(output); + TORCH_CHECK(s * cp_size <= freqs.size(0), + "expected freqs tensor has a longer sequence length than input"); + TORCH_CHECK(d >= d2, + "expected the last dim of the input tensor equals or is " + "greater than the freqs tensor"); - nvte_fused_rope_forward(input_cu.data(), freqs_cu.data(), output_cu.data(), s, b, h, d, d2, - stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, - o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream()); + auto cu_seqlens_cu = transformer_engine::TensorWrapper(); // empty cu_seqlens tensor + nvte_fused_rope_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), output_cu.data(), + qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s, + stride_b, stride_h, stride_d, at::cuda::getCurrentCUDAStream()); return output; } at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, - const bool transpose_output_memory) { + const NVTE_QKV_Format qkv_format, const bool interleaved, + const c10::optional cu_seqlens, const int cp_size, + const int cp_rank) { using namespace transformer_engine::pytorch; - TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); - TORCH_CHECK(output_grads.size(0) <= freqs.size(0), - "expected freqs tensor has a longer sequence length than output_grads"); TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1, "expected the second and third dims of the freqs tensor equal 1"); - TORCH_CHECK(output_grads.size(3) >= freqs.size(3), - "expected the last dim of the output_grads tensor equals or is " - "greater than the freqs tensor"); TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float, "Dtype of the freqs tensor must be float"); + auto act_options = + at::TensorOptions().dtype(output_grads.scalar_type()).device(output_grads.device()); + auto input_grads = at::empty(output_grads.sizes(), act_options); + + auto output_grads_cu = makeTransformerEngineTensor(output_grads); + auto freqs_cu = makeTransformerEngineTensor(freqs); + auto input_grads_cu = makeTransformerEngineTensor(input_grads); + + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); + TORCH_CHECK(cu_seqlens.has_value(), "expected cu_seqlens tensor"); + TORCH_CHECK(cu_seqlens.value().dim() == 1, "expected 1D tensor"); + TORCH_CHECK(output_grads.size(2) >= freqs.size(3), + "expected the last dim of the output_grads tensor equals or is " + "greater than the freqs tensor"); + + // output_grads sizes: (t, h, d) + // t: cumulative sum of sequence lengths + // h: head num + // d: dim of each head + // const int t = output_grads.size(0); + const int h = output_grads.size(1); + const int d = output_grads.size(2); + // output_grads strides + const int stride_t = output_grads.stride(0); + const int stride_h = output_grads.stride(1); + const int stride_d = output_grads.stride(2); + // batch size + const int b = cu_seqlens.value().size(0) - 1; + // freqs' shape is (max_s, 1, 1, d2) + const int max_s = freqs.size(0); + const int d2 = freqs.size(3); + + auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value()); + + nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), + input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank, + max_s, b, h, d, d2, stride_t, /*stride_b=*/0, stride_h, stride_d, + at::cuda::getCurrentCUDAStream()); + + return input_grads; + } + + TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor"); // output_grads sizes: (s, b, h, d) // s: sequence length // b: batch size // h: head num // d: dim of each head - const int s = output_grads.size(0); - const int b = output_grads.size(1); + const int s = + qkv_format == NVTE_QKV_Format::NVTE_SBHD ? output_grads.size(0) : output_grads.size(1); + const int b = + qkv_format == NVTE_QKV_Format::NVTE_SBHD ? output_grads.size(1) : output_grads.size(0); const int h = output_grads.size(2); const int d = output_grads.size(3); // output_grads strides - const int stride_s = output_grads.stride(0); - const int stride_b = output_grads.stride(1); + const int stride_s = + qkv_format == NVTE_QKV_Format::NVTE_SBHD ? output_grads.stride(0) : output_grads.stride(1); + const int stride_b = + qkv_format == NVTE_QKV_Format::NVTE_SBHD ? output_grads.stride(1) : output_grads.stride(0); const int stride_h = output_grads.stride(2); const int stride_d = output_grads.stride(3); // freqs' shape is always (s, 1, 1, d2), so the strides are same under // different memory formats const int d2 = freqs.size(3); - auto act_options = output_grads.options().requires_grad(false); - at::Tensor input_grads; - if (transpose_output_memory) { - input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1); - } else { - input_grads = torch::empty({s, b, h, d}, act_options); - } - const int o_stride_s = input_grads.stride(0); - const int o_stride_b = input_grads.stride(1); - const int o_stride_h = input_grads.stride(2); - const int o_stride_d = input_grads.stride(3); - - auto output_grads_cu = makeTransformerEngineTensor(output_grads); - auto freqs_cu = makeTransformerEngineTensor(freqs); - auto input_grads_cu = makeTransformerEngineTensor(input_grads); - - nvte_fused_rope_backward(output_grads_cu.data(), freqs_cu.data(), input_grads_cu.data(), s, b, h, - d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, - o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream()); - - return input_grads; -} - -at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens, - const at::Tensor &freqs, const int cp_size, const int cp_rank) { - using namespace transformer_engine::pytorch; - TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); - TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor"); - TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); - TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1, - "expected the second and third dims of the freqs tensor equal 1"); - TORCH_CHECK(input.size(2) >= freqs.size(3), - "expected the last dim of the input tensor equals or is " - "greater than the freqs tensor"); - TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float, - "Dtype of the freqs tensor must be float"); - - // input sizes: (t, h, d) - // t: cumulative sum of sequence lengths - // h: head num - // d: dim of each head - const int t = input.size(0); - const int h = input.size(1); - const int d = input.size(2); - // input strides - const int stride_t = input.stride(0); - const int stride_h = input.stride(1); - const int stride_d = input.stride(2); - // batch size - const int b = cu_seqlens.size(0) - 1; - // freqs' shape is (max_s, 1, 1, d2) - const int max_s = freqs.size(0); - const int d2 = freqs.size(3); - - // output - auto act_options = input.options().requires_grad(false); - auto output = torch::empty({t, h, d}, act_options); - // output strides - const int o_stride_t = output.stride(0); - const int o_stride_h = output.stride(1); - const int o_stride_d = output.stride(2); - - auto input_cu = makeTransformerEngineTensor(input); - auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens); - auto freqs_cu = makeTransformerEngineTensor(freqs); - auto output_cu = makeTransformerEngineTensor(output); - - nvte_fused_rope_thd_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), - output_cu.data(), cp_size, cp_rank, max_s, b, h, d, d2, stride_t, - stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, - at::cuda::getCurrentCUDAStream()); - - return output; -} - -at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens, - const at::Tensor &freqs, const int cp_size, const int cp_rank) { - using namespace transformer_engine::pytorch; - TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); - TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor"); - TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); - TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1, - "expected the second and third dims of the freqs tensor equal 1"); - TORCH_CHECK(output_grads.size(2) >= freqs.size(3), + TORCH_CHECK(s * cp_size <= freqs.size(0), + "expected freqs tensor has a longer sequence length than output_grads"); + TORCH_CHECK(d >= d2, "expected the last dim of the output_grads tensor equals or is " "greater than the freqs tensor"); - TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float, - "Dtype of the freqs tensor must be float"); - - // output_grads sizes: (t, h, d) - // t: cumulative sum of sequence lengths - // h: head num - // d: dim of each head - const int t = output_grads.size(0); - const int h = output_grads.size(1); - const int d = output_grads.size(2); - // output_grads strides - const int stride_t = output_grads.stride(0); - const int stride_h = output_grads.stride(1); - const int stride_d = output_grads.stride(2); - // batch size - const int b = cu_seqlens.size(0) - 1; - // freqs' shape is (max_s, 1, 1, d2) - const int max_s = freqs.size(0); - const int d2 = freqs.size(3); - - auto act_options = output_grads.options().requires_grad(false); - auto input_grads = torch::empty({t, h, d}, act_options); - const int o_stride_t = input_grads.stride(0); - const int o_stride_h = input_grads.stride(1); - const int o_stride_d = input_grads.stride(2); - - auto output_grads_cu = makeTransformerEngineTensor(output_grads); - auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens); - auto freqs_cu = makeTransformerEngineTensor(freqs); - auto input_grads_cu = makeTransformerEngineTensor(input_grads); - nvte_fused_rope_thd_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), - input_grads_cu.data(), cp_size, cp_rank, max_s, b, h, d, d2, - stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, - at::cuda::getCurrentCUDAStream()); + auto cu_seqlens_cu = transformer_engine::TensorWrapper(); // empty cu_seqlens tensor + nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), + input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank, s, b, + h, d, d2, stride_s, stride_b, stride_h, stride_d, + at::cuda::getCurrentCUDAStream()); return input_grads; } diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 60a97dad3c..617ba42d4a 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -229,10 +229,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::call_guard()); m.def("fused_rope_backward", &fused_rope_backward, "Fused Apply RoPE BWD", py::call_guard()); - m.def("fused_rope_thd_forward", &fused_rope_thd_forward, "Fused Apply RoPE FWD for thd format", - py::call_guard()); - m.def("fused_rope_thd_backward", &fused_rope_thd_backward, "Fused Apply RoPE BWD for thd format", - py::call_guard()); // Misc m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version", diff --git a/transformer_engine/pytorch/dot_product_attention/rope.py b/transformer_engine/pytorch/dot_product_attention/rope.py index 83698c7bc6..6793f1b760 100644 --- a/transformer_engine/pytorch/dot_product_attention/rope.py +++ b/transformer_engine/pytorch/dot_product_attention/rope.py @@ -7,7 +7,12 @@ """ from typing import Optional, Tuple, Union import torch + import transformer_engine_torch as tex +from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat + + +__all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb"] class RotaryPositionEmbedding(torch.nn.Module): @@ -22,19 +27,24 @@ def __init__( seq_len_interpolation_factor: Optional[int] = None, pretrained_max_position_embeddings: Optional[int] = None, rotary_base: float = 10000.0, + interleaved: bool = False, ): """ Parameters ---------- dim: int - rotary embedding dimension - rotary_percent: float + Rotary embedding dimension. + rotary_percent: float, default = 1.0 Percent of rotary dimension to use for rotary position embeddings. - seq_len_interpolation_factor: int - if not None, discrete positions will be interpolated by this factor via the trick in + seq_len_interpolation_factor: int, default = None + If not None, discrete positions will be interpolated by this factor via the trick in https://arxiv.org/abs/2306.15595 - pretrained_max_position_embeddings: int - pre-trained max_position_embeddings before position interpolation + pretrained_max_position_embeddings: int, default = None + Pre-trained max_position_embeddings before position interpolation. + rotary_base: float, default = 10000.0 + Base of the rotary position embedding. + interleaved: bool, default = False + Whether to use interleaved rotary position embedding. """ super().__init__() if rotary_percent < 1.0: @@ -50,17 +60,18 @@ def __init__( ) self.register_buffer("inv_freq", inv_freq) self.pretrained_max_position_embeddings = pretrained_max_position_embeddings + self.interleaved = interleaved def forward(self, max_seq_len: int, offset: int = 0): """ - Create rotary position embedding frequencies + Create rotary position embedding frequencies. Parameters ---------- max_seq_len: int - sequence length of a sample + Sequence length of a sample. offset: int, default = 0 - fixed offset for freqencies + Fixed offset for frequencies. """ seq = ( torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) @@ -84,7 +95,12 @@ def forward(self, max_seq_len: int, offset: int = 0): freqs = torch.einsum("i , j -> i j", seq, self.inv_freq) # first part even vector components, second part odd vector components, # 2 * dim in dimension size - emb = torch.cat((freqs, freqs), dim=-1) + if not self.interleaved: + emb = torch.cat((freqs, freqs), dim=-1) + else: + emb = torch.stack((freqs.view(-1, 1), freqs.view(-1, 1)), dim=-1).view( + freqs.shape[0], -1 + ) # emb [seq_length, .., dim] return emb.reshape(emb.size(0), 1, 1, emb.size(1)) @@ -104,61 +120,146 @@ def forward( t: torch.Tensor, freqs: torch.Tensor, tensor_format: str = "sbhd", + interleaved: bool = False, cu_seqlens: Union[torch.Tensor, None] = None, cp_size: int = 1, cp_rank: int = 0, ) -> torch.Tensor: - # pylint: disable=missing-function-docstring + """Fused RoPE forward.""" if freqs.dtype != torch.float32: freqs = freqs.float() - if tensor_format == "sbhd": - output = tex.fused_rope_forward(t, freqs, False) - elif tensor_format == "bshd": - output = tex.fused_rope_forward(t.transpose(0, 1), freqs, True).transpose(0, 1) - elif tensor_format == "thd": - output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs, cp_size, cp_rank) - else: - raise ValueError(f"Unsupported tensor_format: {tensor_format}.") + assert tensor_format in ( + "sbhd", + "bshd", + "thd", + ), f"Unsupported tensor_format: {tensor_format}." + output = tex.fused_rope_forward( + t, freqs, QKVFormat[tensor_format], interleaved, cu_seqlens, cp_size, cp_rank + ) ctx.save_for_backward(freqs, cu_seqlens) ctx.tensor_format = tensor_format ctx.cp_size = cp_size ctx.cp_rank = cp_rank + ctx.interleaved = interleaved return output @staticmethod def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: - # pylint: disable=missing-function-docstring + """Fused RoPE backward.""" freqs, cu_seqlens = ctx.saved_tensors - if ctx.tensor_format == "sbhd": - grad_input = tex.fused_rope_backward(grad_output, freqs, False) - elif ctx.tensor_format == "bshd": - grad_input = tex.fused_rope_backward( - grad_output.transpose(0, 1), freqs, True - ).transpose(0, 1) - elif ctx.tensor_format == "thd": - grad_input = tex.fused_rope_thd_backward( - grad_output, cu_seqlens, freqs, ctx.cp_size, ctx.cp_rank - ) - else: - raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.") + grad_input = tex.fused_rope_backward( + grad_output, + freqs, + QKVFormat[ctx.tensor_format], + ctx.interleaved, + cu_seqlens, + ctx.cp_size, + ctx.cp_rank, + ) + + return grad_input, None, None, None, None, None, None - return grad_input, None, None, None, None, None +def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor: + """Change sign so the last dimension becomes [-odd, +even] -def _rotate_half(x: torch.Tensor) -> torch.Tensor: + Args: + x: torch.Tensor. Input tensor. + interleaved: bool. Whether to use interleaved rotary position embedding. + + Returns: + Tensor: Tensor rotated half. """ - change sign so the last dimension becomes [-odd, +even] + if not interleaved: + x1, x2 = torch.chunk(x, 2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + # interleaved + x1 = x[:, :, :, ::2] + x2 = x[:, :, :, 1::2] + x_new = torch.stack((-x2, x1), dim=-1) + return x_new.view(x_new.shape[0], x_new.shape[1], x_new.shape[2], -1) + + +def _apply_rotary_pos_emb_base( + t: torch.Tensor, + freqs: torch.Tensor, + tensor_format: str = "sbhd", + interleaved: bool = False, +) -> torch.Tensor: + """ + Base implementation of applying rotary positional embedding tensor to the input tensor. + + Parameters + ---------- + t: torch.Tensor + Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which rotary positional + embedding will be applied. + freqs: torch.Tensor + Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float', + with `s2 >= s` and `d2 <= d`. + tensor_format: {'sbhd', 'bshd'}, default = 'sbhd' + Should be `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is of shape + `[seq, bs, ...]`. + interleaved: bool, default = False + Whether to use interleaved rotary position embedding. + """ + max_seq_len = freqs.shape[0] + cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0] + + # Only apply the rotary embeddings up to the sequence length of the running + # input. + assert ( + cur_seq_len <= max_seq_len + ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!" + freqs = freqs[:cur_seq_len] + if tensor_format == "bshd": + freqs = freqs.transpose(0, 1) # [seq, 1, 1, dim] -> [1, seq, 1, dim] + # cos/sin first then dtype conversion for better precision + cos_ = torch.cos(freqs).to(t.dtype) + sin_ = torch.sin(freqs).to(t.dtype) + + rot_dim = freqs.shape[-1] + # ideally t_pass is empty so rotary pos embedding is applied to all tensor t + t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + + # first part is cosine component + # second part is sine component, need to change signs with _rotate_half method + t = (t * cos_) + (_rotate_half(t, interleaved) * sin_) + return torch.cat((t, t_pass), dim=-1) + + +def _get_freqs_on_this_cp_rank( + freqs: torch.Tensor, seqlen: int, cp_size: int, cp_rank: int +) -> torch.Tensor: + """Get the position embedding on the current context parallel rank. + + Args: + freqs: torch.Tensor. Positional embedding tensor in shape `[s2, 1, 1, d2]`. + seqlen: int. Length of the current sequence. + cp_size: int. Context parallel world size. + cp_rank: int. Context parallel rank. """ - x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2))) - x1, x2 = x.unbind(dim=-2) - return torch.cat((-x2, x1), dim=-1) + if cp_size > 1: + cp_seg = seqlen // 2 + full_seqlen = cp_size * seqlen + return torch.cat( + [ + freqs[cp_rank * cp_seg : (cp_rank + 1) * cp_seg], + freqs[full_seqlen - (cp_rank + 1) * cp_seg : full_seqlen - cp_rank * cp_seg], + ] + ) + + # cp_size == 1 + return freqs[:seqlen] def apply_rotary_pos_emb( t: torch.Tensor, freqs: torch.Tensor, tensor_format: str = "sbhd", + interleaved: bool = False, fused: bool = False, cu_seqlens: Union[torch.Tensor, None] = None, cp_size: int = 1, @@ -175,11 +276,13 @@ def apply_rotary_pos_emb( freqs: torch.Tensor Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float', with `s2 >= s` and `d2 <= d`. - fused: bool, default = False - Whether to use a fused applying RoPE implementation. tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd' is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True. + interleaved: bool, default = False + Whether to use interleaved rotary position embedding. + fused: bool, default = False + Whether to use a fused applying RoPE implementation. cu_seqlens: torch.Tensor, default = None. Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and dtype torch.int32. Only valid when `tensor_format` is 'thd'. @@ -189,37 +292,40 @@ def apply_rotary_pos_emb( cp_rank: int, default = 0. Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True. """ - if fused: - assert ( - tensor_format != "thd" or cu_seqlens is not None - ), "cu_seqlens must not be None when tensor_format is 'thd'." - return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens, cp_size, cp_rank) - - assert tensor_format in ("sbhd", "bshd"), ( - "Only formats `sbhd` or `bshd` are supported for input tensor `t` " - f"when fused is False, got {tensor_format}." - ) - - max_seq_len = freqs.shape[0] - cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0] - - # Only apply the rotary embeddings up to the sequence length of the running - # input. assert ( - cur_seq_len <= max_seq_len - ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!" - freqs = freqs[:cur_seq_len] - if tensor_format == "bshd": - freqs = freqs.transpose(0, 1) # [seq, 1, 1, dim] -> [1, seq, 1, dim] - # cos/sin first then dtype conversion for better precision - cos_ = torch.cos(freqs).to(t.dtype) - sin_ = torch.sin(freqs).to(t.dtype) + tensor_format != "thd" or cu_seqlens is not None + ), "cu_seqlens must not be None when tensor_format is 'thd'." - rot_dim = freqs.shape[-1] - # ideally t_pass is empty so rotary pos embedding is applied to all tensor t - t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + if fused: + return FusedRoPEFunc.apply( + t, freqs, tensor_format, interleaved, cu_seqlens, cp_size, cp_rank + ) - # first part is cosine component - # second part is sine component, need to change signs with _rotate_half method - t = (t * cos_) + (_rotate_half(t) * sin_) - return torch.cat((t, t_pass), dim=-1) + # Unfused THD format + if tensor_format == "thd": + cu_seqlens = cu_seqlens // cp_size + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + return torch.cat( + [ + _apply_rotary_pos_emb_base( + x.unsqueeze(1), + _get_freqs_on_this_cp_rank(freqs, x.size(0), cp_size, cp_rank), + interleaved=interleaved, + ) + for x in torch.split(t, seqlens) + ] + ).squeeze(1) + + # Unfused SBHD/BSHD format + if tensor_format == "sbhd": + seqlen = t.size(0) + elif tensor_format == "bshd": + seqlen = t.size(1) + else: + raise ValueError(f"Unsupported tensor_format: {tensor_format}.") + return _apply_rotary_pos_emb_base( + t, + _get_freqs_on_this_cp_rank(freqs, seqlen, cp_size, cp_rank), + tensor_format, + interleaved=interleaved, + ) From a3ba4dffaccccb7c191e19ffcfa354b92ba2a466 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Mon, 7 Apr 2025 14:17:46 +0800 Subject: [PATCH 26/29] Fix cpp warnings (#1639) * fix cpp warning Signed-off-by: Xin Yao * more fix Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao --- .../common/util/cast_gated_kernels.cuh | 6 ------ .../common/util/cast_kernels.cuh | 3 --- .../common/util/dequantize_kernels.cuh | 4 +--- .../csrc/extensions/comm_gemm_overlap.cpp | 20 +++++++++---------- .../pytorch/csrc/extensions/gemm.cpp | 3 +-- .../pytorch/csrc/extensions/padding.cpp | 2 +- .../pytorch/csrc/extensions/quantizer.cpp | 6 ++++-- .../pytorch/csrc/extensions/transpose.cpp | 2 +- 8 files changed, 18 insertions(+), 28 deletions(-) diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index e2240ba658..b29bc53c14 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -96,8 +96,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) constexpr size_t in_mem = in_act_mem + in_gate_mem; constexpr size_t out_act_mem = buff_size_aligned_out; - constexpr size_t out_gate_mem = buff_size_aligned_out; - constexpr size_t out_mem = out_act_mem + out_gate_mem; // const size_t in_transaction_size = grad_mem + in_mem; constexpr size_t in_transaction_size = buff_elems * sizeof(IType); @@ -108,7 +106,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) IType *in_gate_sh = reinterpret_cast(dshmem + grad_mem + in_act_mem); OType *out_act_sh = reinterpret_cast(dshmem + grad_mem + in_mem); OType *out_gate_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_act_mem); - // uint64_t *mbar = reinterpret_cast(dshmem + grad_mem + in_mem + out_mem); const uint64_t *TMAP_grad_in = reinterpret_cast(&tensor_map_grad); const uint64_t *TMAP_in_act = reinterpret_cast(&tensor_map_input_act); @@ -289,7 +286,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; - constexpr bool COMPUTE_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING; constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128 constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; // 4 = 128 / 32 @@ -826,8 +822,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out size_t scale_stride_rowwise = USE_ROWWISE_SCALING ? output->scale_inv.shape[1] : 1; size_t scale_stride_colwise = USE_COLWISE_SCALING ? output->columnwise_scale_inv.shape[1] : 1; - float *const amax_ptr = reinterpret_cast(output->amax.dptr); - e8m0_t *const scales_rowwise_ptr = USE_ROWWISE_SCALING ? reinterpret_cast(output->scale_inv.dptr) : nullptr; e8m0_t *const scales_colwise_ptr = diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 412a6f6ef0..c6a8b0f23c 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -142,7 +142,6 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) OType out_colwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; constexpr int shmem_buff_size = sizeof(in_sh) / MXFP8_BUFFERS_NUM; - constexpr int transaction_size = shmem_buff_size * (IS_DACT ? 2 : 1); const bool is_master_thread = (threadIdx.x == 0); @@ -513,7 +512,6 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) __shared__ alignas(128) OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; constexpr int shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM; - constexpr int transaction_size = shmem_buff_size * (IS_DACT ? 2 : 1); const bool is_master_thread = (threadIdx.x == 0); @@ -927,7 +925,6 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, bool use_colwise_scaling = output->has_columnwise_data(); checkCuDriverContext(stream); NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); - const auto &input_shape = input.data.shape; NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); if (use_rowwise_scaling) { diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index c885c69333..967a0df3aa 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -56,7 +56,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t scales_stride) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; - constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128 constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; // 4 = 128 / 32 @@ -65,8 +64,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) constexpr size_t SCALES_COLWISE_PER_CHUNK_X = CHUNK_DIM_X; // 128 constexpr size_t THREADS_PER_SCALE_X_ROWWISE = - DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16 - constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; // 2 + DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16 const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X; diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index 6d05869c36..aefb1d371d 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -157,15 +157,15 @@ void CommOverlap::copy_into_buffer(py::handle input, py::handle quantizer, bool char *ubuf_ptr = reinterpret_cast(_ubuf.dptr()); if (local_chunk) { - if (input_tensor.numel() * _tp_size > (int64_t)_ubuf.numel()) + if (input_tensor.numel() * _tp_size > _ubuf.numel()) NVTE_ERROR("input is larger than the local communication buffer!"); - if (input_tensor.element_size() != (int64_t)_ubuf.element_size()) + if (input_tensor.element_size() != _ubuf.element_size()) NVTE_ERROR("input data type does not match communication buffer!"); ubuf_ptr += (_ubuf.numel() / _tp_size) * _tp_id * _ubuf.element_size(); } else { - if (input_tensor.numel() > (int64_t)_ubuf.numel()) + if (input_tensor.numel() > _ubuf.numel()) NVTE_ERROR("input is larger than the global communication buffer!"); - if (input_tensor.element_size() != (int64_t)_ubuf.element_size()) + if (input_tensor.element_size() != _ubuf.element_size()) NVTE_ERROR("input data type does not match communication buffer!"); } @@ -189,7 +189,7 @@ py::object CommOverlap::get_buffer(py::handle quantizer, bool local_chunk, std::vector torch_shape; if (shape.has_value()) { torch_shape = shape.value(); - auto requested = product(torch_shape); + size_t requested = product(torch_shape); auto expected = local_chunk ? _ubuf.numel() / _tp_size : _ubuf.numel(); NVTE_CHECK(requested == expected, "Number of elements in the requested shape (", requested, ") does not match allocated buffer size (", expected, ")!"); @@ -253,18 +253,18 @@ void CommOverlapP2P::copy_into_buffer(py::handle input, py::handle quantizer, bo at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); if (local_chunk) { // Copy input to the target ubuf chunk by rank offset - if (input_tensor.numel() * _tp_size > (int64_t)_ubuf.numel()) + if (input_tensor.numel() * _tp_size > _ubuf.numel()) NVTE_ERROR("input is larger than the local communication buffer!"); - if (input_tensor.element_size() != (int64_t)_ubuf.element_size()) + if (input_tensor.element_size() != _ubuf.element_size()) NVTE_ERROR("input data type does not match communication buffer!"); NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].dptr(), input_ptr, input_tensor.numel() * input_tensor.element_size(), cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main)); } else { - if (input_tensor.numel() > (int64_t)_ubuf.numel()) + if (input_tensor.numel() > _ubuf.numel()) NVTE_ERROR("input is larger than the global communication buffer!"); - if (input_tensor.element_size() != (int64_t)_ubuf.element_size()) + if (input_tensor.element_size() != _ubuf.element_size()) NVTE_ERROR("input data type does not match communication buffer!"); NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.dptr(), input_ptr, input_tensor.numel() * input_tensor.element_size(), @@ -280,7 +280,7 @@ py::object CommOverlapP2P::get_buffer(py::handle quantizer, bool local_chunk, std::vector torch_shape; if (shape.has_value()) { torch_shape = shape.value(); - auto requested = product(torch_shape); + size_t requested = product(torch_shape); auto expected = local_chunk ? _ubufs[_tp_id].numel() : _ubuf.numel(); NVTE_CHECK(requested == expected, "Number of elements in the requested shape (", requested, ") does not match allocated buffer size (", expected, ")!"); diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index c5b96072df..ff61cd940c 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -318,12 +318,11 @@ std::optional> te_general_grouped_gemm( std::vector single_output_begins; std::vector single_output_ends; - int slicing_dim; if (single_output && D == std::nullopt) { NVTE_ERROR("not implemented, D should be allocated for single output case."); } - void* output_data_ptr; + void* output_data_ptr = nullptr; if (single_output) { output_data_ptr = (*D)[0].data_ptr(); } diff --git a/transformer_engine/pytorch/csrc/extensions/padding.cpp b/transformer_engine/pytorch/csrc/extensions/padding.cpp index b9972af7cb..e03dcb2946 100644 --- a/transformer_engine/pytorch/csrc/extensions/padding.cpp +++ b/transformer_engine/pytorch/csrc/extensions/padding.cpp @@ -17,7 +17,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, NVTE_CHECK(input.dim() == 2, "Dimension of input must equal 2."); NVTE_CHECK(output.dim() == 2, "Dimension of output must equal 2."); - const int num_tensors = input_row_list.size(); + const auto num_tensors = input_row_list.size(); // Extract properties from PyTorch tensors std::vector input_dptr_list, output_dptr_list; std::vector> input_shape_list, output_shape_list; diff --git a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp index 19d8a75a64..9ac6292e53 100644 --- a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp @@ -323,7 +323,8 @@ std::pair Float8BlockQuantizer::create_tensor( "Expected 1 or 2. Got ", block_scaling_dim); } - scale_inv_rowwise = at::empty({sinv0, sinv1}, scale_opts); + scale_inv_rowwise = + at::empty({static_cast(sinv0), static_cast(sinv1)}, scale_opts); tensor.set_rowwise_data(data_rowwise.data_ptr(), this->dtype, shape); tensor.set_rowwise_scale_inv(scale_inv_rowwise.data_ptr(), DType::kFloat32, std::vector{sinv0, sinv1}); @@ -359,7 +360,8 @@ std::pair Float8BlockQuantizer::create_tensor( block_scaling_dim); } data_colwise = at::empty(torch_columnwise_shape, opts); - scale_inv_colwise = at::empty({sinv0, sinv1}, scale_opts); + scale_inv_colwise = + at::empty({static_cast(sinv0), static_cast(sinv1)}, scale_opts); tensor.set_columnwise_data(data_colwise.data_ptr(), this->dtype, columnwise_shape); tensor.set_columnwise_scale_inv(scale_inv_colwise.data_ptr(), DType::kFloat32, diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index 37fbddcc18..a873586032 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -21,7 +21,7 @@ std::vector fused_multi_quantize(std::vector input_list, auto none = py::none(); // create TE tensors from input - for (int i = 0; i < input_list.size(); i++) { + for (size_t i = 0; i < input_list.size(); i++) { auto input_tensor = makeTransformerEngineTensor(input_list[i], none); const NVTEShape input_shape = input_tensor.shape(); From c84d170802dea33ab0bf95aded0611e8abf9cc46 Mon Sep 17 00:00:00 2001 From: Jianbin Chang Date: Mon, 7 Apr 2025 21:37:46 +0800 Subject: [PATCH 27/29] Support FP8 primary weight in FSDP training (#1630) Support fp8 primary weight in fsdp training Signed-off-by: jianbinc Co-authored-by: Kirthi Shankar Sivamani --- .../run_cast_master_weights_to_fp8.py | 280 +++++++++++++++++- transformer_engine/pytorch/tensor/utils.py | 53 +++- 2 files changed, 318 insertions(+), 15 deletions(-) diff --git a/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py index 939684f152..ec06bb7e48 100644 --- a/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py @@ -21,7 +21,11 @@ ) import transformer_engine.pytorch as te from transformer_engine.pytorch.tensor import QuantizedTensor, cast_master_weights_to_fp8 -from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Tensor, + Float8CurrentScalingQuantizer, +) +from transformer_engine.pytorch.tensor.utils import replace_raw_data def _get_raw_data(quantized_tensor): @@ -228,6 +232,279 @@ def step(self): weight.data.copy_(master_weight) +class MiniFSDP: + def __init__(self, weights, lr, dp_group): + rank = dist.get_rank(dp_group) + world_size = dist.get_world_size(dp_group) + + self.weights = weights + self.lr = lr + self.dp_group = dp_group + + # Flatten the weights and pad to align with world size + raw_data_list = [ + _get_raw_data(w).view(-1) if isinstance(w, Float8Tensor) else w.view(-1) + for w in weights + ] + if isinstance(weights[0], Float8Tensor): + raw_data_list = [_get_raw_data(w).view(-1) for w in weights] + else: + raw_data_list = [w.view(-1) for w in weights] + self.flatten_weight, original_length = self._flatten_tensors_with_pad(raw_data_list) + + # Split flattened weights into shards + self.local_weight_shard = torch.chunk(self.flatten_weight, world_size)[rank] + self.local_main_grad_shard = torch.zeros_like(self.local_weight_shard) + shard_size = self.flatten_weight.size(0) // world_size + + # Map original tensors to flattened indices + tensor_indices = [] + cumulative_length = 0 + for tensor in raw_data_list: + length = tensor.size(0) + tensor_indices.append((cumulative_length, cumulative_length + length)) + cumulative_length += length + + # Build shard index mappings + self.weight_indices = [] + self.shard_indices = [] + for idx, (start, end) in enumerate(tensor_indices): + shard_start = rank * shard_size + shard_end = shard_start + shard_size + adjusted_end = min(shard_end, original_length) + + if start <= adjusted_end and end >= shard_start: + start_idx = max(start, shard_start) + end_idx = min(end, adjusted_end) + self.weight_indices.append((start_idx - start, end_idx - start)) + self.shard_indices.append((start_idx - shard_start, end_idx - shard_start)) + else: + self.weight_indices.append((None, None)) + self.shard_indices.append((None, None)) + + if isinstance(weights[idx], Float8Tensor): + replace_raw_data( + weights[idx], self.flatten_weight[start:end].view(weights[idx].shape) + ) + else: + weights[idx].data = self.flatten_weight[start:end].view(weights[idx].shape) + + # Initialize local model weights and high-precision master weights + self.local_weights = [] + self.master_weights = [] + for i, weight in enumerate(self.weights): + weight_start, weight_end = self.weight_indices[i] + shard_start, shard_end = self.shard_indices[i] + if shard_start is not None and shard_end is not None: + local_weight_shard = self.local_weight_shard[shard_start:shard_end] + self.local_weights.append(local_weight_shard) + + if isinstance(weight, QuantizedTensor): + high_precision_init_val = weight.get_high_precision_init_val().view(-1) + master_weight_shard = high_precision_init_val.to(weight.device).float()[ + weight_start:weight_end + ] + else: + master_weight_shard = weight.detach().view(-1).float()[weight_start:weight_end] + self.master_weights.append(master_weight_shard) + else: + self.local_weights.append(None) + self.master_weights.append(None) + setattr( + weight, "main_grad", torch.zeros_like(weight, dtype=torch.float32, device="cuda") + ) + + def _flatten_tensors_with_pad(self, tensors): + """ + Flatten the list of tensors and pad them to align with the world size. + + Args: + tensors (list): List of tensors to flatten. + + Returns: + tuple: Flattened tensor and its original length before padding. + """ + world_size = dist.get_world_size(self.dp_group) + + flatten_tensor = torch.cat(tensors) + original_length = flatten_tensor.size(0) + + padding_needed = (world_size - original_length % world_size) % world_size + if padding_needed > 0: + flatten_tensor = torch.cat( + [flatten_tensor, torch.zeros(padding_needed, dtype=flatten_tensor.dtype)] + ) + + return flatten_tensor, original_length + + def zero_grad(self): + for weight in self.weights: + weight.grad = None + weight.main_grad.zero_() + + def step(self): + """ + Perform an optimization step for the distributed sharded model. + + This method includes: + 1. Gradient reduce-scatter: Synchronize gradients across all processes. + 2. Master weight update: Update high-precision master weights using local gradients. + 3. Precision casting: Cast updated master weights to FP8 or BF16 precision. + 4. Weight synchronization: All-gather updated weights across all processes. + + Returns: + None + """ + # Step 1: Reduce-scatter the gradients + main_grad_buffer, _ = self._flatten_tensors_with_pad( + [weight.main_grad.view(-1) for weight in self.weights] + ) + main_grad_buffer = main_grad_buffer.to(self.local_main_grad_shard.dtype) + dist.reduce_scatter_tensor( + self.local_main_grad_shard, main_grad_buffer, group=self.dp_group + ) + + # Step 2: Update the master weights + for weight, master_weight, (shard_start, shard_end) in zip( + self.weights, self.master_weights, self.shard_indices + ): + if master_weight is None: + continue + + # Extract the local gradient shard for this weight + grad = self.local_main_grad_shard[shard_start:shard_end] + + # Update the master weight using gradient descent + master_weight -= grad * self.lr + + # Step 3: Cast master weights to FP8 or BF16 precision + if isinstance(self.weights[0], Float8Tensor): + local_weights = [] + for model_weight, local_weight in zip(self.weights, self.local_weights): + if local_weight is None: + local_weights.append(None) + continue + + quantizer = model_weight._get_quantizer() + if isinstance(quantizer, Float8CurrentScalingQuantizer): + local_weight = quantizer.create_tensor_from_data( + local_weight.view(-1), + model_weight.dtype, + ) + local_weights.append(local_weight) + + cast_master_weights_to_fp8( + self.weights, + self.master_weights, + [idx[0] for idx in self.weight_indices], + self.dp_group, + local_weights, + ) + else: + for weight, master_weight in zip(self.local_weights, self.master_weights): + if master_weight is None: + continue + + # Copy updated master weights to local weights + weight.data.copy_(master_weight) + + # Step 4: All-gather updated weights across processes + dist.all_gather_into_tensor( + self.flatten_weight, self.local_weight_shard, group=self.dp_group + ) + + +def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group): + rank = dist.get_rank(dp_group) + world_size = dist.get_world_size(dp_group) + + # Configuration constants + NUM_STEPS = 100 + SEED = 12345 + + torch.manual_seed(SEED) + torch.cuda.manual_seed(SEED) + + mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)] + mock_group = mock_groups[rank] + + linear_kwargs = { + "params_dtype": torch.bfloat16, + "bias": False, + "fuse_wgrad_accumulation": False, + } + + # Create model with FP8 weights + with te.fp8.fp8_model_init( + enabled=quantization is not None, + recipe=quantization_recipe(quantization), + preserve_high_precision_init_val=True, + ): + model_fp8 = nn.Sequential( + te.Linear(128, 256, **linear_kwargs), + te.Linear(256, 256 * 3, **linear_kwargs), + te.Linear(256 * 3, 128, **linear_kwargs), + ) + + # Create model with BF16 weights + model = nn.Sequential( + te.Linear(128, 256, **linear_kwargs), + te.Linear(256, 256 * 3, **linear_kwargs), + te.Linear(256 * 3, 128, **linear_kwargs), + ) + + # Make sure the BF16 model and FP8 model have the same initial weights + for w_fp8, w in zip(model_fp8.parameters(), model.parameters()): + high_precision_init_val = w_fp8.get_high_precision_init_val() + w.data.copy_(high_precision_init_val) + + optimizer_fp8 = MiniFSDP([w for w in model_fp8.parameters()], 10.0, dp_group) + optimizer = MiniFSDP([w for w in model.parameters()], 10.0, dp_group) + + for _ in range(100): + optimizer_fp8.zero_grad() + optimizer.zero_grad() + + inputs = [ + torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size) + ] + # Choose based on rank to make sure the inputs of different ranks are different. + x = inputs[rank] + + with te.fp8.fp8_autocast( + enabled=quantization is not None, + fp8_recipe=quantization_recipe(quantization), + fp8_group=mock_group, + ): + y_fp8 = model_fp8(x) + + with te.fp8_autocast( + enabled=quantization is not None, + fp8_recipe=quantization_recipe(quantization), + fp8_group=mock_group, + ): + y = model(x) + + targets = [torch.randn_like(y) for _ in range(world_size)] + # Choose based on rank to make sure the targets of different ranks are different. + target = targets[rank] + loss_fp8 = nn.MSELoss()(y_fp8, target) + loss = nn.MSELoss()(y, target) + + loss_fp8.backward() + loss.backward() + + optimizer_fp8.step() + optimizer.step() + + torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0) + + print( + f"✅ Successfully validated FSDP {NUM_STEPS} training steps with" + f" {quantization} quantization" + ) + + def _test_zero_1(dp_group): """Make sure the implementation of zero-1 optimizer is correct""" rank = dist.get_rank(dp_group) @@ -389,6 +666,7 @@ def main(argv=None, namespace=None): dp_group = dist.new_group(backend="nccl") _test_zero_1(dp_group) _test_cast_master_weights_to_fp8(args.quantization, dp_group) + _test_fsdp_cast_master_weights_to_fp8(args.quantization, dp_group) dist.destroy_process_group() return 0 diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index 34992f08bc..33c0953d94 100644 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -38,7 +38,9 @@ def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor): raise ValueError(f"replace_raw_data for {type(tensor)} is not supported yet") -def cast_master_weights_to_fp8(model_weights, master_weights, start_offsets, group): +def cast_master_weights_to_fp8( + model_weights, master_weights, start_offsets, group, fsdp_shard_model_weights=None +): r"""Helper function to cast master weights to FP8 primary weights. This is intended for use with ZeRO/FSDP. Each rank has a shard of @@ -55,14 +57,23 @@ def cast_master_weights_to_fp8(model_weights, master_weights, start_offsets, gro should be updated. group : The distributed group to do amax reduction. Typically it's the data parallel group. + fsdp_shard_model_weights : list of FSDP shard model weights. If None, it means that the model weights are + not sharded. Otherwise, it means that the model weights are sharded and we get + target model weights data storage using the FSDP shard model weights. """ delayed_scaling_params = [] current_scaling_params = [] - for model_weight, master_weight, start_offset in zip( - model_weights, master_weights, start_offsets + if fsdp_shard_model_weights is None: + use_fsdp_shard_model_weights = False + fsdp_shard_model_weights = [None] * len(model_weights) + else: + use_fsdp_shard_model_weights = True + + for model_weight, master_weight, start_offset, fsdp_shard_model_weight in zip( + model_weights, master_weights, start_offsets, fsdp_shard_model_weights ): # Clear `_high_precision_init_val` of model_weight automatically. # - Master weights are initialized from model weights, if we use fp8 primary weights to @@ -88,9 +99,13 @@ def cast_master_weights_to_fp8(model_weights, master_weights, start_offsets, gro quantizer = model_weight._get_quantizer() if isinstance(quantizer, Float8Quantizer): - delayed_scaling_params.append((model_weight, master_weight, start_offset)) + delayed_scaling_params.append( + (model_weight, master_weight, start_offset, fsdp_shard_model_weight) + ) elif isinstance(quantizer, Float8CurrentScalingQuantizer): - current_scaling_params.append((model_weight, master_weight, start_offset)) + current_scaling_params.append( + (model_weight, master_weight, start_offset, fsdp_shard_model_weight) + ) elif isinstance(quantizer, MXFP8Quantizer): raise NotImplementedError( "cast_master_weights_to_fp8 for MXFP8BlockScaling is not supported yet" @@ -101,12 +116,16 @@ def cast_master_weights_to_fp8(model_weights, master_weights, start_offsets, gro ) if len(delayed_scaling_params) > 0: - _cast_master_weights_to_fp8_delayed_scaling(delayed_scaling_params, group) + _cast_master_weights_to_fp8_delayed_scaling( + delayed_scaling_params, group, use_fsdp_shard_model_weights + ) if len(current_scaling_params) > 0: - _cast_master_weights_to_fp8_current_scaling(current_scaling_params, group) + _cast_master_weights_to_fp8_current_scaling( + current_scaling_params, group, use_fsdp_shard_model_weights + ) -def _cast_master_weights_to_fp8_delayed_scaling(params, group): +def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_model_weights=False): r"""Helper function to cast master weights to FP8 primary weights for delayed scaling. Parameters @@ -115,13 +134,14 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group): indicating the starting index of the master weight in the model weight. group : The distributed group to do amax reduction. Typically it's the data parallel group. + use_fsdp_shard_model_weights : bool, if True, it means that the model weights are sharded. """ # Collect amaxes to do reduce-max among dp group. # Collect scales and scale_invs to update scale_invs of the fp8 weights. amaxes, scales, scale_invs = [], [], [] - for model_weight, master_weight, start_offset in params: + for model_weight, master_weight, start_offset, shard_model_weight_raw in params: # Reset transpose cache for all model weights. # We cannot create transpose cache here because users (like megatron) may want to overlap # the all-gather of model weights and forward process, so the model weight is not updated @@ -147,7 +167,8 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group): # master_weight may be smaller than model_weight because it could be distributed across # multiple ranks. So we need to create a dummy weight using the raw data from model_weight. - shard_model_weight_raw = model_weight._data.view(-1)[start_offset:end_offset] + if not use_fsdp_shard_model_weights: + shard_model_weight_raw = model_weight._data.view(-1)[start_offset:end_offset] shard_model_weight_fp8 = quantizer.create_tensor_from_data( shard_model_weight_raw.view(1, -1), model_weight.dtype, @@ -186,7 +207,7 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group): ) -def _cast_master_weights_to_fp8_current_scaling(params, group): +def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_model_weights=False): r"""Helper function to cast master weights to FP8 primary weights for current scaling. Parameters @@ -195,6 +216,7 @@ def _cast_master_weights_to_fp8_current_scaling(params, group): indicating the starting index of the master weight in the model weight. group : The distributed group to do amax reduction. Typically it's the data parallel group. + use_fsdp_shard_model_weights : bool, if True, it means that the model weights are sharded. """ # Parameter attributes @@ -219,7 +241,7 @@ def _cast_master_weights_to_fp8_current_scaling(params, group): # amaxes in a contiguous buffer. If the master weight is None, the corresponding amax # will be set to 0. # --------------------------------------------------------------------------------------------- - for (model_weight, master_weight, _), amax in zip(params, amaxes): + for (model_weight, master_weight, _, _), amax in zip(params, amaxes): # Make sure all the model weights have the same numerical options. quantizer = model_weight._get_quantizer() @@ -260,7 +282,9 @@ def _cast_master_weights_to_fp8_current_scaling(params, group): # --------------------------------------------------------------------------------------------- # Step 4: Cast master weights to FP8. # --------------------------------------------------------------------------------------------- - for (model_weight, master_weight, start_offset), scale in zip(params, scales): + for (model_weight, master_weight, start_offset, model_weight_fragment), scale in zip( + params, scales + ): # Reset transpose cache for all model weights. # We cannot create transpose cache here because users (like megatron) may want to overlap # the all-gather of model weights and forward process, so the model weight is not updated @@ -274,7 +298,8 @@ def _cast_master_weights_to_fp8_current_scaling(params, group): # Cast master weight to FP8 end_offset = start_offset + master_weight.numel() - model_weight_fragment = model_weight.reshape(-1)[start_offset:end_offset] + if not use_fsdp_shard_model_weights: + model_weight_fragment = model_weight.reshape(-1)[start_offset:end_offset] quantizer = Float8Quantizer( scale=scale, amax=torch.Tensor(), From b362a6e071a8867d5d376351a43beb76d547e4e1 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 7 Apr 2025 13:30:33 -0400 Subject: [PATCH 28/29] Removing NVTE_NO_SCALING (#1650) * rm no scaling enum Signed-off-by: Phuong Nguyen * update jax enum Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- .../common/include/transformer_engine/transformer_engine.h | 3 +-- transformer_engine/jax/cpp_extensions/gemm.py | 5 +++++ transformer_engine/jax/csrc/extensions/gemm.cpp | 2 +- transformer_engine/jax/quantize/scaling_modes.py | 4 ++-- 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index c539265e62..ba47b9d38c 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -86,8 +86,7 @@ enum NVTEScalingMode { */ NVTE_BLOCK_SCALING_1D = 2, NVTE_BLOCK_SCALING_2D = 3, - NVTE_INVALID_SCALING = 4, - NVTE_NO_SCALING = 5 + NVTE_INVALID_SCALING = 100 }; /*! \brief TE Tensor type diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 736105dd75..1df2bcc97f 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -491,6 +491,11 @@ def grouped_gemm( bias_contig = jnp.empty(0) if bias_list is None else jnp.concatenate(bias_contig_) dim_list = jnp.array(dims, dtype=jnp.int32) + # TE/common does not support NVTE_NO_SCALING yet + # It expects NVTE_DELAYED_TENSOR_SCALING as default for FP32, BF16, FP16 + if scaling_mode == ScalingMode.NVTE_NO_SCALING: + scaling_mode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING + # Perform batched GEMM on flattened inputs out_contig = GroupedGemmPrimitive.outer_primitive.bind( lhs_contig, diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index c1e008a5bc..e5ec160c91 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -90,7 +90,7 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh auto lhs_sinv_shape = std::vector{1, 1}; auto rhs_sinv_shape = std::vector{1, 1}; - if (scaling_mode == NVTE_NO_SCALING || scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { auto lhs_i = TensorWrapper(static_cast(lhs_ptr), lhs_shape, lhs_dtype, nullptr, nullptr, reinterpret_cast(lhs_sinv_ptr)); auto rhs_i = TensorWrapper(static_cast(rhs_ptr), rhs_shape, rhs_dtype, nullptr, diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index a9c93a3553..95bbc9bb41 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -233,8 +233,8 @@ class ScalingMode(Enum): NVTE_DELAYED_TENSOR_SCALING = 0 NVTE_MXFP8_1D_SCALING = 1 - NVTE_INVALID_SCALING = 4 - NVTE_NO_SCALING = 5 + NVTE_INVALID_SCALING = 100 + NVTE_NO_SCALING = 1000 def _get_impl(self) -> ScalingModeMetadataImpl: """Get the implementation for this scaling mode. From f3123cf37f876588d385fb86ebc96375fbb6a4a4 Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Mon, 7 Apr 2025 13:19:38 -0700 Subject: [PATCH 29/29] Prune number of tests. Signed-off-by: Keith Wyss --- .../test_float8_blockwise_gemm_exact.py | 82 ++++++++++++++----- 1 file changed, 60 insertions(+), 22 deletions(-) diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py index 9ddb4b9989..9a1cfa2db8 100644 --- a/tests/pytorch/test_float8_blockwise_gemm_exact.py +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -326,25 +326,68 @@ def cublas_gemm_test_constraint_enforced( (128, 128, 128), (256, 128, 256), # non 128x128 divisible input shapes - (16, 128, 128), - (16, 64, 128), - (128, 160, 128), (320, 128, 336), (320, 64, 336), # k > 128 (256, 256, 256), (320, 256, 336), - (256, 512, 256), - (256, 1024, 256), - (1024, 1024, 1024), (1024, 4096, 1024), - (512, 128, 512), - (768, 128, 768), - (1024, 128, 1024), - (1536, 128, 1536), - (2048, 128, 2048), - (4096, 128, 4096), - (4096, 512, 3072), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("noise_type", ["normal"], ids=str) +@pytest.mark.parametrize("x_magnitude", [1], ids=str) +@pytest.mark.parametrize("w_magnitude", [1], ids=str) +@pytest.mark.parametrize("accumulate", [False], ids=["no_accumulate"]) +@pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) +@pytest.mark.parametrize( + "is_x_1d_scaled, is_w_1d_scaled", + [ + (True, False), + (True, True), + (False, True), + ], + ids=["1Dx2D", "1Dx1D", "2Dx1D"], +) +def test_cublas_gemm_fp8_blockwise_shape_varying( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, +): + cublas_gemm_fp8_blockwise_case( + x_dtype, + w_dtype, + out_dtype, + M, + K, + N, + noise_type, + x_magnitude, + w_magnitude, + accumulate, + use_split_accumulator, + is_x_1d_scaled, + is_w_1d_scaled, + ) + + +@pytest.mark.parametrize( + "M, K, N", + [ + (256, 128, 256), + (320, 256, 336), ], ) @pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) @@ -364,7 +407,7 @@ def cublas_gemm_test_constraint_enforced( ], ids=["1Dx2D", "1Dx1D", "2Dx1D"], ) -def test_cublas_gemm_fp8_blockwise_shape_varying( +def test_cublas_gemm_fp8_blockwise_accumulate_magnitude_varying( x_dtype, w_dtype, out_dtype, @@ -402,18 +445,16 @@ def test_cublas_gemm_fp8_blockwise_shape_varying( # k = 128 (256, 128, 256), # non 128x128 divisible input shapes - (16, 128, 128), (320, 64, 336), # k > 128 (256, 256, 256), - (4096, 128, 4096), ], ) @pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) @pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) @pytest.mark.parametrize("noise_type", ["normal"], ids=str) -@pytest.mark.parametrize("x_magnitude", [1e-28, 1, 1e3], ids=str) +@pytest.mark.parametrize("x_magnitude", [1e-3], ids=str) @pytest.mark.parametrize("w_magnitude", [1], ids=str) @pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) @pytest.mark.parametrize("use_split_accumulator", [True], ids=["split_acc"]) @@ -468,7 +509,6 @@ def test_cublas_gemm_fp8_blockwise_bias( (16, 128, 128), (320, 64, 336), # k > 128 - (256, 256, 256), (4096, 128, 4096), ], ) @@ -540,15 +580,13 @@ def test_cublas_gemm_fp8_blockwise_columnwise( # k = 128 (256, 128, 256), # non 128x128 divisible input shapes - (16, 128, 128), (320, 64, 336), # k > 128 (256, 256, 256), - (4096, 128, 4096), ], ) -@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) -@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("x_dtype", [torch.float8_e4m3fn], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float8_e4m3fn], ids=str) @pytest.mark.parametrize("out_dtype", [torch.bfloat16], ids=str) @pytest.mark.parametrize("noise_type", ["normal"], ids=str) @pytest.mark.parametrize("x_magnitude", [1], ids=str)