From 19b6b08365005cbb8501b519467450cc63003dbf Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Fri, 8 May 2026 20:49:30 -0700 Subject: [PATCH 01/57] Initial implementation Signed-off-by: Ziang Li --- docs/envvars.rst | 6 + tests/cpp/operator/test_dequantize_nvfp4.cu | 38 ++- tests/cpp/test_common.cu | 6 + tests/cpp/test_common.h | 1 + .../nvfp4/test_nvfp4_quantize_exact.py | 35 ++- tests/pytorch/test_backward_override.py | 26 +- tests/pytorch/test_cuda_graphs.py | 12 +- tests/pytorch/test_recipe.py | 87 +++++- tests/pytorch/test_sanity.py | 12 +- tests/pytorch/utils.py | 13 +- .../common/cast/dispatch/quantize.cuh | 18 +- .../common/cast/nvfp4/core_nvfp4.cuh | 147 ++++++++- .../common/cast/nvfp4/dequantize_nvfp4.cuh | 10 +- .../quantize_transpose_nvfp4_tuned_1D.cuh | 264 +++++++++++----- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 4 + transformer_engine/common/common.h | 13 +- .../transformer_engine/transformer_engine.h | 21 ++ transformer_engine/common/recipe/__init__.py | 13 + .../common/transformer_engine.cpp | 12 + .../common/transpose/cast_transpose.h | 4 +- ...quantize_transpose_vector_blockwise_fp4.cu | 288 +++++++++++++----- transformer_engine/pytorch/csrc/common.h | 1 + .../pytorch/csrc/extensions/cast.cpp | 32 +- transformer_engine/pytorch/csrc/quantizer.cpp | 18 ++ .../pytorch/csrc/type_converters.cpp | 2 + .../custom_recipes/quantization_ref_nvfp4.py | 142 +++++++-- transformer_engine/pytorch/quantization.py | 1 + .../pytorch/tensor/grouped_tensor.py | 5 +- .../pytorch/tensor/nvfp4_tensor.py | 7 + .../tensor/storage/grouped_tensor_storage.py | 44 ++- .../tensor/storage/nvfp4_tensor_storage.py | 8 + 31 files changed, 1056 insertions(+), 234 deletions(-) diff --git a/docs/envvars.rst b/docs/envvars.rst index ffbad409d4..37b8028533 100644 --- a/docs/envvars.rst +++ b/docs/envvars.rst @@ -287,6 +287,12 @@ Kernel Configuration :Default: ``0`` :Description: Enable row-scaled NVFP4 tensors for forward activation quantizers in the ``NVFP4BlockScaling`` recipe. When set to ``1`` (or when ``NVFP4BlockScaling(row_scaled_activation=True)`` is used), rowwise ``amax`` metadata is stored as one FP32 value per tensor row instead of a single scalar. +.. envvar:: NVTE_NVFP4_ENABLE_4OVER6 + + :Type: ``int`` (0 or 1) + :Default: ``0`` + :Description: Enable 4over6 candidate selection for NVFP4 1D quantization in the ``NVFP4BlockScaling`` recipe. This mode currently requires RHT, stochastic rounding, and 2D quantization to be disabled, either with the corresponding recipe fields or with :envvar:`NVTE_NVFP4_DISABLE_RHT`, :envvar:`NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING`, and :envvar:`NVTE_NVFP4_DISABLE_2D_QUANTIZATION`. + Torch Compilation and Fusion ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu index eb9e8bce23..a0c53b8f86 100644 --- a/tests/cpp/operator/test_dequantize_nvfp4.cu +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -46,8 +46,9 @@ void compute_ref_dequantize_nvfp4(const uint8_t *packed_data, OType *output, size_t rows, size_t cols, - size_t scale_stride) { - constexpr float factor_inv = 1.0f / (6.0f * 448.0f); + size_t scale_stride, + bool use_4over6) { + const float factor_inv = 1.0f / (6.0f * (use_4over6 ? 256.0f : 448.0f)); constexpr size_t BLOCK_SIZE = 16; const size_t Mread = cols / BLOCK_SIZE; const size_t bytes_per_block = BLOCK_SIZE / 2; @@ -90,7 +91,8 @@ float compute_amax(test::Tensor &t, size_t rows, size_t cols) { // against a CPU reference computed from the quantized data. template void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, - const bool row_scaled_nvfp4) { + const bool row_scaled_nvfp4, + const bool use_4over6) { using namespace test; DType otype = TypeInfo::dtype; @@ -105,6 +107,7 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, // Configure quantized tensor amax size_t amax_size = 1; + quantized.set_nvfp4_4over6(use_4over6); if (row_scaled_nvfp4) { quantized.set_row_scaled_nvfp4(true); amax_size = rows; @@ -116,7 +119,9 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, // Quantize if (rows > 0 && cols > 0) { - nvte_quantize(input.data(), quantized.data(), 0); + QuantizationConfigWrapper quant_config; + quant_config.set_nvfp4_4over6(use_4over6); + nvte_quantize_v2(input.data(), quantized.data(), quant_config, 0); cudaDeviceSynchronize(); auto err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); @@ -146,7 +151,7 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, std::make_unique(rows * cols); compute_ref_dequantize_nvfp4( fp4_data, scales, amax_vals, ref_output.get(), - rows, cols, scale_stride); + rows, cols, scale_stride, use_4over6); // Compare results from TE and reference impls auto [atol, rtol] = getTolerances(otype); @@ -156,7 +161,8 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, // Dequantize NVFP4 with GEMM-swizzled scales and compare against compact path. template void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, - const bool row_scaled_nvfp4) { + const bool row_scaled_nvfp4, + const bool use_4over6) { using namespace test; DType otype = TypeInfo::dtype; @@ -165,6 +171,7 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, Tensor quantized_compact("quantized_compact", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); + quantized_compact.set_nvfp4_4over6(use_4over6); if (row_scaled_nvfp4) { quantized_compact.set_row_scaled_nvfp4(true); } else if (rows > 0 && cols > 0) { @@ -174,7 +181,9 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, } if (rows > 0 && cols > 0) { - nvte_quantize(input.data(), quantized_compact.data(), 0); + QuantizationConfigWrapper quant_config; + quant_config.set_nvfp4_4over6(use_4over6); + nvte_quantize_v2(input.data(), quantized_compact.data(), quant_config, 0); cudaDeviceSynchronize(); } @@ -186,6 +195,7 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, // Create tensor with same FP4 data but swizzled scales Tensor quantized_swizzled("quantized_swizzled", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); + quantized_swizzled.set_nvfp4_4over6(use_4over6); if (row_scaled_nvfp4) { quantized_swizzled.set_row_scaled_nvfp4(true); } else { @@ -260,6 +270,7 @@ std::vector> nvfp4_tensor_dims = { class DequantizeNVFP4TestSuite : public ::testing::TestWithParam , transformer_engine::DType, + bool, bool>> {}; TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4) @@ -271,10 +282,11 @@ TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4) const auto tensor_size = std::get<0>(GetParam()); const DType output_type = std::get<1>(GetParam()); const bool row_scaled_nvfp4 = std::get<2>(GetParam()); + const bool use_4over6 = std::get<3>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, performTest_dequantize_nvfp4( - tensor_size.first, tensor_size.second, row_scaled_nvfp4); + tensor_size.first, tensor_size.second, row_scaled_nvfp4, use_4over6); ); } @@ -284,13 +296,15 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Combine( ::testing::ValuesIn(nvfp4_tensor_dims), ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Bool(), ::testing::Bool()), [](const testing::TestParamInfo& info) { std::string name = std::to_string(std::get<0>(info.param).first) + "X" + std::to_string(std::get<0>(info.param).second) + "X" + test::typeName(std::get<1>(info.param)) + "X" + - (std::get<2>(info.param) ? "RowScaled" : "PerTensor"); + (std::get<2>(info.param) ? "RowScaled" : "PerTensor") + "X" + + (std::get<3>(info.param) ? "FourOverSix" : "Default"); return name; } ); @@ -298,6 +312,7 @@ INSTANTIATE_TEST_SUITE_P( class DequantizeNVFP4SwizzledTestSuite : public ::testing::TestWithParam , transformer_engine::DType, + bool, bool>> {}; TEST_P(DequantizeNVFP4SwizzledTestSuite, TestDequantizeNVFP4Swizzled) @@ -309,10 +324,11 @@ TEST_P(DequantizeNVFP4SwizzledTestSuite, TestDequantizeNVFP4Swizzled) const auto tensor_size = std::get<0>(GetParam()); const DType output_type = std::get<1>(GetParam()); const bool row_scaled_nvfp4 = std::get<2>(GetParam()); + const bool use_4over6 = std::get<3>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, performTest_dequantize_nvfp4_swizzled( - tensor_size.first, tensor_size.second, row_scaled_nvfp4); + tensor_size.first, tensor_size.second, row_scaled_nvfp4, use_4over6); ); } @@ -322,6 +338,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Combine( ::testing::ValuesIn(nvfp4_tensor_dims), ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Bool(), ::testing::Bool()), [](const testing::TestParamInfo& info) { @@ -329,6 +346,7 @@ INSTANTIATE_TEST_SUITE_P( std::to_string(std::get<0>(info.param).second) + "X" + test::typeName(std::get<1>(info.param)) + "X" + (std::get<2>(info.param) ? "RowScaled" : "PerTensor") + "X" + + (std::get<3>(info.param) ? "FourOverSix" : "Default") + "X" + "Swizzled"; return name; } diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 4fd75bb927..6474540b39 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -440,6 +440,12 @@ void Tensor::set_row_scaled_nvfp4(bool row_scaled_nvfp4) { } } +void Tensor::set_nvfp4_4over6(bool nvfp4_4over6) { + NVTE_CHECK(tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING, + "NVFP4 4over6 is only supported for NVFP4 tensors."); + tensor_.set_nvfp4_4over6(nvfp4_4over6); +} + void Tensor::to_cpu() { if (data_rowwise_) { data_rowwise_->to_cpu(); } if (data_columnwise_) { data_columnwise_->to_cpu(); } diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 17f36a99dd..06afb86e7c 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -297,6 +297,7 @@ class Tensor { void set_with_gemm_swizzled_scales(bool with_gemm_swizzled_scales); void set_row_scaled_nvfp4(bool row_scaled_nvfp4); + void set_nvfp4_4over6(bool nvfp4_4over6); void to_cpu(); void from_cpu(); diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 53569d90d9..d7f9c8994e 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -20,7 +20,10 @@ def maybe_skip_row_scaled_unsupported_quantization( row_scaled_nvfp4: bool, return_transpose: bool, with_2d_quantization: bool = False, + use_4over6: bool = False, ) -> None: + if use_4over6 and with_2d_quantization: + pytest.skip("NVFP4 4over6 does not support 2D quantization") if not row_scaled_nvfp4: return if return_transpose: @@ -45,9 +48,10 @@ def check_quantization_nvfp4_versus_reference( use_cpp_allocator: bool, with_2d_quantization: bool, row_scaled_nvfp4: bool = False, + use_4over6: bool = False, ) -> None: maybe_skip_row_scaled_unsupported_quantization( - row_scaled_nvfp4, return_transpose, with_2d_quantization + row_scaled_nvfp4, return_transpose, with_2d_quantization, use_4over6 ) te_dtype = tex.DType.kFloat4E2M1 @@ -71,6 +75,7 @@ def check_quantization_nvfp4_versus_reference( with_post_rht_amax=False, with_2d_quantization=with_2d_quantization, row_scaled_nvfp4=row_scaled_nvfp4, + use_4over6=use_4over6, ) if use_cpp_allocator: x_nvfp4_sut = nvfp4_quantizer(x) @@ -104,6 +109,7 @@ def check_quantization_nvfp4_versus_reference( eps=0.0, quant_tile_shape=quant_tile_shape, row_scaled_nvfp4=row_scaled_nvfp4, + use_4over6=use_4over6, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -179,6 +185,7 @@ def check_quantization_nvfp4_versus_reference( "with_2d_quantization", [True, False], ids=["2d_quantization", "1d_quantization"] ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) +@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) def test_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, M: int, @@ -188,6 +195,7 @@ def test_quantization_block_tiling_versus_reference( use_cpp_allocator: bool, with_2d_quantization: bool, row_scaled_nvfp4: bool, + use_4over6: bool, ) -> None: check_quantization_nvfp4_versus_reference( x_dtype=x_dtype, @@ -198,6 +206,7 @@ def test_quantization_block_tiling_versus_reference( use_cpp_allocator=use_cpp_allocator, with_2d_quantization=with_2d_quantization, row_scaled_nvfp4=row_scaled_nvfp4, + use_4over6=use_4over6, ) @@ -215,6 +224,7 @@ def test_quantization_block_tiling_versus_reference( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) +@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) def test_nvfp4_quantization_extrema_versus_reference( x_dtype: torch.dtype, M: int, @@ -223,8 +233,11 @@ def test_nvfp4_quantization_extrema_versus_reference( return_transpose: bool, use_cpp_allocator: bool, row_scaled_nvfp4: bool, + use_4over6: bool, ): - maybe_skip_row_scaled_unsupported_quantization(row_scaled_nvfp4, return_transpose) + maybe_skip_row_scaled_unsupported_quantization( + row_scaled_nvfp4, return_transpose, use_4over6=use_4over6 + ) te_dtype = tex.DType.kFloat4E2M1 @@ -247,6 +260,7 @@ def test_nvfp4_quantization_extrema_versus_reference( with_rht=False, with_post_rht_amax=False, row_scaled_nvfp4=row_scaled_nvfp4, + use_4over6=use_4over6, ) if use_cpp_allocator: @@ -278,6 +292,7 @@ def test_nvfp4_quantization_extrema_versus_reference( eps=0.0, quant_tile_shape=(1, 16), row_scaled_nvfp4=row_scaled_nvfp4, + use_4over6=use_4over6, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -322,6 +337,7 @@ def test_nvfp4_quantization_extrema_versus_reference( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) +@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) def test_nvfp4_quantization_boundary_values( x_dtype: torch.dtype, M: int, @@ -329,13 +345,16 @@ def test_nvfp4_quantization_boundary_values( return_transpose: bool, use_cpp_allocator: bool, row_scaled_nvfp4: bool, + use_4over6: bool, ): """ Stress rounding/threshold behavior by placing values just below/above many potential bin edges within each 16-element microblock. Validates native vs reference byte-for-byte and scale parity. """ - maybe_skip_row_scaled_unsupported_quantization(row_scaled_nvfp4, return_transpose) + maybe_skip_row_scaled_unsupported_quantization( + row_scaled_nvfp4, return_transpose, use_4over6=use_4over6 + ) te_dtype = tex.DType.kFloat4E2M1 @@ -367,6 +386,7 @@ def test_nvfp4_quantization_boundary_values( with_rht=False, with_post_rht_amax=False, row_scaled_nvfp4=row_scaled_nvfp4, + use_4over6=use_4over6, ) if use_cpp_allocator: @@ -398,6 +418,7 @@ def test_nvfp4_quantization_boundary_values( eps=0.0, quant_tile_shape=(1, 16), row_scaled_nvfp4=row_scaled_nvfp4, + use_4over6=use_4over6, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -442,6 +463,7 @@ def test_nvfp4_quantization_boundary_values( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) +@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) def test_nvfp4_quantization_noncontiguous_inputs( x_dtype: torch.dtype, M: int, @@ -449,8 +471,11 @@ def test_nvfp4_quantization_noncontiguous_inputs( return_transpose: bool, use_cpp_allocator: bool, row_scaled_nvfp4: bool, + use_4over6: bool, ): - maybe_skip_row_scaled_unsupported_quantization(row_scaled_nvfp4, return_transpose) + maybe_skip_row_scaled_unsupported_quantization( + row_scaled_nvfp4, return_transpose, use_4over6=use_4over6 + ) te_dtype = tex.DType.kFloat4E2M1 @@ -473,6 +498,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( with_rht=False, with_post_rht_amax=False, row_scaled_nvfp4=row_scaled_nvfp4, + use_4over6=use_4over6, ) if use_cpp_allocator: @@ -504,6 +530,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( eps=0.0, quant_tile_shape=(1, 16), row_scaled_nvfp4=row_scaled_nvfp4, + use_4over6=use_4over6, ) x_nvfp4_ref = ref_quantizer.quantize(x_nc) diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index 43e9587d95..832b5a3741 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -39,6 +39,8 @@ # -------------------------- _BACKWARD_OVERRIDES = ("high_precision", "dequantized") +_NVFP4_RECIPE_NAMES = ("nvfp4", "nvfp4_row_scaled", "nvfp4_row_scaled_4over6") +_NVFP4_ROW_SCALED_RECIPE_NAMES = ("nvfp4_row_scaled", "nvfp4_row_scaled_4over6") fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) @@ -79,9 +81,9 @@ id="NVFP4BlockScaling", ), pytest.param( - "nvfp4_row_scaled", + "nvfp4_row_scaled_4over6", marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), - id="NVFP4RowScaledBlockScaling", + id="NVFP4RowScaled4Over6BlockScaling", ), ] @@ -170,7 +172,7 @@ def _maybe_skip_recipe_dtype( ) -> None: if dtype == torch.bfloat16 and not bf16_available: pytest.skip(reason_for_no_bf16) - if recipe_name in ("nvfp4", "nvfp4_row_scaled"): + if recipe_name in _NVFP4_RECIPE_NAMES: if module_type in ("linear", "layernorm_linear") and dtype not in ( torch.bfloat16, torch.float32, @@ -183,12 +185,12 @@ def _maybe_skip_recipe_dtype( def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: str) -> None: if module_type == "ops_linear" and recipe_name == "fp8_block_scaling": pytest.skip("Fusible ops (te_ops.Linear) do not support Float8BlockScaling recipe") - if module_type == "ops_linear" and recipe_name == "nvfp4_row_scaled": + if module_type == "ops_linear" and recipe_name in _NVFP4_ROW_SCALED_RECIPE_NAMES: pytest.skip("Row-scaled NVFP4 currently does not support fused te_ops paths.") def _make_quantized_forward_reference_recipe(recipe_name: str) -> recipe.Recipe: - if recipe_name == "nvfp4_row_scaled": + if recipe_name in _NVFP4_ROW_SCALED_RECIPE_NAMES: return make_recipe(recipe_name, backward_override="dequantized") return make_recipe(recipe_name) @@ -208,9 +210,7 @@ def _maybe_skip_unsupported_recipe_shape( " by 32." ) return - if recipe_name in ("nvfp4", "nvfp4_row_scaled") and ( - flat_first_dim % 16 != 0 or last_dim % 16 != 0 - ): + if recipe_name in _NVFP4_RECIPE_NAMES and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): pytest.skip( "Linear/LayerNormLinear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible" " by 16." @@ -235,9 +235,7 @@ def _maybe_skip_unsupported_recipe_shape( pytest.skip( "te_ops.Linear + MXFP8 requires prod(shape[:-1]) and shape[-1] divisible by 32." ) - if recipe_name in ("nvfp4", "nvfp4_row_scaled") and ( - flat_first_dim % 16 != 0 or last_dim % 16 != 0 - ): + if recipe_name in _NVFP4_RECIPE_NAMES and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): pytest.skip( "te_ops.Linear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible by 16." ) @@ -256,9 +254,9 @@ def _maybe_skip_unsupported_grouped_splits(recipe_name: str, m_splits: list[int] ) if recipe_name == "mxfp8" and any(m % 32 != 0 for m in non_empty_splits): pytest.skip("GroupedLinear + MXFP8 requires each non-empty m_split divisible by 32.") - if recipe_name in ("nvfp4", "nvfp4_row_scaled") and any(m % 16 != 0 for m in non_empty_splits): + if recipe_name in _NVFP4_RECIPE_NAMES and any(m % 16 != 0 for m in non_empty_splits): pytest.skip("GroupedLinear + NVFP4 requires each non-empty m_split divisible by 16.") - if recipe_name in ("nvfp4", "nvfp4_row_scaled") and any(m % 64 != 0 for m in non_empty_splits): + if recipe_name in _NVFP4_RECIPE_NAMES and any(m % 64 != 0 for m in non_empty_splits): pytest.skip( "GroupedLinear + NVFP4 grouped split_quantize currently requires each non-empty " "m_split divisible by 64 due to grouped amax kernel constraints." @@ -1741,7 +1739,7 @@ def test_backward_override_memory_peak_report( modes = ( ("high_precision", "dequantized") - if recipe_name == "nvfp4_row_scaled" + if recipe_name in _NVFP4_ROW_SCALED_RECIPE_NAMES else (None, "high_precision", "dequantized") ) mode_results: dict[str, dict[str, float] | str] = {} diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 33ba65e0d9..7ed8ebdb22 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -64,8 +64,14 @@ def nvfp4_rht_and_2d_quantization(): return nvfp4_recipe -def nvfp4_row_scaled(): - nvfp4_recipe = recipe.NVFP4BlockScaling(row_scaled_activation=True) +def nvfp4_row_scaled_4over6(): + nvfp4_recipe = recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + row_scaled_activation=True, + enable_4over6=True, + ) nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() @@ -100,7 +106,7 @@ def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> fp8_recipes.append(recipe.MXFP8BlockScaling()) if nvfp4_available: fp8_recipes.append(nvfp4_rht_and_2d_quantization()) - fp8_recipes.append(nvfp4_row_scaled()) + fp8_recipes.append(nvfp4_row_scaled_4over6()) if fp8_block_scaling_available: fp8_recipes.append(recipe.Float8BlockScaling()) if fp8_available: diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 5f5221af76..90e7381d44 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -28,6 +28,7 @@ NVFP4BlockScalingRecipeState, _amax_and_scale_update, ) +from transformer_engine.pytorch.tensor.storage.grouped_tensor_storage import GroupedTensorStorage import transformer_engine.pytorch.ops as te_ops from transformer_engine.common.recipe import ( DelayedScaling, @@ -534,9 +535,77 @@ def test_nvfp4_row_scaled_quantizer_roles(): assert [q.row_scaled_nvfp4 for q in backward_quantizers] == [False, False] +@pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) +def test_nvfp4_4over6_quantizer_roles(): + recipe = NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + enable_4over6=True, + row_scaled_activation=True, + ) + + forward_quantizers = NVFP4BlockScalingRecipeState( + recipe, + mode="forward", + num_quantizers=3, + ).make_quantizers() + assert [q.use_4over6 for q in forward_quantizers] == [True, True, True] + assert [q.row_scaled_nvfp4 for q in forward_quantizers] == [True, False, True] + + backward_quantizers = NVFP4BlockScalingRecipeState( + recipe, + mode="backward", + num_quantizers=2, + ).make_quantizers() + assert [q.use_4over6 for q in backward_quantizers] == [True, True] + assert [q.row_scaled_nvfp4 for q in backward_quantizers] == [False, False] + + +@pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) +def test_nvfp4_grouped_storage_metadata(): + q = NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=False, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=False, + stochastic_rounding=False, + row_scaled_nvfp4=True, + use_4over6=True, + ) + + grouped_tensor = GroupedTensorStorage.make_grouped_tensor( + num_tensors=2, + first_dims=None, + last_dims=None, + logical_first_dim=32, + logical_last_dim=64, + quantizer=q, + device=torch.device("cuda"), + dtype=torch.bfloat16, + ) + assert grouped_tensor._row_scaled_nvfp4 + assert grouped_tensor.row_scaled_nvfp4 + assert grouped_tensor._use_4over6 + assert grouped_tensor.use_4over6 + + grouped_copy = grouped_tensor.copy() + assert grouped_copy._row_scaled_nvfp4 + assert grouped_copy.row_scaled_nvfp4 + assert grouped_copy._use_4over6 + assert grouped_copy.use_4over6 + + for tensor in grouped_tensor.split_into_quantized_tensors(): + assert tensor._row_scaled_nvfp4 + assert tensor._use_4over6 + + @pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) +@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) @pytest.mark.parametrize( "M, N", [ @@ -552,24 +621,28 @@ def test_nvfp4_row_scaled_quantizer_roles(): (8192, 8192), ], ) -def test_fp4_dequantize(dtype, row_scaled_nvfp4, M, N): +def test_fp4_dequantize(dtype, row_scaled_nvfp4, use_4over6, M, N): q = NVFP4Quantizer( columnwise=not row_scaled_nvfp4, row_scaled_nvfp4=row_scaled_nvfp4, + use_4over6=use_4over6, ) a = torch.rand((M, N)).cuda().to(dtype=dtype) starting_tensor = q(a) assert starting_tensor._row_scaled_nvfp4 == row_scaled_nvfp4 + assert starting_tensor._use_4over6 == use_4over6 assert starting_tensor._amax_rowwise.numel() == (M if row_scaled_nvfp4 else 1) dequantized_tensor = starting_tensor.dequantize() new_tensor = q(dequantized_tensor) assert new_tensor._row_scaled_nvfp4 == row_scaled_nvfp4 + assert new_tensor._use_4over6 == use_4over6 assert new_tensor._amax_rowwise.numel() == (M if row_scaled_nvfp4 else 1) - torch.testing.assert_close( - new_tensor._rowwise_data, - starting_tensor._rowwise_data, - rtol=0, - atol=0, - ) + if not use_4over6: + torch.testing.assert_close( + new_tensor._rowwise_data, + starting_tensor._rowwise_data, + rtol=0, + atol=0, + ) new_dequantized_tensor = new_tensor.dequantize() torch.testing.assert_close(dequantized_tensor, new_dequantized_tensor) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index c811342df5..67fefc8217 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -94,8 +94,14 @@ def nvfp4_vanilla(): return nvfp4_recipe -def nvfp4_row_scaled(): - nvfp4_recipe = recipe.NVFP4BlockScaling(row_scaled_activation=True) +def nvfp4_row_scaled_4over6(): + nvfp4_recipe = recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + row_scaled_activation=True, + enable_4over6=True, + ) nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() @@ -115,7 +121,7 @@ def nvfp4_row_scaled(): fp8_recipes.append(None) fp8_recipes_with_row_scaled = fp8_recipes.copy() if nvfp4_available: - fp8_recipes_with_row_scaled.insert(-1, nvfp4_row_scaled()) + fp8_recipes_with_row_scaled.insert(-1, nvfp4_row_scaled_4over6()) param_types = [torch.float32, torch.float16] if is_bf16_available(): # bf16 requires sm_80 or higher diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 2ee18aaf57..62742d4ebb 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -118,7 +118,7 @@ def quantization_tols(name: str) -> dict[str, float]: "mxfp8_block_scaling", ): return dtype_tols(tex.DType.kFloat8E4M3) - if name in ("nvfp4", "nvfp4_row_scaled"): + if name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_row_scaled_4over6"): return dtype_tols(tex.DType.kFloat4E2M1) raise ValueError(f"Unsupported quantization scheme ({name})") @@ -160,6 +160,15 @@ def make_recipe(name: Optional[str], **recipe_kwargs: Any) -> Optional[Recipe]: row_scaled_activation=True, **recipe_kwargs, ) + if name == "nvfp4_row_scaled_4over6": + return transformer_engine.common.recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + row_scaled_activation=True, + enable_4over6=True, + **recipe_kwargs, + ) raise ValueError(f"Unsupported quantization scheme ({name})") @@ -167,6 +176,8 @@ def recipe_id(recipe: Optional[Recipe]) -> str: """Readable pytest id for a quantization recipe.""" if not isinstance(recipe, Recipe): return "None" + if recipe.nvfp4() and recipe.row_scaled_activation and recipe.enable_4over6: + return "NVFP4RowScaled4Over6BlockScaling" if recipe.nvfp4() and recipe.row_scaled_activation: return "NVFP4RowScaledBlockScaling" return type(recipe).__name__ diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 123362ce10..9fccffe993 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -101,6 +101,11 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, int32_t cols = input_tensor->flat_last_dim(); auto dtype = input_tensor->dtype(); const bool row_scaled_nvfp4 = output_tensor->row_scaled_nvfp4; + const bool use_4over6 = quant_config_cpp.nvfp4_4over6; + NVTE_CHECK(!use_4over6 || !quant_config_cpp.nvfp4_2d_quantization, + "NVFP4 4over6 quantization does not support 2D quantization."); + NVTE_CHECK(!use_4over6 || !quant_config_cpp.stochastic_rounding, + "NVFP4 4over6 quantization does not support stochastic rounding."); if (row_scaled_nvfp4) { NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, "Row-scaled NVFP4 quantization does not support 2D quantization."); @@ -132,9 +137,11 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, /*swizzled_scale=*/false, /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, + /*use_fast_math=*/quant_config_cpp.use_fast_math, /*rng_state=*/quant_config_cpp.rng_state, /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, /*row_scaled_nvfp4=*/row_scaled_nvfp4, + /*use_4over6=*/use_4over6, /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); } @@ -249,6 +256,11 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens int32_t rows = grad_tensor->flat_first_dim(); int32_t cols = grad_tensor->flat_last_dim(); auto dtype = grad_tensor->dtype(); + const bool use_4over6 = quant_config_cpp.nvfp4_4over6; + NVTE_CHECK(!use_4over6 || !quant_config_cpp.nvfp4_2d_quantization, + "NVFP4 4over6 quantization does not support 2D quantization."); + NVTE_CHECK(!use_4over6 || !quant_config_cpp.stochastic_rounding, + "NVFP4 4over6 quantization does not support stochastic rounding."); NVTE_CHECK(!output_tensor->row_scaled_nvfp4, "Backward NVFP4 quantization does not support row-scaled outputs."); bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && @@ -275,9 +287,11 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, /*swizzled_scale=*/false, /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, + /*use_fast_math=*/quant_config_cpp.use_fast_math, /*rng_state=*/quant_config_cpp.rng_state, /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, - /*row_scaled_nvfp4=*/false, /*noop_tensor=*/noop_tensor->data, + /*row_scaled_nvfp4=*/false, + /*use_4over6=*/use_4over6, /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); } break; @@ -374,6 +388,8 @@ void group_quantize_fwd_host_aware_helper(const NVTETensor input, NVTETensor *ou NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, "2D quantization is not supported for group quantize."); + NVTE_CHECK(!quant_config_cpp.nvfp4_4over6, + "NVFP4 4over6 quantization is not supported for group quantize."); // Launch NVFP4 group quantize kernel nvfp4::group_quantize_transpose( diff --git a/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh index 792b068cbc..fcd88d9585 100644 --- a/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh @@ -13,6 +13,7 @@ #include #include +#include #include #include @@ -76,10 +77,11 @@ namespace core { using namespace ptx; // Compute the global encode scale factor for a given global amax +template __device__ __forceinline__ float compute_global_encode_scaling_factor_FP4(const float global_amax) { using namespace detail; - constexpr float fp8_max = TypeExtrema::max; // 448.0f; - constexpr float fp4_max = TypeExtrema::max; // 6.0f; + constexpr float fp8_max = USE_4OVER6 ? 256.0f : TypeExtrema::max; // 448.0f; + constexpr float fp4_max = TypeExtrema::max; // 6.0f; float global_encode_scale = fp8_max * fp4_max / global_amax; // If scale is infinity, return max value of float32 global_encode_scale = fminf(global_encode_scale, TypeExtrema::max); @@ -90,6 +92,147 @@ __device__ __forceinline__ float compute_global_encode_scaling_factor_FP4(const return global_encode_scale; } +__device__ __forceinline__ void compute_4over6_decoding_scaling_factors( + const float block_amax, const float S_enc, nvfp4_scale_t &S_dec_b_fp8_map4, + nvfp4_scale_t &S_dec_b_fp8_map6) { + constexpr float fp4_max = detail::TypeExtrema::max; // 6.0f + const float sf_high_precision = block_amax / fp4_max * S_enc; + S_dec_b_fp8_map4 = static_cast(sf_high_precision * 1.5f); + S_dec_b_fp8_map6 = static_cast(sf_high_precision); +} + +template +__device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_mse_rn(const float (&x)[8], + const float block_scale_inverse, + const nvfp4_scale_t S_dec_b_fp8, + const float global_amax, + float *err) { + uint32_t out = 0; + uint32_t out_dequant_1 = 0; + uint32_t out_dequant_2 = 0; + uint32_t out_dequant_3 = 0; + uint32_t out_dequant_4 = 0; + + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; + if constexpr (is_blackwell) { + float x_scaled[8]; + if constexpr (USE_FAST_MATH) { +#pragma unroll + for (int i = 0; i < 8; ++i) { + x_scaled[i] = x[i] * block_scale_inverse; + } + } else { + x_scaled[0] = __fmul_rn(x[0], block_scale_inverse); + x_scaled[1] = __fmul_rn(x[1], block_scale_inverse); + x_scaled[2] = __fmul_rn(x[2], block_scale_inverse); + x_scaled[3] = __fmul_rn(x[3], block_scale_inverse); + x_scaled[4] = __fmul_rn(x[4], block_scale_inverse); + x_scaled[5] = __fmul_rn(x[5], block_scale_inverse); + x_scaled[6] = __fmul_rn(x[6], block_scale_inverse); + x_scaled[7] = __fmul_rn(x[7], block_scale_inverse); + } + + asm volatile( + "{\n" + ".reg .b8 byte0, byte1, byte2, byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %8, %7;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %10, %9;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %12, %11;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "cvt.rn.f16x2.e2m1x2 %1, byte0;\n" + "cvt.rn.f16x2.e2m1x2 %2, byte1;\n" + "cvt.rn.f16x2.e2m1x2 %3, byte2;\n" + "cvt.rn.f16x2.e2m1x2 %4, byte3;\n" + "}" + : "=r"(out), "=r"(out_dequant_1), "=r"(out_dequant_2), "=r"(out_dequant_3), + "=r"(out_dequant_4) + : "f"(x_scaled[0]), "f"(x_scaled[1]), "f"(x_scaled[2]), "f"(x_scaled[3]), "f"(x_scaled[4]), + "f"(x_scaled[5]), "f"(x_scaled[6]), "f"(x_scaled[7])); + + const uint16_t out_dequant_1_hi = (out_dequant_1 >> 16) & 0xFFFF; + const uint16_t out_dequant_1_lo = out_dequant_1 & 0xFFFF; + const uint16_t out_dequant_2_hi = (out_dequant_2 >> 16) & 0xFFFF; + const uint16_t out_dequant_2_lo = out_dequant_2 & 0xFFFF; + const uint16_t out_dequant_3_hi = (out_dequant_3 >> 16) & 0xFFFF; + const uint16_t out_dequant_3_lo = out_dequant_3 & 0xFFFF; + const uint16_t out_dequant_4_hi = (out_dequant_4 >> 16) & 0xFFFF; + const uint16_t out_dequant_4_lo = out_dequant_4 & 0xFFFF; + + constexpr float fp4_max = detail::TypeExtrema::max; // 6.0f + constexpr float fp8_4over6_max = 256.0f; + constexpr float mse_denom = fp4_max * fp8_4over6_max; + const float sf = static_cast(S_dec_b_fp8); + if constexpr (USE_FAST_MATH) { + const float dequant[8] = { + __half2float(__ushort_as_half(out_dequant_1_lo)), + __half2float(__ushort_as_half(out_dequant_1_hi)), + __half2float(__ushort_as_half(out_dequant_2_lo)), + __half2float(__ushort_as_half(out_dequant_2_hi)), + __half2float(__ushort_as_half(out_dequant_3_lo)), + __half2float(__ushort_as_half(out_dequant_3_hi)), + __half2float(__ushort_as_half(out_dequant_4_lo)), + __half2float(__ushort_as_half(out_dequant_4_hi)), + }; +#pragma unroll + for (int i = 0; i < 8; ++i) { + const float val = dequant[i] * sf * global_amax / mse_denom; + const float diff = val - x[i]; + *err += diff * diff; + } + } else { + const float val0 = __fdiv_rn( + __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_1_lo)), sf), global_amax), + mse_denom); + const float val1 = __fdiv_rn( + __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_1_hi)), sf), global_amax), + mse_denom); + const float val2 = __fdiv_rn( + __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_2_lo)), sf), global_amax), + mse_denom); + const float val3 = __fdiv_rn( + __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_2_hi)), sf), global_amax), + mse_denom); + const float val4 = __fdiv_rn( + __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_3_lo)), sf), global_amax), + mse_denom); + const float val5 = __fdiv_rn( + __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_3_hi)), sf), global_amax), + mse_denom); + const float val6 = __fdiv_rn( + __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_4_lo)), sf), global_amax), + mse_denom); + const float val7 = __fdiv_rn( + __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_4_hi)), sf), global_amax), + mse_denom); + + const float diff0 = __fsub_rn(val0, x[0]); + const float diff1 = __fsub_rn(val1, x[1]); + const float diff2 = __fsub_rn(val2, x[2]); + const float diff3 = __fsub_rn(val3, x[3]); + const float diff4 = __fsub_rn(val4, x[4]); + const float diff5 = __fsub_rn(val5, x[5]); + const float diff6 = __fsub_rn(val6, x[6]); + const float diff7 = __fsub_rn(val7, x[7]); + + *err = __fadd_rn(*err, __fmul_rn(diff0, diff0)); + *err = __fadd_rn(*err, __fmul_rn(diff1, diff1)); + *err = __fadd_rn(*err, __fmul_rn(diff2, diff2)); + *err = __fadd_rn(*err, __fmul_rn(diff3, diff3)); + *err = __fadd_rn(*err, __fmul_rn(diff4, diff4)); + *err = __fadd_rn(*err, __fmul_rn(diff5, diff5)); + *err = __fadd_rn(*err, __fmul_rn(diff6, diff6)); + *err = __fadd_rn(*err, __fmul_rn(diff7, diff7)); + } + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } + + return out; +} + __device__ __forceinline__ uint32_t get_rbits( transformer_engine::curanddx::detail::philox4x32_native_state &rng, diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index d549a050ee..9bc54fc191 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -35,8 +35,8 @@ template __global__ void __launch_bounds__(512) dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales, const float *const tensor_amax, const bool row_scaled_nvfp4, - const size_t N, const size_t M, const size_t scale_stride, - const size_t num_scale_tiles_X) { + const bool use_4over6, const size_t N, const size_t M, + const size_t scale_stride, const size_t num_scale_tiles_X) { const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; const size_t x = thread_idx % M; const size_t y = thread_idx / M; @@ -65,7 +65,8 @@ __global__ void __launch_bounds__(512) value.vec = input_vectorized[my_index]; fp8e4m3 scale = scales[my_scale_index]; float amax = row_scaled_nvfp4 ? tensor_amax[y] : tensor_amax[0]; - constexpr float factor_inv = 1.0 / (6.0 * 448.0); + const float fp8_max = use_4over6 ? 256.0f : 448.0f; + const float factor_inv = 1.0f / (6.0f * fp8_max); float final_scale = static_cast(scale) * amax * factor_inv; #pragma unroll for (int i = 0; i < 4; i++) { @@ -92,6 +93,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) const bool with_gemm_swizzled_scales = input.with_gemm_swizzled_scales; const bool row_scaled_nvfp4 = input.row_scaled_nvfp4; + const bool use_4over6 = input.nvfp4_4over6; constexpr int FP4_BLOCK_SIZE = 16; const size_t N = input.flat_first_dim(); @@ -116,7 +118,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) dequantize_fp4_kernel<<>>( input.data.dptr, reinterpret_cast(output->data.dptr), reinterpret_cast(input.scale_inv.dptr), - reinterpret_cast(input.amax.dptr), row_scaled_nvfp4, N, Mread, + reinterpret_cast(input.amax.dptr), row_scaled_nvfp4, use_4over6, N, Mread, input.scale_inv.shape.back(), num_scale_tiles_X);); // NOLINT(*) ); // NOLINT(*) diff --git a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh index 8adda82131..e9a567825d 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh @@ -184,14 +184,12 @@ compute_nvfp4_scaling_coefficient(const nvfp4_scale_t S_dec_block, const f return static_cast(scale_rcp); } -template -__device__ __forceinline__ void colwise_scaling(const IType *__restrict__ sIn_ptr, - fp4e2m1x2 *__restrict__ sOut_tr_ptr, - nvfp4_scale_t *__restrict__ sSFcolwise_ptr, - const float S_enc_colwise, const int stage_Y, - const int stage_X, const int buff_in, - const int buff_out_tr, RNG_t &rng, - uint4 &random_uint4, int &rnd_idx) { +template +__device__ __forceinline__ void colwise_scaling( + const IType *__restrict__ sIn_ptr, fp4e2m1x2 *__restrict__ sOut_tr_ptr, + nvfp4_scale_t *__restrict__ sSFcolwise_ptr, const float S_enc_colwise, + const float global_amax_colwise, const int stage_Y, const int stage_X, const int buff_in, + const int buff_out_tr, RNG_t &rng, uint4 &random_uint4, int &rnd_idx) { using scaling_coeff_type = typename SCALING_COEFFICIENT_TYPE::type; const auto &sIn2x = *reinterpret_cast(sIn_ptr); @@ -231,37 +229,84 @@ __device__ __forceinline__ void colwise_scaling(const IType *__restrict__ sIn_pt static_cast(__habs(thread_amax_2x.y))}; #pragma unroll for (int w = 0; w < 2; ++w) { - const nvfp4_scale_t S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax[w], S_enc_colwise); - - // Store scaling factors to SMEM buffer (R2S) - sSFcolwise[scale_tr_offset_Y + w][scale_tr_offset_X] = S_dec_b_fp8; + __align__(8) uint32_t rOut[SCALE_DIM / 8]; + nvfp4_scale_t S_dec_b_fp8; - const scaling_coeff_type SFcoefficient = - compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_colwise); + if constexpr (USE_4OVER6) { + nvfp4_scale_t S_dec_b_fp8_map4; + nvfp4_scale_t S_dec_b_fp8_map6; + core::compute_4over6_decoding_scaling_factors(block_amax[w], S_enc_colwise, S_dec_b_fp8_map4, + S_dec_b_fp8_map6); + + const scaling_coeff_type SFcoefficient_map4 = + compute_nvfp4_scaling_coefficient(S_dec_b_fp8_map4, S_enc_colwise); + const scaling_coeff_type SFcoefficient_map6 = + compute_nvfp4_scaling_coefficient(S_dec_b_fp8_map6, S_enc_colwise); + + float err_map4 = 0.0f; + float err_map6 = 0.0f; + __align__(8) uint32_t rOut_map4[SCALE_DIM / 8]; + __align__(8) uint32_t rOut_map6[SCALE_DIM / 8]; +#pragma unroll + for (int e = 0; e < SCALE_DIM / 8; ++e) { + const float x[8] = { + static_cast(rIn[w][8 * e + 0]), static_cast(rIn[w][8 * e + 1]), + static_cast(rIn[w][8 * e + 2]), static_cast(rIn[w][8 * e + 3]), + static_cast(rIn[w][8 * e + 4]), static_cast(rIn[w][8 * e + 5]), + static_cast(rIn[w][8 * e + 6]), static_cast(rIn[w][8 * e + 7]), + }; + rOut_map4[e] = core::cvt_fp32_to_fp4_8x_with_mse_rn( + x, static_cast(SFcoefficient_map4), S_dec_b_fp8_map4, global_amax_colwise, + &err_map4); + rOut_map6[e] = core::cvt_fp32_to_fp4_8x_with_mse_rn( + x, static_cast(SFcoefficient_map6), S_dec_b_fp8_map6, global_amax_colwise, + &err_map6); + } - // Scale elements - __align__(8) uint32_t rOut[SCALE_DIM / 8]; + if (err_map4 < err_map6) { + S_dec_b_fp8 = S_dec_b_fp8_map4; #pragma unroll - for (int e = 0; e < SCALE_DIM / 8; ++e) { - const uint64_t elts03 = *reinterpret_cast(&rIn[w][8 * e]); - const uint64_t elts47 = *reinterpret_cast(&rIn[w][8 * e + 4]); - if constexpr (USE_STOCHASTIC_ROUNDING) { - const uint32_t rbits03 = core::get_rbits(rng, random_uint4, rnd_idx); - const uint32_t rbits47 = core::get_rbits(rng, random_uint4, rnd_idx); - rOut[e] = ptx::mul_cvt_bf16_to_fp4_8x_stochastic_rounding( - elts03, elts47, SFcoefficient, rbits03, rbits47); + for (int e = 0; e < SCALE_DIM / 8; ++e) { + rOut[e] = rOut_map4[e]; + } } else { - rOut[e] = ptx::mul_cvt_bf16_to_fp4_8x_round_to_nearest(elts03, elts47, - SFcoefficient); + S_dec_b_fp8 = S_dec_b_fp8_map6; +#pragma unroll + for (int e = 0; e < SCALE_DIM / 8; ++e) { + rOut[e] = rOut_map6[e]; + } + } + } else { + S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax[w], S_enc_colwise); + const scaling_coeff_type SFcoefficient = + compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_colwise); + +#pragma unroll + for (int e = 0; e < SCALE_DIM / 8; ++e) { + const uint64_t elts03 = *reinterpret_cast(&rIn[w][8 * e]); + const uint64_t elts47 = *reinterpret_cast(&rIn[w][8 * e + 4]); + if constexpr (USE_STOCHASTIC_ROUNDING) { + const uint32_t rbits03 = core::get_rbits(rng, random_uint4, rnd_idx); + const uint32_t rbits47 = core::get_rbits(rng, random_uint4, rnd_idx); + rOut[e] = ptx::mul_cvt_bf16_to_fp4_8x_stochastic_rounding( + elts03, elts47, SFcoefficient, rbits03, rbits47); + } else { + rOut[e] = ptx::mul_cvt_bf16_to_fp4_8x_round_to_nearest(elts03, elts47, + SFcoefficient); + } } } + + // Store scaling factors to SMEM buffer (R2S) + sSFcolwise[scale_tr_offset_Y + w][scale_tr_offset_X] = S_dec_b_fp8; + uint64_t &out_pack_16x = *reinterpret_cast(rOut); ptx::st_shared_b64(&sOut_tr[buff_out_tr][out_tr_thread_offset_Y + w][out_tr_thread_offset_X], out_pack_16x); } } -template +template __device__ __forceinline__ void rowwise_scaling( const IType *__restrict__ sIn_ptr, fp4e2m1x2 *__restrict__ sOut_ptr, nvfp4_scale_t *__restrict__ sSFrowwise_ptr, const float S_enc_rowwise, const int stage_Y, @@ -314,19 +359,100 @@ __device__ __forceinline__ void rowwise_scaling( const float block_amax = get_amax_of_pair(thread_amax_2x); nvfp4_scale_t S_dec_b_fp8; - scaling_coeff_type SFcoefficient; + float block_S_enc_rowwise; + float block_global_amax; if constexpr (ROW_SCALED_NVFP4) { const size_t row_idx = row_offset + stage_Y * TILE_DIM_Y + it_offset_Y_rowwise; - const float S_enc_rowwise_block = - row_idx < rows ? core::compute_global_encode_scaling_factor_FP4(amax_rowwise_ptr[row_idx]) - : 1.0f; - S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc_rowwise_block); - SFcoefficient = - compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_rowwise_block); + if (row_idx < rows) { + block_global_amax = amax_rowwise_ptr[row_idx]; + block_S_enc_rowwise = + core::compute_global_encode_scaling_factor_FP4(block_global_amax); + } else { + block_global_amax = 1.0f; + block_S_enc_rowwise = 1.0f; + } + } else { + block_global_amax = *amax_rowwise_ptr; + block_S_enc_rowwise = S_enc_rowwise; + } + + __align__(8) uint32_t rOut[WAVES]; + + if constexpr (USE_4OVER6) { + nvfp4_scale_t S_dec_b_fp8_map4; + nvfp4_scale_t S_dec_b_fp8_map6; + core::compute_4over6_decoding_scaling_factors(block_amax, block_S_enc_rowwise, + S_dec_b_fp8_map4, S_dec_b_fp8_map6); + + const scaling_coeff_type SFcoefficient_map4 = + compute_nvfp4_scaling_coefficient(S_dec_b_fp8_map4, + block_S_enc_rowwise); + const scaling_coeff_type SFcoefficient_map6 = + compute_nvfp4_scaling_coefficient(S_dec_b_fp8_map6, + block_S_enc_rowwise); + + float err_map4 = 0.0f; + float err_map6 = 0.0f; + __align__(8) uint32_t rOut_map4[WAVES]; + __align__(8) uint32_t rOut_map6[WAVES]; + + auto process_wave = [&](const int w) { + const float x[8] = { + static_cast(rIn[w][0].x), static_cast(rIn[w][0].y), + static_cast(rIn[w][1].x), static_cast(rIn[w][1].y), + static_cast(rIn[w][2].x), static_cast(rIn[w][2].y), + static_cast(rIn[w][3].x), static_cast(rIn[w][3].y), + }; + rOut_map4[w] = core::cvt_fp32_to_fp4_8x_with_mse_rn( + x, static_cast(SFcoefficient_map4), S_dec_b_fp8_map4, block_global_amax, + &err_map4); + rOut_map6[w] = core::cvt_fp32_to_fp4_8x_with_mse_rn( + x, static_cast(SFcoefficient_map6), S_dec_b_fp8_map6, block_global_amax, + &err_map6); + }; + + if (bank_group == 0) { + process_wave(0); + process_wave(1); + } else { + process_wave(1); + process_wave(0); + } + + if (err_map4 < err_map6) { + S_dec_b_fp8 = S_dec_b_fp8_map4; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + rOut[w] = rOut_map4[w]; + } + } else { + S_dec_b_fp8 = S_dec_b_fp8_map6; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + rOut[w] = rOut_map6[w]; + } + } } else { - S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc_rowwise); - SFcoefficient = - compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_rowwise); + S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, block_S_enc_rowwise); + const scaling_coeff_type SFcoefficient = + compute_nvfp4_scaling_coefficient(S_dec_b_fp8, block_S_enc_rowwise); + +// Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const uint64_t elts03 = *reinterpret_cast(&rIn[w][0]); + const uint64_t elts47 = *reinterpret_cast(&rIn[w][2]); + + if constexpr (USE_STOCHASTIC_ROUNDING) { + const uint32_t rbits03 = core::get_rbits(rng, random_uint4, rnd_idx); + const uint32_t rbits47 = core::get_rbits(rng, random_uint4, rnd_idx); + rOut[w] = ptx::mul_cvt_bf16_to_fp4_8x_stochastic_rounding( + elts03, elts47, SFcoefficient, rbits03, rbits47); + } else { + rOut[w] = ptx::mul_cvt_bf16_to_fp4_8x_round_to_nearest(elts03, elts47, + SFcoefficient); + } + } } // Store scaling factors to SMEM buffer (R2S) @@ -339,29 +465,15 @@ __device__ __forceinline__ void rowwise_scaling( // Scale elements #pragma unroll for (int w = 0; w < WAVES; ++w) { - const uint64_t elts03 = *reinterpret_cast(&rIn[w][0]); - const uint64_t elts47 = *reinterpret_cast(&rIn[w][2]); - - uint32_t out_x8; - if constexpr (USE_STOCHASTIC_ROUNDING) { - const uint32_t rbits03 = core::get_rbits(rng, random_uint4, rnd_idx); - const uint32_t rbits47 = core::get_rbits(rng, random_uint4, rnd_idx); - out_x8 = ptx::mul_cvt_bf16_to_fp4_8x_stochastic_rounding( - elts03, elts47, SFcoefficient, rbits03, rbits47); - } else { - out_x8 = ptx::mul_cvt_bf16_to_fp4_8x_round_to_nearest(elts03, elts47, - SFcoefficient); - } - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % ELTS_PER_THREAD; const int swizzled_idx = (swizzled_group_idx + thread_offset_X_rowwise) / 2; - ptx::st_shared_b32(&sOut[buff_out][it_offset_Y_rowwise][swizzled_idx], out_x8); + ptx::st_shared_b32(&sOut[buff_out][it_offset_Y_rowwise][swizzled_idx], rOut[w]); } } } template + bool ROW_SCALED_NVFP4, bool USE_4OVER6> __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D_kernel( const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_output, @@ -429,12 +541,15 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D const float S_enc_rowwise = (amax_rowwise_ptr == nullptr) ? 1.0f - : core::compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); + : core::compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); const float S_enc_colwise = (amax_colwise_ptr == nullptr) ? S_enc_rowwise - : core::compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); + : core::compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); + const float global_amax_colwise = (amax_colwise_ptr == nullptr) + ? ((amax_rowwise_ptr == nullptr) ? 1.0f : *amax_rowwise_ptr) + : *amax_colwise_ptr; __shared__ uint64_t workID_mbar; __shared__ __uint128_t workID_response; @@ -582,14 +697,14 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D ptx::cp_async_bulk_wait_group_read(); // NVFP4 Quantization - rowwise_scaling( + rowwise_scaling( sIn_ptr, sOut_ptr, sSFrowwise_ptr, S_enc_rowwise, stage_Y, stage_X, buff_in, buff_out, amax_rowwise_ptr, block_offset_Y, rows, rng, random_uint4, rnd_idx); if constexpr (RETURN_TRANSPOSE) { - colwise_scaling( - sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, S_enc_colwise, stage_Y, stage_X, buff_in, - buff_out_tr, rng, random_uint4, rnd_idx); + colwise_scaling( + sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, S_enc_colwise, global_amax_colwise, stage_Y, + stage_X, buff_in, buff_out_tr, rng, random_uint4, rnd_idx); } // Wait for shared memory writes to be visible to TMA engine @@ -691,6 +806,7 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, const bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; const bool use_fast_math = quant_config ? quant_config->use_fast_math : false; + const bool use_4over6 = quant_config ? quant_config->nvfp4_4over6 : false; const bool row_scaled_nvfp4 = output->row_scaled_nvfp4; // If transposed output is allocated, return the transposed data @@ -710,6 +826,8 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, "Row-scaled NVFP4 quantization requires rowwise amax."); NVTE_CHECK(!row_scaled_nvfp4 || !output->has_columnwise_data(), "Row-scaled NVFP4 quantization does not produce columnwise output."); + NVTE_CHECK(!use_4over6 || !use_stochastic_rounding, + "NVFP4 4over6 quantization does not support stochastic rounding."); if (return_transpose) { NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), @@ -801,18 +919,22 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, use_fast_math, USE_FAST_MATH, TRANSFORMER_ENGINE_SWITCH_CONDITION( row_scaled_nvfp4, ROW_SCALED_NVFP4, - TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { - auto kernel = - quantize_transpose_nvfp4_tuned_1D_kernel; - - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - dshmem_size); - kernel<<>>( - tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, - scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, - scale_stride, scale_stride_transpose, rng_state); - });););); + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_4over6, USE_4OVER6, + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + auto kernel = + quantize_transpose_nvfp4_tuned_1D_kernel; + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, + scales_ptr, scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, + amax_colwise_ptr, rows, cols, scale_stride, scale_stride_transpose, + rng_state); + }););););); #else NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); #endif // FP4_TYPE_SUPPORTED diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 28218e2b43..1b3aa1950f 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -230,6 +230,10 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz chunk.set_row_scaled_nvfp4(source.get_row_scaled_nvfp4()); continue; } + if (param_type == NVTETensorParam::kNVTENVFP44Over6) { + chunk.set_nvfp4_4over6(source.get_nvfp4_4over6()); + continue; + } auto param = source.get_parameter(param_type); auto param_dptr = reinterpret_cast(param.data_ptr); auto param_dtype = static_cast(param.dtype); diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 12479f2a9c..b9d1c3f70e 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -178,6 +178,11 @@ struct Tensor { * Only meaningful for NVFP4 tensors. */ bool row_scaled_nvfp4 = false; + /*! \brief Whether NVFP4 uses 4over6 block scale selection. + * + * Only meaningful for NVFP4 tensors. + */ + bool nvfp4_4over6 = false; /*! Map from NVTETensorParam to parameter sizes */ static constexpr size_t attr_sizes[] = { @@ -189,7 +194,8 @@ struct Tensor { sizeof(NVTEBasicTensor), // kNVTEColumnwiseScaleInv sizeof(NVTEBasicTensor), // kNVTEColumnwiseAmax sizeof(uint8_t), // kNVTEWithGEMMSwizzledScales - sizeof(uint8_t) // kNVTERowScaledNVFP4 + sizeof(uint8_t), // kNVTERowScaledNVFP4 + sizeof(uint8_t) // kNVTENVFP44Over6 }; Tensor() : scaling_mode{NVTE_DELAYED_TENSOR_SCALING}, nvte_tensor{0} {} @@ -206,6 +212,7 @@ struct Tensor { scaling_mode = NVTE_DELAYED_TENSOR_SCALING; with_gemm_swizzled_scales = false; row_scaled_nvfp4 = false; + nvfp4_4over6 = false; } explicit operator NVTETensor() const noexcept { return nvte_tensor; } @@ -477,6 +484,7 @@ struct QuantizationConfig { bool nvfp4_2d_quantization = false; bool stochastic_rounding = false; bool use_fast_math = false; + bool nvfp4_4over6 = false; static constexpr size_t attr_sizes[] = { sizeof(uint8_t), // force_pow_2_scales @@ -486,7 +494,8 @@ struct QuantizationConfig { sizeof(NVTETensor), // rng_seed and offset sizeof(uint8_t), // nvfp4_2d_quantization sizeof(uint8_t), // stochastic_rounding - sizeof(uint8_t) // use_fast_math + sizeof(uint8_t), // use_fast_math + sizeof(uint8_t) // nvfp4_4over6 }; }; diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 045ae88893..35bf020e8b 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -83,6 +83,7 @@ enum NVTETensorParam { * its values are populated during quantization. */ kNVTERowScaledNVFP4 = 8, + kNVTENVFP44Over6 = 9, /*!< Whether an NVFP4 tensor uses 4over6 scaling */ kNVTENumTensorParams }; @@ -381,6 +382,8 @@ enum NVTEQuantizationConfigAttribute { * inconsistently between kernels. */ kNVTEQuantizationConfigUseFastMath = 7, + /*! Whether to use NVFP4 4over6 block scale selection */ + kNVTEQuantizationConfigNVFP44Over6 = 8, kNVTEQuantizationConfigNumAttributes }; @@ -781,6 +784,11 @@ class TensorWrapper { nvte_set_tensor_param_v2(tensor_, kNVTERowScaledNVFP4, &val, sizeof(val)); } + void set_nvfp4_4over6(bool nvfp4_4over6) { + const auto val = static_cast(nvfp4_4over6); + nvte_set_tensor_param_v2(tensor_, kNVTENVFP44Over6, &val, sizeof(val)); + } + // Parameter getters NVTEBasicTensor get_parameter(const NVTETensorParam param) const noexcept { @@ -823,6 +831,12 @@ class TensorWrapper { return static_cast(val); } + bool get_nvfp4_4over6() const { + uint8_t val = 0; + nvte_get_tensor_param_v2(tensor_, kNVTENVFP44Over6, &val, sizeof(val), nullptr); + return static_cast(val); + } + /*! \brief Get an underlying NVTETensor. * * \return NVTETensor held by this TensorWrapper. @@ -1318,6 +1332,13 @@ class QuantizationConfigWrapper { sizeof(val)); } + /*! \brief Set whether to use NVFP4 4over6 block scale selection */ + void set_nvfp4_4over6(bool nvfp4_4over6) { + const auto val = static_cast(nvfp4_4over6); + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNVFP44Over6, &val, + sizeof(val)); + } + private: /*! \brief Wrapped NVTEQuantizationConfig. */ NVTEQuantizationConfig config_ = nullptr; diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index b773a81d1b..61eb15dcca 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -522,6 +522,9 @@ class NVFP4BlockScaling(Recipe): If set to `True`, forward activation quantizers emit row-scaled NVFP4 tensors. In this mode, rowwise ``amax`` metadata is stored as a vector with one FP32 value per tensor row. + enable_4over6 : bool, default = False + If set to `True`, NVFP4 1D quantization evaluates per-block 4over6 + and 6-over-6 candidates and chooses the one with lower MSE. backward_override : {None, 'high_precision', 'dequantized'}, default = None Backward precision mode. None does not modify backward behavior, `high_precision` keeps original high-precision operands for backward, @@ -536,6 +539,7 @@ class NVFP4BlockScaling(Recipe): ) disable_2d_quantization: bool = os.getenv("NVTE_NVFP4_DISABLE_2D_QUANTIZATION", "0") == "1" row_scaled_activation: bool = os.getenv("NVTE_NVFP4_ROW_SCALED_ACTIVATION", "0") == "1" + enable_4over6: bool = os.getenv("NVTE_NVFP4_ENABLE_4OVER6", "0") == "1" fp4_format: Format = Format.E2M1 fp8_format: Format = Format.E4M3 @@ -551,6 +555,14 @@ def __post_init__(self) -> None: assert ( self.backward_override in _BACKWARD_OVERRIDES ), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'." + if self.enable_4over6: + assert self.disable_rht, "NVFP4 4over6 currently requires RHT to be disabled" + assert ( + self.disable_stochastic_rounding + ), "NVFP4 4over6 currently requires stochastic rounding to be disabled" + assert ( + self.disable_2d_quantization + ), "NVFP4 4over6 currently requires 2D quantization to be disabled" # Quantization params # Note: RHT is currently only applied to column-wise usage so that @@ -580,6 +592,7 @@ def _make_repr(self) -> str: f"fp8_mha={self.fp8_mha}, " f"backward_override={self.backward_override}, " f"row_scaled_activation={self.row_scaled_activation}, " + f"enable_4over6={self.enable_4over6}, " f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, " f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, " f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, " diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 1a52d76019..2f5cf8ac9c 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -855,6 +855,9 @@ void nvte_set_tensor_param_v2(NVTETensor tensor, NVTETensorParam param, const vo case kNVTERowScaledNVFP4: t.row_scaled_nvfp4 = static_cast(*reinterpret_cast(buf)); break; + case kNVTENVFP44Over6: + t.nvfp4_4over6 = static_cast(*reinterpret_cast(buf)); + break; default: NVTE_ERROR("Unsupported tensor parameter (", static_cast(param), ")"); } @@ -938,6 +941,9 @@ void nvte_get_tensor_param_v2(const NVTETensor tensor, NVTETensorParam param, vo case kNVTERowScaledNVFP4: *reinterpret_cast(buf) = static_cast(t->row_scaled_nvfp4); break; + case kNVTENVFP44Over6: + *reinterpret_cast(buf) = static_cast(t->nvfp4_4over6); + break; default: NVTE_ERROR("Unsupported tensor parameter (", static_cast(param), ")"); } @@ -1049,6 +1055,9 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigUseFastMath: bool_to_uint8(config_.use_fast_math, buf); break; + case kNVTEQuantizationConfigNVFP44Over6: + bool_to_uint8(config_.nvfp4_4over6, buf); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } @@ -1104,6 +1113,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigUseFastMath: uint8_to_bool(buf, config_.use_fast_math); break; + case kNVTEQuantizationConfigNVFP44Over6: + uint8_to_bool(buf, config_.nvfp4_4over6); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h index c462b30147..e4b839ea6d 100644 --- a/transformer_engine/common/transpose/cast_transpose.h +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -66,9 +66,9 @@ void quantize_transpose_vector_blockwise_fp4( const SimpleTensor &input, const SimpleTensor &global_amax, SimpleTensor &scale_inv, SimpleTensor &scale_inv_t, SimpleTensor &output, SimpleTensor &output_t, const float epsilon, const bool return_identity, const bool return_transpose, const bool pow2_scale, - const bool swizzled_scale, const bool use_stochastic_rounding, + const bool swizzled_scale, const bool use_stochastic_rounding, const bool use_fast_math, const NVTETensor rng_state_tensor, const bool use_2d_quantization, const bool row_scaled_nvfp4, - const SimpleTensor &noop_tensor, cudaStream_t stream); + const bool use_4over6, const SimpleTensor &noop_tensor, cudaStream_t stream); } // namespace transformer_engine::detail diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index cf9821f1a9..7100326305 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -14,6 +14,7 @@ #include #include +#include "common/cast/nvfp4/core_nvfp4.cuh" #include "common/common.h" #include "common/recipe/recipe_common.cuh" #include "common/transpose/cast_transpose.h" @@ -187,8 +188,9 @@ __device__ __forceinline__ float ComputeOutputFP4(IType input, float encode_scal return static_cast(input) * encode_scale; } +template __device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_amax) { - constexpr float fp8_max = TypeExtrema::max; + constexpr float fp8_max = kUse4Over6 ? 256.0f : TypeExtrema::max; constexpr float fp4_max = TypeExtrema::max; float global_encode_scale = fp8_max * fp4_max / global_amax; // If scale is infinity, return max value of float32 @@ -316,7 +318,8 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x(const float2 in01, template + bool kApplyStochasticRounding, bool kIs2DBlockScaling, bool kRowScaledNVFP4, + bool kUse4Over6, bool kUseFastMath> __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel( const IType* const input, const float* global_amax, OType* const output_c, OType* const output_t, ScaleType* const tile_scales_inv_c, ScaleType* const tile_scales_inv_t, @@ -418,7 +421,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo const int kNumThreadsReduce = kScaleBlockDim / kNVecOut; const float global_encode_scale = - kIsE8Scaling ? 1.0f : ComputeGlobalEncodeScaleFP4(global_amax[0]); + kIsE8Scaling ? 1.0f : ComputeGlobalEncodeScaleFP4(global_amax[0]); constexpr float fp4_max_inv = 1.0f / TypeExtrema::max; const float global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; const float global_decode_scale = 1.0 / global_encode_scale; @@ -513,15 +516,74 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo float row_global_encode_scale = global_encode_scale; if constexpr (kRowScaledNVFP4) { row_global_encode_scale = - row_idx < num_rows ? ComputeGlobalEncodeScaleFP4(global_amax[row_idx]) : 1.0f; + row_idx < num_rows ? ComputeGlobalEncodeScaleFP4(global_amax[row_idx]) + : 1.0f; } const float row_global_encode_scale_multiplier = kRowScaledNVFP4 ? row_global_encode_scale * fp4_max_inv : global_encode_scale_multiplier; const float row_global_decode_scale = kRowScaledNVFP4 ? 1.0f / row_global_encode_scale : global_decode_scale; - ScaleType scale_inv = - ComputeDecodeScaleFP4(amax, row_global_encode_scale_multiplier); - float encode_scale = ComputeEncodeScaleFP4(scale_inv, row_global_decode_scale); + ScaleType scale_inv; + float encode_scale; + OVec output_vec; + if constexpr (kUse4Over6) { + ScaleType scale_inv_map4; + ScaleType scale_inv_map6; + transformer_engine::dispatch::nvfp4::core::compute_4over6_decoding_scaling_factors( + amax, row_global_encode_scale, scale_inv_map4, scale_inv_map6); + const float encode_scale_map4 = + ComputeEncodeScaleFP4(scale_inv_map4, row_global_decode_scale); + const float encode_scale_map6 = + ComputeEncodeScaleFP4(scale_inv_map6, row_global_decode_scale); + float row_global_amax; + if constexpr (kRowScaledNVFP4) { + if (row_idx < num_rows) { + row_global_amax = global_amax[row_idx]; + } else { + row_global_amax = 1.0f; + } + } else { + row_global_amax = global_amax[0]; + } + + float err_map4 = 0.0f; + float err_map6 = 0.0f; + uint32_t output_vec_map4[2]; + uint32_t output_vec_map6[2]; +#pragma unroll + for (int i = 0; i < kNVecOut / kNVecSMem; i += 4) { + const int out_idx = i / 4; + const float x[8] = { + static_cast(smem_vec[i + 0].data.elt[0]), + static_cast(smem_vec[i + 0].data.elt[1]), + static_cast(smem_vec[i + 1].data.elt[0]), + static_cast(smem_vec[i + 1].data.elt[1]), + static_cast(smem_vec[i + 2].data.elt[0]), + static_cast(smem_vec[i + 2].data.elt[1]), + static_cast(smem_vec[i + 3].data.elt[0]), + static_cast(smem_vec[i + 3].data.elt[1]), + }; + output_vec_map4[out_idx] = + transformer_engine::dispatch::nvfp4::core::cvt_fp32_to_fp4_8x_with_mse_rn< + kUseFastMath>(x, encode_scale_map4, scale_inv_map4, row_global_amax, &err_map4); + output_vec_map6[out_idx] = + transformer_engine::dispatch::nvfp4::core::cvt_fp32_to_fp4_8x_with_mse_rn< + kUseFastMath>(x, encode_scale_map6, scale_inv_map6, row_global_amax, &err_map6); + } + + if (err_map4 < err_map6) { + scale_inv = scale_inv_map4; + *reinterpret_cast(&output_vec.data.elt[0]) = output_vec_map4[0]; + *reinterpret_cast(&output_vec.data.elt[4]) = output_vec_map4[1]; + } else { + scale_inv = scale_inv_map6; + *reinterpret_cast(&output_vec.data.elt[0]) = output_vec_map6[0]; + *reinterpret_cast(&output_vec.data.elt[4]) = output_vec_map6[1]; + } + } else { + scale_inv = ComputeDecodeScaleFP4(amax, row_global_encode_scale_multiplier); + encode_scale = ComputeEncodeScaleFP4(scale_inv, row_global_decode_scale); + } // Step 2.5: Write scale_inv bool write_scale_inv = is_src_lane; if constexpr (!kAligned) { @@ -541,22 +603,24 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo } } // Step 2.6: Quantize - OVec output_vec; + if constexpr (!kUse4Over6) { #pragma unroll - for (int i = 0; i < kNVecOut / kNVecSMem; i += 2) { - // Pack two elements into __nv_bfloat162 - float2 f2_a; - float2 f2_b; - f2_a.x = ComputeOutputFP4(smem_vec[i].data.elt[0], encode_scale); - f2_a.y = ComputeOutputFP4(smem_vec[i].data.elt[1], encode_scale); - f2_b.x = ComputeOutputFP4(smem_vec[i + 1].data.elt[0], encode_scale); - f2_b.y = ComputeOutputFP4(smem_vec[i + 1].data.elt[1], encode_scale); - const uint32_t rbits = kApplyStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0; - // Convert to __nv_fp4x4_e2m1 - __nv_fp4x4_e2m1 out_4x = cvt_fp32_to_fp4_4x(f2_a, f2_b, rbits); - - output_vec.data.elt[i] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[0]; - output_vec.data.elt[i + 1] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[1]; + for (int i = 0; i < kNVecOut / kNVecSMem; i += 2) { + // Pack two elements into __nv_bfloat162 + float2 f2_a; + float2 f2_b; + f2_a.x = ComputeOutputFP4(smem_vec[i].data.elt[0], encode_scale); + f2_a.y = ComputeOutputFP4(smem_vec[i].data.elt[1], encode_scale); + f2_b.x = ComputeOutputFP4(smem_vec[i + 1].data.elt[0], encode_scale); + f2_b.y = ComputeOutputFP4(smem_vec[i + 1].data.elt[1], encode_scale); + const uint32_t rbits = + kApplyStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0; + // Convert to __nv_fp4x4_e2m1 + __nv_fp4x4_e2m1 out_4x = cvt_fp32_to_fp4_4x(f2_a, f2_b, rbits); + + output_vec.data.elt[i] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[0]; + output_vec.data.elt[i + 1] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[1]; + } } // Step 2.7: Store output_c if constexpr (kAligned) { @@ -643,9 +707,57 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo amax = __shfl_sync(mask, amax, src_lane); } // Step 3.4: Compute scale - ScaleType scale_inv = - ComputeDecodeScaleFP4(amax, global_encode_scale_multiplier); - float encode_scale = ComputeEncodeScaleFP4(scale_inv, global_decode_scale); + ScaleType scale_inv; + float encode_scale; + OVec output_vec; + if constexpr (kUse4Over6) { + ScaleType scale_inv_map4; + ScaleType scale_inv_map6; + transformer_engine::dispatch::nvfp4::core::compute_4over6_decoding_scaling_factors( + amax, global_encode_scale, scale_inv_map4, scale_inv_map6); + const float encode_scale_map4 = + ComputeEncodeScaleFP4(scale_inv_map4, global_decode_scale); + const float encode_scale_map6 = + ComputeEncodeScaleFP4(scale_inv_map6, global_decode_scale); + + float err_map4 = 0.0f; + float err_map6 = 0.0f; + uint32_t output_vec_map4[2]; + uint32_t output_vec_map6[2]; +#pragma unroll + for (int i = 0; i < kNVecOut / kNFP4PerContainer; i += 4) { + const int out_idx = i / 4; + const float x[8] = { + static_cast(smem_vec[2 * (i + 0)].data.elt[smem_idx]), + static_cast(smem_vec[2 * (i + 0) + 1].data.elt[smem_idx]), + static_cast(smem_vec[2 * (i + 1)].data.elt[smem_idx]), + static_cast(smem_vec[2 * (i + 1) + 1].data.elt[smem_idx]), + static_cast(smem_vec[2 * (i + 2)].data.elt[smem_idx]), + static_cast(smem_vec[2 * (i + 2) + 1].data.elt[smem_idx]), + static_cast(smem_vec[2 * (i + 3)].data.elt[smem_idx]), + static_cast(smem_vec[2 * (i + 3) + 1].data.elt[smem_idx]), + }; + output_vec_map4[out_idx] = + transformer_engine::dispatch::nvfp4::core::cvt_fp32_to_fp4_8x_with_mse_rn< + kUseFastMath>(x, encode_scale_map4, scale_inv_map4, global_amax[0], &err_map4); + output_vec_map6[out_idx] = + transformer_engine::dispatch::nvfp4::core::cvt_fp32_to_fp4_8x_with_mse_rn< + kUseFastMath>(x, encode_scale_map6, scale_inv_map6, global_amax[0], &err_map6); + } + + if (err_map4 < err_map6) { + scale_inv = scale_inv_map4; + *reinterpret_cast(&output_vec.data.elt[0]) = output_vec_map4[0]; + *reinterpret_cast(&output_vec.data.elt[4]) = output_vec_map4[1]; + } else { + scale_inv = scale_inv_map6; + *reinterpret_cast(&output_vec.data.elt[0]) = output_vec_map6[0]; + *reinterpret_cast(&output_vec.data.elt[4]) = output_vec_map6[1]; + } + } else { + scale_inv = ComputeDecodeScaleFP4(amax, global_encode_scale_multiplier); + encode_scale = ComputeEncodeScaleFP4(scale_inv, global_decode_scale); + } // Step 3.5: Write scale_inv_t bool write_scale_inv = is_src_lane; if constexpr (!kAligned) { @@ -665,27 +777,29 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo } } // Step 3.6: Quantize - OVec output_vec; + if constexpr (!kUse4Over6) { #pragma unroll - for (int i = 0; i < kNVecOut / kNFP4PerContainer; i += 2) { - // Pack two elements into __nv_bfloat162 - float2 f2_a; - float2 f2_b; - f2_a.x = - ComputeOutputFP4(smem_vec[2 * i].data.elt[smem_idx], encode_scale); - f2_a.y = ComputeOutputFP4(smem_vec[2 * i + 1].data.elt[smem_idx], - encode_scale); - f2_b.x = ComputeOutputFP4(smem_vec[2 * (i + 1)].data.elt[smem_idx], - encode_scale); - f2_b.y = ComputeOutputFP4(smem_vec[2 * (i + 1) + 1].data.elt[smem_idx], - encode_scale); - const uint32_t rbits = - kApplyStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0; - // Convert to __nv_fp4x4_e2m1 - __nv_fp4x4_e2m1 out_4x = cvt_fp32_to_fp4_4x(f2_a, f2_b, rbits); - - output_vec.data.elt[i] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[0]; - output_vec.data.elt[i + 1] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[1]; + for (int i = 0; i < kNVecOut / kNFP4PerContainer; i += 2) { + // Pack two elements into __nv_bfloat162 + float2 f2_a; + float2 f2_b; + f2_a.x = ComputeOutputFP4(smem_vec[2 * i].data.elt[smem_idx], + encode_scale); + f2_a.y = ComputeOutputFP4(smem_vec[2 * i + 1].data.elt[smem_idx], + encode_scale); + f2_b.x = ComputeOutputFP4(smem_vec[2 * (i + 1)].data.elt[smem_idx], + encode_scale); + f2_b.y = ComputeOutputFP4( + smem_vec[2 * (i + 1) + 1].data.elt[smem_idx], encode_scale); + const uint32_t rbits = + kApplyStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0; + // Convert to __nv_fp4x4_e2m1 + __nv_fp4x4_e2m1 out_4x = + cvt_fp32_to_fp4_4x(f2_a, f2_b, rbits); + + output_vec.data.elt[i] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[0]; + output_vec.data.elt[i + 1] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[1]; + } } // Step 3.7: Store output_t if constexpr (kAligned) { @@ -718,9 +832,9 @@ void quantize_transpose_vector_blockwise_fp4( const SimpleTensor& input, const SimpleTensor& global_amax, SimpleTensor& scale_inv, SimpleTensor& scale_inv_t, SimpleTensor& output, SimpleTensor& output_t, const float epsilon, const bool return_identity, const bool return_transpose, const bool pow2_scale, - const bool swizzled_scale, const bool use_stochastic_rounding, + const bool swizzled_scale, const bool use_stochastic_rounding, const bool use_fast_math, const NVTETensor rng_state_tensor, const bool use_2d_quantization, const bool row_scaled_nvfp4, - const SimpleTensor& noop_tensor, cudaStream_t stream) { + const bool use_4over6, const SimpleTensor& noop_tensor, cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_vector_blockwise_fp4); #if CUDA_VERSION >= 12080 @@ -737,6 +851,10 @@ void quantize_transpose_vector_blockwise_fp4( "Row-scaled NVFP4 quantization only supports rowwise quantization."); NVTE_CHECK(!row_scaled_nvfp4 || !use_2d_quantization, "Row-scaled NVFP4 quantization does not support 2D quantization."); + NVTE_CHECK(!use_4over6 || !use_2d_quantization, + "NVFP4 4over6 quantization does not support 2D quantization."); + NVTE_CHECK(!use_4over6 || !use_stochastic_rounding, + "NVFP4 4over6 quantization does not support stochastic rounding."); const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; size_t num_elements = row_length; @@ -819,38 +937,50 @@ void quantize_transpose_vector_blockwise_fp4( TRANSFORMER_ENGINE_SWITCH_CONDITION( row_scaled_nvfp4, kRowScaledNVFP4, - size_t smem_bytes = kSMemSize * sizeof(InputType); - auto kernel = block_scaled_1d_cast_transpose_kernel< - kReturnIdentity, kReturnTranspose, kPow2Scale, kAligned, - float, InputType, OutputType, ScaleType, kSwizzledScale, - kApplyStochasticRounding, kIs2DBlockScaling, - kRowScaledNVFP4>; - if (smem_bytes >= 48 * 1024) { - cudaError_t err = cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_bytes); - NVTE_CHECK(err == cudaSuccess, - "Failed to set dynamic shared memory size."); - } kernel<<>>( - reinterpret_cast(input.dptr), - reinterpret_cast(global_amax.dptr), - reinterpret_cast(output.dptr), - reinterpret_cast(output_t.dptr), - reinterpret_cast(scale_inv.dptr), - reinterpret_cast(scale_inv_t.dptr), - row_length, num_rows, scale_stride_x, scale_stride_y, - scale_t_stride_x, scale_t_stride_y, kScaleBlockDim, - epsilon, rng_state, - noop_ptr);) // kRowScaledNVFP4 - ) // kIs2DBlockScaling - ) // kApplyStochasticRounding - ) // kSwizzledScale - ) // kAligned - ) // kReturnTranspose - ) // kReturnIdentity - ) // OutputType - ) // InputType + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_4over6, kUse4Over6, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_fast_math, kUseFastMath, + + size_t smem_bytes = kSMemSize * sizeof(InputType); + auto kernel = block_scaled_1d_cast_transpose_kernel< + kReturnIdentity, kReturnTranspose, kPow2Scale, + kAligned, float, InputType, OutputType, ScaleType, + kSwizzledScale, kApplyStochasticRounding, + kIs2DBlockScaling, kRowScaledNVFP4, kUse4Over6, + kUseFastMath>; + if (smem_bytes >= 48 * 1024) { + cudaError_t err = cudaFuncSetAttribute( + kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_bytes); + NVTE_CHECK( + err == cudaSuccess, + "Failed to set dynamic shared memory size."); + } kernel<<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(global_amax.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), + row_length, num_rows, scale_stride_x, + scale_stride_y, scale_t_stride_x, + scale_t_stride_y, kScaleBlockDim, epsilon, + rng_state, + noop_ptr);) // kUseFastMath + ) // kUse4Over6 + ) // kRowScaledNVFP4 + ) // kIs2DBlockScaling + ) // kApplyStochasticRounding + ) // kSwizzledScale + ) // kAligned + ) // kReturnTranspose + ) // kReturnIdentity + ) // OutputType + ) // InputType NVTE_CHECK_CUDA(cudaGetLastError()); #else diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 94350da1e6..6f14fcdf89 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -327,6 +327,7 @@ class NVFP4Quantizer : public Quantizer { // 2D block scaling bool with_2d_quantization; bool stochastic_rounding; + bool use_4over6; // Whether tensors emitted by this quantizer use row-scaled NVFP4 metadata. bool row_scaled_nvfp4; diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 2b38339d67..d4a27c61b4 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -84,6 +84,8 @@ void group_quantize_nvfp4_impl(const GroupedTensorWrapper &grouped_input_tensor, // assert the 2D scaling case, since 2D scaling grouped quant kernel is not ready yet NVTE_CHECK(!nvfp4_quantizer_cpp->with_2d_quantization, "2D scaling grouped quant kernel is not ready yet"); + NVTE_CHECK(!nvfp4_quantizer_cpp->use_4over6, + "NVFP4 4over6 quantization is not supported for grouped quantization."); auto quant_config_cpp = QuantizationConfigWrapper(); @@ -722,6 +724,7 @@ std::tuple, std::vector, bool> bulk_alloc // Quantization parameters const auto rowwise_usage = quantizer_cpp_list[0]->rowwise_usage; const bool row_scaled_nvfp4 = quantizer_cpp_list[0]->row_scaled_nvfp4; + const bool use_4over6 = quantizer_cpp_list[0]->use_4over6; const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage; if (row_scaled_nvfp4) { NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 bulk allocation requires rowwise usage."); @@ -866,10 +869,10 @@ std::tuple, std::vector, bool> bulk_alloc py::object amax_columnwise = columnwise_usage ? py::cast(amax_columnwise_list[i]) : py::none(); // Construct Python tensor - tensor_py_list.emplace_back(NVFP4TensorClass(rowwise_data, rowwise_scale, columnwise_data, - columnwise_scale, amax_rowwise, amax_columnwise, - fp4_dtype, quantizer_py_list[i], - with_gemm_swizzled_scales, row_scaled_nvfp4)); + tensor_py_list.emplace_back( + NVFP4TensorClass(rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, + amax_rowwise, amax_columnwise, fp4_dtype, quantizer_py_list[i], + with_gemm_swizzled_scales, row_scaled_nvfp4, use_4over6)); // Construct C++ tensor // Use a TensorWrapper variable to hold the output of makeTransformerEngineTensor, @@ -887,6 +890,7 @@ std::tuple, std::vector, bool> bulk_alloc columnwise_usage ? columnwise_scale_shapes[i] : std::vector{0}, scaling_mode); tensor_wrapper.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); tensor_wrapper.set_row_scaled_nvfp4(row_scaled_nvfp4); + tensor_wrapper.set_nvfp4_4over6(use_4over6); // Set the amax rowwise and amax columnwise if available if (rowwise_usage) { @@ -997,6 +1001,8 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, cudaStream_t stream) { const size_t num_tensors = split_sections.size(); const auto &quantizer = *quantizers.front(); + NVTE_CHECK(!quantizer.use_4over6, + "NVFP4 4over6 quantization is not supported with RHT split quantization."); std::vector nvte_tensor_input_list; std::vector nvte_tensor_output_list; @@ -1157,6 +1163,8 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, cudaStream_t stream) { const size_t num_tensors = input_list.size(); const auto &quantizer = *quantizers.front(); + NVTE_CHECK(!quantizer.use_4over6 || !quantizer.stochastic_rounding, + "NVFP4 4over6 quantization does not support stochastic rounding."); std::vector nvte_tensor_input_list; std::vector nvte_tensor_output_list; @@ -1189,6 +1197,17 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, need_separate_rng_states, quant_config_list, dummy_quant_config_list_colwise); // colwise rng states are not needed in this case + for (auto &config : quant_config_list) { + config.set_nvfp4_4over6(quantizer.use_4over6); + } + + const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); + if (use_fast_math) { + for (auto &config : quant_config_list) { + config.set_use_fast_math(true); + } + } + // We need: // 1. Rowwise amax = amax for input // 2. Columnwise amax = amax for input too @@ -1259,6 +1278,11 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, "NVFP4 split-quantize does not support 2D quantization"); NVTE_CHECK(!quantizer.with_amax_reduction, "NVFP4 split-quantize does not support amax reduction"); + if (quantizer.use_4over6) { + NVTE_CHECK(!quantizer.with_rht, "NVFP4 4over6 quantization does not support RHT."); + NVTE_CHECK(!quantizer.stochastic_rounding, + "NVFP4 4over6 quantization does not support stochastic rounding."); + } // Check input tensor shape const size_t input_last_dim = input.ndim() > 0 ? input.size(input.ndim() - 1) : 1; diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 7045995dd7..b7abf0da7c 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1729,6 +1729,7 @@ NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantize this->with_post_rht_amax = quantizer.attr("with_post_rht_amax").cast(); this->with_2d_quantization = quantizer.attr("with_2d_quantization").cast(); this->stochastic_rounding = quantizer.attr("stochastic_rounding").cast(); + this->use_4over6 = quantizer.attr("use_4over6").cast(); this->row_scaled_nvfp4 = quantizer.attr("row_scaled_nvfp4").cast(); // Get amax reduction group if needed for NVFP4 AG @@ -1778,6 +1779,7 @@ std::pair NVFP4Quantizer::create_tensor( "NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE, " (got shape=", shape, ")"); const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; + const bool use_4over6 = this->use_4over6; if (row_scaled_nvfp4) { NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 quantization requires rowwise usage."); NVTE_CHECK(!columnwise_usage, @@ -1845,6 +1847,7 @@ std::pair NVFP4Quantizer::create_tensor( kwargs["quantizer"] = this->quantizer; kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); + kwargs["use_4over6"] = py::cast(use_4over6); kwargs["fake_dtype"] = GetATenDType(dtype); py::tuple args(0); @@ -1875,6 +1878,7 @@ std::pair NVFP4Quantizer::create_tensor( kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); kwargs["device"] = py::cast(device); kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); + kwargs["use_4over6"] = py::cast(use_4over6); py::tuple args(0); PyObject* result = PyObject_Call(reinterpret_cast(NVFP4TensorPythonClass), args.ptr(), kwargs.ptr()); @@ -1908,6 +1912,7 @@ std::pair NVFP4Quantizer::create_tensor( } out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); out_cpp.set_row_scaled_nvfp4(row_scaled_nvfp4); + out_cpp.set_nvfp4_4over6(use_4over6); this->set_quantization_params(&out_cpp); return {std::move(out_cpp), std::move(out_py)}; @@ -1936,6 +1941,7 @@ std::pair NVFP4Quantizer::create_grouped_tenso std::optional columnwise_amax; const std::vector logical_shape_vec = {logical_first_dim, logical_last_dim}; const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; + const bool use_4over6 = this->use_4over6; if (row_scaled_nvfp4) { NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 grouped quantization requires rowwise usage."); NVTE_CHECK(!columnwise_usage, @@ -2010,6 +2016,7 @@ std::pair NVFP4Quantizer::create_grouped_tenso kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); kwargs["with_gemm_swizzled_scales"] = this->optimize_for_gemm; kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); + kwargs["use_4over6"] = py::cast(use_4over6); PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); if (result == nullptr) { PyErr_Print(); @@ -2091,6 +2098,7 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( "Row-scaled NVFP4 quantization does not support columnwise usage."); } tensor.attr("_row_scaled_nvfp4") = py::cast(row_scaled_nvfp4); + tensor.attr("_use_4over6") = py::cast(use_4over6); // Coerce row-wise data if (rowwise_usage) { @@ -2195,6 +2203,7 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( } out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); out_cpp.set_row_scaled_nvfp4(row_scaled_nvfp4); + out_cpp.set_nvfp4_4over6(use_4over6); this->set_quantization_params(&out_cpp); return {std::move(out_cpp), std::move(tensor)}; @@ -2285,6 +2294,15 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou } quant_config.set_nvfp4_2d_quantization(this->with_2d_quantization); quant_config.set_stochastic_rounding(this->stochastic_rounding); + quant_config.set_nvfp4_4over6(this->use_4over6); + + if (this->use_4over6) { + NVTE_CHECK(!this->with_rht, "NVFP4 4over6 quantization does not support RHT."); + NVTE_CHECK(!this->with_2d_quantization, + "NVFP4 4over6 quantization does not support 2D quantization."); + NVTE_CHECK(!this->stochastic_rounding, + "NVFP4 4over6 quantization does not support stochastic rounding."); + } // We only need RHT for columnwise usage. // flat first dim and last dim for multi dimensional input diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index 37ab0b0535..69ed98b435 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -135,6 +135,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) const bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); const bool with_gemm_swizzled_scales = tensor.attr("_with_gemm_swizzled_scales").cast(); const bool row_scaled_nvfp4 = tensor.attr("_row_scaled_nvfp4").cast(); + const bool use_4over6 = tensor.attr("_use_4over6").cast(); NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for NVFP4 Tensor."); @@ -165,6 +166,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) // Scale layout ret.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); ret.set_row_scaled_nvfp4(row_scaled_nvfp4); + ret.set_nvfp4_4over6(use_4over6); // Quantizer state quantizer->set_quantization_params(&ret); diff --git a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py index acb7abefd1..f0a07b4e4e 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py @@ -350,6 +350,7 @@ def __init__( eps: float = 0.0, quant_tile_shape: Tuple[int, int] = (1, 16), row_scaled_nvfp4: bool = False, + use_4over6: bool = False, with_rht: bool = False, with_random_sign_mask: bool = True, ): @@ -360,6 +361,13 @@ def __init__( raise ValueError( "Row-scaled NVFP4 reference quantization does not support columnwise usage." ) + if use_4over6: + if pow_2_scales: + raise ValueError("4over6 is only supported for NVFP4 (non-pow2) mode.") + if quant_tile_shape != (1, 16): + raise ValueError("4over6 reference quantization only supports 1x16 tiles.") + if with_rht: + raise ValueError("4over6 reference quantization does not support RHT.") super().__init__(rowwise=rowwise, columnwise=columnwise) self.internal = True @@ -368,6 +376,7 @@ def __init__( self.eps = eps self.quant_tile_shape = quant_tile_shape self.row_scaled_nvfp4 = row_scaled_nvfp4 + self.use_4over6 = use_4over6 self.with_rht = with_rht self.with_random_sign_mask = with_random_sign_mask @@ -456,6 +465,7 @@ def _quantize_blockwise_reference( *, pow_2_scales: bool, row_scaled_nvfp4: bool = False, + use_4over6: bool = False, eps: float, # pylint: disable=unused-argument ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -488,6 +498,9 @@ def _quantize_blockwise_reference( x = x.view(m, n // tile_len_x, tile_len_x) FLOAT4_E2M1_MAX = torch.tensor(6.0, device=x.device, dtype=torch.float32) FLOAT8_E4M3_MAX = torch.tensor(448.0, device=x.device, dtype=torch.float32) + GLOBAL_SCALE_E4M3_MAX = torch.tensor( + 256.0 if use_4over6 else 448.0, device=x.device, dtype=torch.float32 + ) decode_scale = torch.div(vec_max, FLOAT4_E2M1_MAX) if pow_2_scales: @@ -497,10 +510,12 @@ def _quantize_blockwise_reference( decode_scale.to(torch.float32), ) else: + if use_4over6 and using_2d_quantization: + raise ValueError("4over6 reference quantization does not support 2D quantization.") if row_scaled_nvfp4: global_amax = global_amax.to(torch.float32).view(m, 1, 1) - global_encode_scale = torch.div(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX, global_amax) + global_encode_scale = torch.div(GLOBAL_SCALE_E4M3_MAX * FLOAT4_E2M1_MAX, global_amax) global_encode_scale = torch.min( global_encode_scale, torch.tensor( @@ -519,30 +534,111 @@ def _quantize_blockwise_reference( global_encode_scale, ) global_decode_scale = torch.div(1.0, global_encode_scale) - global_encode_scale_multiplier = global_encode_scale * torch.reciprocal(FLOAT4_E2M1_MAX) - - # Match the kernel's default path: fold the FP4 reciprocal into the - # global scale multiplier, but keep the final reciprocal exact. - decode_scale = vec_max * global_encode_scale_multiplier - decode_scale = torch.min( - decode_scale, - torch.tensor( + if use_4over6: + # FourOverSix compares map-to-4 and map-to-6 candidates using + # the original input-domain MSE, while keeping TE-style FP4 + # quantization for each candidate. + decode_scale_map6 = torch.div(vec_max, FLOAT4_E2M1_MAX) * global_encode_scale + decode_scale_map4 = decode_scale_map6 * 1.5 + decode_scale_map4 = torch.clamp( + decode_scale_map4, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX + ).to(torch.float8_e4m3fn) + decode_scale_map6 = torch.clamp( + decode_scale_map6, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX + ).to(torch.float8_e4m3fn) + + fp32_max = torch.tensor( torch.finfo(torch.float32).max, - device=decode_scale.device, + device=decode_scale_map4.device, dtype=torch.float32, - ), - ) - decode_scale = torch.clamp(decode_scale, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX) - decode_scale = decode_scale.to(torch.float8_e4m3fn) + ) + encode_scale_map4 = torch.min( + torch.div(1.0, decode_scale_map4.to(torch.float32) * global_decode_scale), + fp32_max, + ) + encode_scale_map6 = torch.min( + torch.div(1.0, decode_scale_map6.to(torch.float32) * global_decode_scale), + fp32_max, + ) - encode_scale = torch.min( - torch.div(1.0, decode_scale.to(torch.float32) * global_decode_scale), - torch.tensor( - torch.finfo(torch.float32).max, - device=decode_scale.device, - dtype=torch.float32, - ), - ) + clipped_x_map4 = torch.clamp( + x.to(torch.float32) * encode_scale_map4, + -FLOAT4_E2M1_MAX, + FLOAT4_E2M1_MAX, + ).reshape(m, n) + clipped_x_map6 = torch.clamp( + x.to(torch.float32) * encode_scale_map6, + -FLOAT4_E2M1_MAX, + FLOAT4_E2M1_MAX, + ).reshape(m, n) + qx_map4 = cast_to_fp4x2(clipped_x_map4) + qx_map6 = cast_to_fp4x2(clipped_x_map6) + + fp4_map4 = cast_from_fp4x2(qx_map4, torch.float32).view( + m, n // tile_len_x, tile_len_x + ) + fp4_map6 = cast_from_fp4x2(qx_map6, torch.float32).view( + m, n // tile_len_x, tile_len_x + ) + denom = FLOAT4_E2M1_MAX * GLOBAL_SCALE_E4M3_MAX + sf_map4 = decode_scale_map4.to(torch.float32).squeeze(-1) + sf_map6 = decode_scale_map6.to(torch.float32).squeeze(-1) + if row_scaled_nvfp4: + mse_global_amax = global_amax.squeeze(-1) + else: + mse_global_amax = global_amax + x_float = x.to(torch.float32) + err_map4 = torch.zeros_like(vec_max) + err_map6 = torch.zeros_like(vec_max) + for idx in range(tile_len_x): + val_map4 = fp4_map4[:, :, idx] * sf_map4 + val_map4 = val_map4 * mse_global_amax + val_map4 = val_map4 / denom + diff_map4 = val_map4 - x_float[:, :, idx] + err_map4 = err_map4 + (diff_map4 * diff_map4).unsqueeze(-1) + + val_map6 = fp4_map6[:, :, idx] * sf_map6 + val_map6 = val_map6 * mse_global_amax + val_map6 = val_map6 / denom + diff_map6 = val_map6 - x_float[:, :, idx] + err_map6 = err_map6 + (diff_map6 * diff_map6).unsqueeze(-1) + pick_map4 = err_map4 < err_map6 + qx = torch.where( + pick_map4.expand(-1, -1, tile_len_x // 2), + qx_map4.view(m, n // tile_len_x, tile_len_x // 2), + qx_map6.view(m, n // tile_len_x, tile_len_x // 2), + ).reshape(m, n // 2) + decode_scale = torch.where(pick_map4, decode_scale_map4, decode_scale_map6).squeeze( + -1 + ) + return qx, decode_scale + else: + global_encode_scale_multiplier = global_encode_scale * torch.reciprocal( + FLOAT4_E2M1_MAX + ) + + # Match the kernel's default path: fold the FP4 reciprocal into the + # global scale multiplier, but keep the final reciprocal exact. + decode_scale = vec_max * global_encode_scale_multiplier + decode_scale = torch.min( + decode_scale, + torch.tensor( + torch.finfo(torch.float32).max, + device=decode_scale.device, + dtype=torch.float32, + ), + ) + decode_scale = torch.clamp(decode_scale, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX) + decode_scale = decode_scale.to(torch.float8_e4m3fn) + + encode_scale = torch.min( + torch.div(1.0, decode_scale.to(torch.float32) * global_decode_scale), + torch.tensor( + torch.finfo(torch.float32).max, + device=decode_scale.device, + dtype=torch.float32, + ), + ) scaled_x = x.to(torch.float32) * encode_scale @@ -679,6 +775,7 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ self.quant_tile_shape[0], pow_2_scales=self.pow_2_scales, row_scaled_nvfp4=self.row_scaled_nvfp4, + use_4over6=self.use_4over6, eps=self.eps, ) if transpose_scales: @@ -702,6 +799,7 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ self.quant_tile_shape[1], self.quant_tile_shape[0], pow_2_scales=self.pow_2_scales, + use_4over6=self.use_4over6, eps=self.eps, ) diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 0c40723517..5642344143 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -1668,6 +1668,7 @@ def _make(tensor_type: str) -> NVFP4Quantizer: and tensor_type != "weight" and self.recipe.row_scaled_activation ), + use_4over6=self.recipe.enable_4over6, ) if self.mode not in ("forward", "backward"): diff --git a/transformer_engine/pytorch/tensor/grouped_tensor.py b/transformer_engine/pytorch/tensor/grouped_tensor.py index f28f972b58..b1fe0c432d 100644 --- a/transformer_engine/pytorch/tensor/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/grouped_tensor.py @@ -93,6 +93,7 @@ def __new__( stride: Optional[List[int]] = None, with_gemm_swizzled_scales: bool = False, row_scaled_nvfp4: bool = False, + use_4over6: bool = False, ): if ( shapes is not None @@ -166,6 +167,7 @@ def __new__( columnwise_scale_inv_offsets=columnwise_scale_inv_offsets, with_gemm_swizzled_scales=with_gemm_swizzled_scales, row_scaled_nvfp4=row_scaled_nvfp4, + use_4over6=use_4over6, ) return instance @@ -197,7 +199,8 @@ def copy_grouped_storage_metadata(dst: GroupedTensor, src: GroupedTensor) -> Non dst.logical_shape = src.logical_shape dst.quantized_tensors = src.quantized_tensors dst._with_gemm_swizzled_scales = src._with_gemm_swizzled_scales - dst.row_scaled_nvfp4 = src.row_scaled_nvfp4 + dst._row_scaled_nvfp4 = src._row_scaled_nvfp4 + dst._use_4over6 = src._use_4over6 def make_wrapper_like(src: GroupedTensor, requires_grad: bool) -> GroupedTensor: """Create a wrapper of the same type and tensor metadata as src.""" diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index df7a2b4bd3..1cb906d61f 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -130,6 +130,8 @@ class NVFP4Quantizer(Quantizer): """Whether emitted NVFP4 tensors store one FP32 amax per row.""" row_scaled_nvfp4: bool + """Whether to use NVFP4 4over6 block scale selection.""" + use_4over6: bool """RHT matrix random sign mask""" rht_matrix_random_sign_mask_t: int @@ -147,6 +149,7 @@ def __init__( with_2d_quantization: bool = False, stochastic_rounding: bool = False, row_scaled_nvfp4: bool = False, + use_4over6: bool = False, with_random_sign_mask: bool = True, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) @@ -158,6 +161,7 @@ def __init__( self.with_2d_quantization = with_2d_quantization self.stochastic_rounding = stochastic_rounding self.row_scaled_nvfp4 = row_scaled_nvfp4 + self.use_4over6 = use_4over6 self.rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht( with_random_sign_mask, torch.cuda.current_device() ) @@ -204,6 +208,7 @@ def copy(self) -> NVFP4Quantizer: with_2d_quantization=self.with_2d_quantization, stochastic_rounding=self.stochastic_rounding, row_scaled_nvfp4=self.row_scaled_nvfp4, + use_4over6=self.use_4over6, ) quantizer.internal = self.internal quantizer.optimize_for_gemm = self.optimize_for_gemm @@ -356,6 +361,7 @@ def __new__( quantizer: Quantizer, with_gemm_swizzled_scales: bool, row_scaled_nvfp4: bool = False, + use_4over6: bool = False, **kwargs, ): instance = super().__new__( @@ -371,6 +377,7 @@ def __new__( with_gemm_swizzled_scales, *args, row_scaled_nvfp4=row_scaled_nvfp4, + use_4over6=use_4over6, **kwargs, ) return instance diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index ac56d334bc..0b68c2a439 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -48,6 +48,13 @@ class GroupedTensorStorage: Note: This structure is used only for combined storage of multiple tensors with the same dtype and scaling mode. """ + # Whether scaling factors are in the swizzled format expected by GEMM + _with_gemm_swizzled_scales: bool + # Whether grouped NVFP4 tensors use row-scaled amax metadata + _row_scaled_nvfp4: bool + # Whether grouped NVFP4 tensors use 4over6 block scale selection + _use_4over6: bool + @staticmethod def _initialize_storage_fields( instance: "GroupedTensorStorage", @@ -73,6 +80,7 @@ def _initialize_storage_fields( stride: Optional[List[int]] = None, with_gemm_swizzled_scales: bool = False, row_scaled_nvfp4: bool = False, + use_4over6: bool = False, ) -> None: """ Initialize a GroupedTensor. @@ -148,7 +156,8 @@ def _initialize_storage_fields( # Used as a convenience. instance.quantized_tensors = None instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales - instance.row_scaled_nvfp4 = row_scaled_nvfp4 + instance._row_scaled_nvfp4 = row_scaled_nvfp4 + instance._use_4over6 = use_4over6 def __new__( cls, @@ -175,6 +184,7 @@ def __new__( stride: Optional[List[int]] = None, with_gemm_swizzled_scales: bool = False, row_scaled_nvfp4: bool = False, + use_4over6: bool = False, ): instance = object.__new__(cls) cls._initialize_storage_fields( @@ -201,6 +211,7 @@ def __new__( stride=stride, with_gemm_swizzled_scales=with_gemm_swizzled_scales, row_scaled_nvfp4=row_scaled_nvfp4, + use_4over6=use_4over6, ) return instance @@ -307,6 +318,24 @@ def get_dtype(self) -> torch.dtype: return self.fake_dtype + @property + def row_scaled_nvfp4(self) -> bool: + """Whether grouped NVFP4 tensors use row-scaled amax metadata.""" + return self._row_scaled_nvfp4 + + @row_scaled_nvfp4.setter + def row_scaled_nvfp4(self, row_scaled_nvfp4: bool) -> None: + self._row_scaled_nvfp4 = row_scaled_nvfp4 + + @property + def use_4over6(self) -> bool: + """Whether grouped NVFP4 tensors use 4over6 block scale selection.""" + return self._use_4over6 + + @use_4over6.setter + def use_4over6(self, use_4over6: bool) -> None: + self._use_4over6 = use_4over6 + def prepare_for_saving( self, ) -> Tuple[list[Optional[torch.Tensor]], "GroupedTensorStorage"]: @@ -375,7 +404,8 @@ def clear(self) -> None: self.columnwise_scale_inv_offsets = None self.tensor_shapes = [] self.fake_dtype = torch.float32 - self.row_scaled_nvfp4 = False + self._row_scaled_nvfp4 = False + self._use_4over6 = False def __repr__(self) -> str: """String representation of the GroupedTensorStorage.""" @@ -544,7 +574,8 @@ def copy(self) -> "GroupedTensorStorage": scale_inv_offsets=self.scale_inv_offsets, columnwise_scale_inv_offsets=self.columnwise_scale_inv_offsets, with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, - row_scaled_nvfp4=self.row_scaled_nvfp4, + row_scaled_nvfp4=self._row_scaled_nvfp4, + use_4over6=self._use_4over6, ) @staticmethod @@ -656,6 +687,7 @@ def make_grouped_tensor( scale_inv_offsets = None columnwise_scale_inv_offsets = None row_scaled_nvfp4 = False + use_4over6 = False if no_quantization: assert dtype is not None, "dtype must be provided for unquantized GroupedTensor" if rowwise_usage: @@ -715,6 +747,7 @@ def make_grouped_tensor( amax = torch.empty(num_tensors, dtype=torch.float32, device=device) elif quantizer._get_compatible_recipe().nvfp4(): row_scaled_nvfp4 = quantizer.row_scaled_nvfp4 + use_4over6 = quantizer.use_4over6 if row_scaled_nvfp4: if not rowwise_usage: raise ValueError( @@ -843,6 +876,7 @@ def make_grouped_tensor( quantizer.optimize_for_gemm if quantizer is not None else False ), row_scaled_nvfp4=row_scaled_nvfp4, + use_4over6=use_4over6, ) grouped_tensor.quantized_tensors = grouped_tensor.split_into_quantized_tensors() return grouped_tensor @@ -956,7 +990,8 @@ def split_into_quantized_tensors( columnwise_scale_inv_offsets.append(cum) self.columnwise_scale_inv_offsets = columnwise_scale_inv_offsets nvfp4_rowwise_amax_offsets = None - row_scaled_nvfp4 = self.row_scaled_nvfp4 + row_scaled_nvfp4 = self._row_scaled_nvfp4 + use_4over6 = self._use_4over6 if recipe.nvfp4() and row_scaled_nvfp4: cum = 0 nvfp4_rowwise_amax_offsets = [0] @@ -1184,6 +1219,7 @@ def split_into_quantized_tensors( quantizer=quantizer, with_gemm_swizzled_scales=quantizer.optimize_for_gemm, row_scaled_nvfp4=row_scaled_nvfp4, + use_4over6=use_4over6, ) result.append(tensor) diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index e51acb71e5..94bf9cd3e6 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -99,6 +99,8 @@ class NVFP4TensorStorage(QuantizedTensorStorage): _with_gemm_swizzled_scales: bool # Whether this NVFP4 tensor uses row-scaled amax metadata _row_scaled_nvfp4: bool + # Whether this NVFP4 tensor uses 4over6 block scale selection + _use_4over6: bool def __new__( cls, @@ -114,6 +116,7 @@ def __new__( *args, fake_dtype: Optional[torch.dtype] = None, row_scaled_nvfp4: bool = False, + use_4over6: bool = False, **kwargs, ): if cls is NVFP4TensorStorage: @@ -132,6 +135,7 @@ def __new__( instance._amax_columnwise = amax_columnwise instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales instance._row_scaled_nvfp4 = row_scaled_nvfp4 + instance._use_4over6 = use_4over6 return instance @@ -158,6 +162,8 @@ def copy_from_storage(self, src: QuantizedTensorStorage) -> None: raise RuntimeError("Scale layout mismatch in copy_from_storage") if self._row_scaled_nvfp4 != src._row_scaled_nvfp4: raise RuntimeError("Rowwise amax scaling mode mismatch in copy_from_storage") + if self._use_4over6 != src._use_4over6: + raise RuntimeError("NVFP4 4over6 mode mismatch in copy_from_storage") def _copy_optional(dst: Optional[torch.Tensor], src_tensor: Optional[torch.Tensor]): if dst is not None and src_tensor is not None: @@ -183,6 +189,7 @@ def get_metadata(self) -> Dict[str, Any]: "quantizer": self._quantizer, "with_gemm_swizzled_scales": self._with_gemm_swizzled_scales, "row_scaled_nvfp4": self._row_scaled_nvfp4, + "use_4over6": self._use_4over6, "fake_dtype": self._dtype, } @@ -316,6 +323,7 @@ def view(self, shape: torch.Size): fp4_dtype=self._fp4_dtype, with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, row_scaled_nvfp4=self._row_scaled_nvfp4, + use_4over6=self._use_4over6, fake_dtype=self._dtype, ) From 7b0b2d03f36d502a2e89db92a664447cd79ccb59 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Fri, 8 May 2026 21:08:36 -0700 Subject: [PATCH 02/57] Make 4over6 compile time for dequant Signed-off-by: Ziang Li --- .../common/cast/nvfp4/dequantize_nvfp4.cuh | 33 +++++++++++-------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index 9bc54fc191..8f12be20bf 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -31,11 +31,10 @@ namespace dispatch { namespace nvfp4 { namespace dequantize_kernel { #if FP4_TYPE_SUPPORTED -template +template __global__ void __launch_bounds__(512) dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales, - const float *const tensor_amax, const bool row_scaled_nvfp4, - const bool use_4over6, const size_t N, const size_t M, + const float *const tensor_amax, const size_t N, const size_t M, const size_t scale_stride, const size_t num_scale_tiles_X) { const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; const size_t x = thread_idx % M; @@ -64,8 +63,13 @@ __global__ void __launch_bounds__(512) fp4vec value; value.vec = input_vectorized[my_index]; fp8e4m3 scale = scales[my_scale_index]; - float amax = row_scaled_nvfp4 ? tensor_amax[y] : tensor_amax[0]; - const float fp8_max = use_4over6 ? 256.0f : 448.0f; + float amax; + if constexpr (ROW_SCALED_NVFP4) { + amax = tensor_amax[y]; + } else { + amax = tensor_amax[0]; + } + constexpr float fp8_max = USE_4OVER6 ? 256.0f : 448.0f; const float factor_inv = 1.0f / (6.0f * fp8_max); float final_scale = static_cast(scale) * amax * factor_inv; #pragma unroll @@ -114,14 +118,17 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) output->data.dtype, OType, TRANSFORMER_ENGINE_SWITCH_CONDITION( with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, - - dequantize_fp4_kernel<<>>( - input.data.dptr, reinterpret_cast(output->data.dptr), - reinterpret_cast(input.scale_inv.dptr), - reinterpret_cast(input.amax.dptr), row_scaled_nvfp4, use_4over6, N, Mread, - input.scale_inv.shape.back(), - num_scale_tiles_X);); // NOLINT(*) - ); // NOLINT(*) + TRANSFORMER_ENGINE_SWITCH_CONDITION( + row_scaled_nvfp4, ROW_SCALED_NVFP4, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_4over6, USE_4OVER6, + dequantize_fp4_kernel<<>>( + input.data.dptr, reinterpret_cast(output->data.dptr), + reinterpret_cast(input.scale_inv.dptr), + reinterpret_cast(input.amax.dptr), N, Mread, + input.scale_inv.shape.back(), num_scale_tiles_X);););); // NOLINT(*) + ); // NOLINT(*) NVTE_CHECK_CUDA(cudaGetLastError()); #else NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); From 1e5b6ad53d66e5e4466cff145451beb553a2d7d8 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Fri, 8 May 2026 21:25:25 -0700 Subject: [PATCH 03/57] Expand 1d fwd+bwd test Signed-off-by: Ziang Li --- tests/pytorch/test_backward_override.py | 9 ++++++++- tests/pytorch/test_sanity.py | 14 ++++++++++++++ tests/pytorch/utils.py | 12 +++++++++++- 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index 832b5a3741..6774b96a40 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -39,7 +39,7 @@ # -------------------------- _BACKWARD_OVERRIDES = ("high_precision", "dequantized") -_NVFP4_RECIPE_NAMES = ("nvfp4", "nvfp4_row_scaled", "nvfp4_row_scaled_4over6") +_NVFP4_RECIPE_NAMES = ("nvfp4", "nvfp4_4over6", "nvfp4_row_scaled", "nvfp4_row_scaled_4over6") _NVFP4_ROW_SCALED_RECIPE_NAMES = ("nvfp4_row_scaled", "nvfp4_row_scaled_4over6") fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) @@ -80,6 +80,11 @@ marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), id="NVFP4BlockScaling", ), + pytest.param( + "nvfp4_4over6", + marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), + id="NVFP44Over6BlockScaling", + ), pytest.param( "nvfp4_row_scaled_4over6", marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), @@ -187,6 +192,8 @@ def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: s pytest.skip("Fusible ops (te_ops.Linear) do not support Float8BlockScaling recipe") if module_type == "ops_linear" and recipe_name in _NVFP4_ROW_SCALED_RECIPE_NAMES: pytest.skip("Row-scaled NVFP4 currently does not support fused te_ops paths.") + if module_type == "grouped_linear" and recipe_name == "nvfp4_4over6": + pytest.skip("NVFP4 4over6 currently does not support grouped quantization.") def _make_quantized_forward_reference_recipe(recipe_name: str) -> recipe.Recipe: diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 67fefc8217..2b4799d99f 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -94,6 +94,19 @@ def nvfp4_vanilla(): return nvfp4_recipe +def nvfp4_4over6(): + nvfp4_recipe = recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + enable_4over6=True, + ) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + def nvfp4_row_scaled_4over6(): nvfp4_recipe = recipe.NVFP4BlockScaling( disable_rht=True, @@ -113,6 +126,7 @@ def nvfp4_row_scaled_4over6(): fp8_recipes.append(recipe.MXFP8BlockScaling()) if nvfp4_available: fp8_recipes.append(nvfp4_vanilla()) # TODO: fix check for this + fp8_recipes.append(nvfp4_4over6()) if fp8_block_scaling_available: fp8_recipes.append(recipe.Float8BlockScaling()) if fp8_available: diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 62742d4ebb..a60a52be4c 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -118,7 +118,7 @@ def quantization_tols(name: str) -> dict[str, float]: "mxfp8_block_scaling", ): return dtype_tols(tex.DType.kFloat8E4M3) - if name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_row_scaled_4over6"): + if name in ("nvfp4", "nvfp4_4over6", "nvfp4_row_scaled", "nvfp4_row_scaled_4over6"): return dtype_tols(tex.DType.kFloat4E2M1) raise ValueError(f"Unsupported quantization scheme ({name})") @@ -152,6 +152,14 @@ def make_recipe(name: Optional[str], **recipe_kwargs: Any) -> Optional[Recipe]: disable_2d_quantization=True, **recipe_kwargs, ) + if name == "nvfp4_4over6": + return transformer_engine.common.recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + enable_4over6=True, + **recipe_kwargs, + ) if name == "nvfp4_row_scaled": return transformer_engine.common.recipe.NVFP4BlockScaling( disable_rht=True, @@ -178,6 +186,8 @@ def recipe_id(recipe: Optional[Recipe]) -> str: return "None" if recipe.nvfp4() and recipe.row_scaled_activation and recipe.enable_4over6: return "NVFP4RowScaled4Over6BlockScaling" + if recipe.nvfp4() and recipe.enable_4over6: + return "NVFP44Over6BlockScaling" if recipe.nvfp4() and recipe.row_scaled_activation: return "NVFP4RowScaledBlockScaling" return type(recipe).__name__ From 99660fc01175745338d87a9a74aee9e7700ab5fb Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Fri, 8 May 2026 21:46:31 -0700 Subject: [PATCH 04/57] Refactor Signed-off-by: Ziang Li --- docs/envvars.rst | 2 +- tests/cpp/operator/test_dequantize_nvfp4.cu | 8 +-- .../nvfp4/test_nvfp4_quantize_exact.py | 13 ++++ .../common/cast/dispatch/quantize.cuh | 15 +++-- .../quantize_transpose_nvfp4_tuned_1D.cuh | 65 ++++++++++++++----- transformer_engine/common/recipe/__init__.py | 4 +- ...quantize_transpose_vector_blockwise_fp4.cu | 12 +--- 7 files changed, 78 insertions(+), 41 deletions(-) diff --git a/docs/envvars.rst b/docs/envvars.rst index 37b8028533..015593cb67 100644 --- a/docs/envvars.rst +++ b/docs/envvars.rst @@ -291,7 +291,7 @@ Kernel Configuration :Type: ``int`` (0 or 1) :Default: ``0`` - :Description: Enable 4over6 candidate selection for NVFP4 1D quantization in the ``NVFP4BlockScaling`` recipe. This mode currently requires RHT, stochastic rounding, and 2D quantization to be disabled, either with the corresponding recipe fields or with :envvar:`NVTE_NVFP4_DISABLE_RHT`, :envvar:`NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING`, and :envvar:`NVTE_NVFP4_DISABLE_2D_QUANTIZATION`. + :Description: Enable per-block map-to-4 versus map-to-6 candidate selection for NVFP4 1D quantization in the ``NVFP4BlockScaling`` recipe. This mode currently requires RHT, stochastic rounding, and 2D quantization to be disabled, either with the corresponding recipe fields or with :envvar:`NVTE_NVFP4_DISABLE_RHT`, :envvar:`NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING`, and :envvar:`NVTE_NVFP4_DISABLE_2D_QUANTIZATION`. Torch Compilation and Fusion ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu index a0c53b8f86..ea2ef14916 100644 --- a/tests/cpp/operator/test_dequantize_nvfp4.cu +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -119,9 +119,7 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, // Quantize if (rows > 0 && cols > 0) { - QuantizationConfigWrapper quant_config; - quant_config.set_nvfp4_4over6(use_4over6); - nvte_quantize_v2(input.data(), quantized.data(), quant_config, 0); + nvte_quantize(input.data(), quantized.data(), 0); cudaDeviceSynchronize(); auto err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); @@ -181,9 +179,7 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, } if (rows > 0 && cols > 0) { - QuantizationConfigWrapper quant_config; - quant_config.set_nvfp4_4over6(use_4over6); - nvte_quantize_v2(input.data(), quantized_compact.data(), quant_config, 0); + nvte_quantize(input.data(), quantized_compact.data(), 0); cudaDeviceSynchronize(); } diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index d7f9c8994e..48d4717ecc 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -2,6 +2,8 @@ # # See LICENSE for license information. +import os + import pytest import torch import transformer_engine.pytorch as te @@ -16,6 +18,17 @@ recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) +@pytest.fixture(autouse=True) +def disable_fast_math_for_exact_reference_tests(): + original = os.environ.get("NVTE_USE_FAST_MATH") + os.environ["NVTE_USE_FAST_MATH"] = "0" + yield + if original is None: + os.environ.pop("NVTE_USE_FAST_MATH", None) + else: + os.environ["NVTE_USE_FAST_MATH"] = original + + def maybe_skip_row_scaled_unsupported_quantization( row_scaled_nvfp4: bool, return_transpose: bool, diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 9fccffe993..ca0fed9f16 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -101,11 +101,13 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, int32_t cols = input_tensor->flat_last_dim(); auto dtype = input_tensor->dtype(); const bool row_scaled_nvfp4 = output_tensor->row_scaled_nvfp4; - const bool use_4over6 = quant_config_cpp.nvfp4_4over6; + const bool use_4over6 = quant_config_cpp.nvfp4_4over6 || output_tensor->nvfp4_4over6; NVTE_CHECK(!use_4over6 || !quant_config_cpp.nvfp4_2d_quantization, "NVFP4 4over6 quantization does not support 2D quantization."); NVTE_CHECK(!use_4over6 || !quant_config_cpp.stochastic_rounding, "NVFP4 4over6 quantization does not support stochastic rounding."); + quant_config_cpp.nvfp4_4over6 = use_4over6; + output_tensor->nvfp4_4over6 = use_4over6; if (row_scaled_nvfp4) { NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, "Row-scaled NVFP4 quantization does not support 2D quantization."); @@ -256,11 +258,13 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens int32_t rows = grad_tensor->flat_first_dim(); int32_t cols = grad_tensor->flat_last_dim(); auto dtype = grad_tensor->dtype(); - const bool use_4over6 = quant_config_cpp.nvfp4_4over6; + const bool use_4over6 = quant_config_cpp.nvfp4_4over6 || output_tensor->nvfp4_4over6; NVTE_CHECK(!use_4over6 || !quant_config_cpp.nvfp4_2d_quantization, "NVFP4 4over6 quantization does not support 2D quantization."); NVTE_CHECK(!use_4over6 || !quant_config_cpp.stochastic_rounding, "NVFP4 4over6 quantization does not support stochastic rounding."); + quant_config_cpp.nvfp4_4over6 = use_4over6; + output_tensor->nvfp4_4over6 = use_4over6; NVTE_CHECK(!output_tensor->row_scaled_nvfp4, "Backward NVFP4 quantization does not support row-scaled outputs."); bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && @@ -386,10 +390,13 @@ void group_quantize_fwd_host_aware_helper(const NVTETensor input, NVTETensor *ou int32_t cols = input_tensor->flat_last_dim(); auto dtype = input_tensor->dtype(); + bool use_4over6 = quant_config_cpp.nvfp4_4over6; + for (const auto *output_tensor : output_tensors) { + use_4over6 = use_4over6 || output_tensor->nvfp4_4over6; + } NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, "2D quantization is not supported for group quantize."); - NVTE_CHECK(!quant_config_cpp.nvfp4_4over6, - "NVFP4 4over6 quantization is not supported for group quantize."); + NVTE_CHECK(!use_4over6, "NVFP4 4over6 quantization is not supported for group quantize."); // Launch NVFP4 group quantize kernel nvfp4::group_quantize_transpose( diff --git a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh index e9a567825d..998539836f 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh @@ -396,27 +396,58 @@ __device__ __forceinline__ void rowwise_scaling( __align__(8) uint32_t rOut_map4[WAVES]; __align__(8) uint32_t rOut_map6[WAVES]; - auto process_wave = [&](const int w) { - const float x[8] = { - static_cast(rIn[w][0].x), static_cast(rIn[w][0].y), - static_cast(rIn[w][1].x), static_cast(rIn[w][1].y), - static_cast(rIn[w][2].x), static_cast(rIn[w][2].y), - static_cast(rIn[w][3].x), static_cast(rIn[w][3].y), + if (bank_group == 0) { + const float x0[8] = { + static_cast(rIn[0][0].x), static_cast(rIn[0][0].y), + static_cast(rIn[0][1].x), static_cast(rIn[0][1].y), + static_cast(rIn[0][2].x), static_cast(rIn[0][2].y), + static_cast(rIn[0][3].x), static_cast(rIn[0][3].y), }; - rOut_map4[w] = core::cvt_fp32_to_fp4_8x_with_mse_rn( - x, static_cast(SFcoefficient_map4), S_dec_b_fp8_map4, block_global_amax, + rOut_map4[0] = core::cvt_fp32_to_fp4_8x_with_mse_rn( + x0, static_cast(SFcoefficient_map4), S_dec_b_fp8_map4, block_global_amax, &err_map4); - rOut_map6[w] = core::cvt_fp32_to_fp4_8x_with_mse_rn( - x, static_cast(SFcoefficient_map6), S_dec_b_fp8_map6, block_global_amax, + rOut_map6[0] = core::cvt_fp32_to_fp4_8x_with_mse_rn( + x0, static_cast(SFcoefficient_map6), S_dec_b_fp8_map6, block_global_amax, &err_map6); - }; - if (bank_group == 0) { - process_wave(0); - process_wave(1); + const float x1[8] = { + static_cast(rIn[1][0].x), static_cast(rIn[1][0].y), + static_cast(rIn[1][1].x), static_cast(rIn[1][1].y), + static_cast(rIn[1][2].x), static_cast(rIn[1][2].y), + static_cast(rIn[1][3].x), static_cast(rIn[1][3].y), + }; + rOut_map4[1] = core::cvt_fp32_to_fp4_8x_with_mse_rn( + x1, static_cast(SFcoefficient_map4), S_dec_b_fp8_map4, block_global_amax, + &err_map4); + rOut_map6[1] = core::cvt_fp32_to_fp4_8x_with_mse_rn( + x1, static_cast(SFcoefficient_map6), S_dec_b_fp8_map6, block_global_amax, + &err_map6); } else { - process_wave(1); - process_wave(0); + const float x1[8] = { + static_cast(rIn[1][0].x), static_cast(rIn[1][0].y), + static_cast(rIn[1][1].x), static_cast(rIn[1][1].y), + static_cast(rIn[1][2].x), static_cast(rIn[1][2].y), + static_cast(rIn[1][3].x), static_cast(rIn[1][3].y), + }; + rOut_map4[1] = core::cvt_fp32_to_fp4_8x_with_mse_rn( + x1, static_cast(SFcoefficient_map4), S_dec_b_fp8_map4, block_global_amax, + &err_map4); + rOut_map6[1] = core::cvt_fp32_to_fp4_8x_with_mse_rn( + x1, static_cast(SFcoefficient_map6), S_dec_b_fp8_map6, block_global_amax, + &err_map6); + + const float x0[8] = { + static_cast(rIn[0][0].x), static_cast(rIn[0][0].y), + static_cast(rIn[0][1].x), static_cast(rIn[0][1].y), + static_cast(rIn[0][2].x), static_cast(rIn[0][2].y), + static_cast(rIn[0][3].x), static_cast(rIn[0][3].y), + }; + rOut_map4[0] = core::cvt_fp32_to_fp4_8x_with_mse_rn( + x0, static_cast(SFcoefficient_map4), S_dec_b_fp8_map4, block_global_amax, + &err_map4); + rOut_map6[0] = core::cvt_fp32_to_fp4_8x_with_mse_rn( + x0, static_cast(SFcoefficient_map6), S_dec_b_fp8_map6, block_global_amax, + &err_map6); } if (err_map4 < err_map6) { @@ -437,7 +468,6 @@ __device__ __forceinline__ void rowwise_scaling( const scaling_coeff_type SFcoefficient = compute_nvfp4_scaling_coefficient(S_dec_b_fp8, block_S_enc_rowwise); -// Scale elements #pragma unroll for (int w = 0; w < WAVES; ++w) { const uint64_t elts03 = *reinterpret_cast(&rIn[w][0]); @@ -462,7 +492,6 @@ __device__ __forceinline__ void rowwise_scaling( sSFrowwise[scales_offset_Y][scales_offset_X] = S_dec_b_fp8; } -// Scale elements #pragma unroll for (int w = 0; w < WAVES; ++w) { const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % ELTS_PER_THREAD; diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 61eb15dcca..aedd3458fe 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -523,8 +523,8 @@ class NVFP4BlockScaling(Recipe): NVFP4 tensors. In this mode, rowwise ``amax`` metadata is stored as a vector with one FP32 value per tensor row. enable_4over6 : bool, default = False - If set to `True`, NVFP4 1D quantization evaluates per-block 4over6 - and 6-over-6 candidates and chooses the one with lower MSE. + If set to `True`, NVFP4 1D quantization evaluates per-block + map-to-4 and map-to-6 candidates and chooses the one with lower MSE. backward_override : {None, 'high_precision', 'dequantized'}, default = None Backward precision mode. None does not modify backward behavior, `high_precision` keeps original high-precision operands for backward, diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index 7100326305..c3bfb0ab7c 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -190,16 +190,8 @@ __device__ __forceinline__ float ComputeOutputFP4(IType input, float encode_scal template __device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_amax) { - constexpr float fp8_max = kUse4Over6 ? 256.0f : TypeExtrema::max; - constexpr float fp4_max = TypeExtrema::max; - float global_encode_scale = fp8_max * fp4_max / global_amax; - // If scale is infinity, return max value of float32 - global_encode_scale = fminf(global_encode_scale, TypeExtrema::max); - // If global amax is 0 or infinity, return 1 - if (global_amax == 0.f || global_encode_scale == 0.f) { - return 1.f; - } - return global_encode_scale; + return transformer_engine::dispatch::nvfp4::core::compute_global_encode_scaling_factor_FP4< + kUse4Over6>(global_amax); } __device__ __forceinline__ uint32_t get_rbits( From cb2e0a3b87d2380549638fab2026ea7939f64cf6 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Fri, 8 May 2026 21:56:10 -0700 Subject: [PATCH 05/57] Clean up Signed-off-by: Ziang Li --- tests/pytorch/test_backward_override.py | 10 +++- tests/pytorch/test_recipe.py | 4 -- tests/pytorch/utils.py | 54 ++++++++----------- .../pytorch/tensor/grouped_tensor.py | 4 +- .../tensor/storage/grouped_tensor_storage.py | 23 +++----- 5 files changed, 39 insertions(+), 56 deletions(-) diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index 6774b96a40..2d98293dbb 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -39,8 +39,14 @@ # -------------------------- _BACKWARD_OVERRIDES = ("high_precision", "dequantized") -_NVFP4_RECIPE_NAMES = ("nvfp4", "nvfp4_4over6", "nvfp4_row_scaled", "nvfp4_row_scaled_4over6") +_NVFP4_RECIPE_NAMES = ( + "nvfp4", + "nvfp4_4over6", + "nvfp4_row_scaled", + "nvfp4_row_scaled_4over6", +) _NVFP4_ROW_SCALED_RECIPE_NAMES = ("nvfp4_row_scaled", "nvfp4_row_scaled_4over6") +_NVFP4_4OVER6_RECIPE_NAMES = ("nvfp4_4over6", "nvfp4_row_scaled_4over6") fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) @@ -192,7 +198,7 @@ def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: s pytest.skip("Fusible ops (te_ops.Linear) do not support Float8BlockScaling recipe") if module_type == "ops_linear" and recipe_name in _NVFP4_ROW_SCALED_RECIPE_NAMES: pytest.skip("Row-scaled NVFP4 currently does not support fused te_ops paths.") - if module_type == "grouped_linear" and recipe_name == "nvfp4_4over6": + if module_type == "grouped_linear" and recipe_name in _NVFP4_4OVER6_RECIPE_NAMES: pytest.skip("NVFP4 4over6 currently does not support grouped quantization.") diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 90e7381d44..aa2617a625 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -586,15 +586,11 @@ def test_nvfp4_grouped_storage_metadata(): device=torch.device("cuda"), dtype=torch.bfloat16, ) - assert grouped_tensor._row_scaled_nvfp4 assert grouped_tensor.row_scaled_nvfp4 - assert grouped_tensor._use_4over6 assert grouped_tensor.use_4over6 grouped_copy = grouped_tensor.copy() - assert grouped_copy._row_scaled_nvfp4 assert grouped_copy.row_scaled_nvfp4 - assert grouped_copy._use_4over6 assert grouped_copy.use_4over6 for tensor in grouped_tensor.split_into_quantized_tensors(): diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index a60a52be4c..9022b2b530 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -31,6 +31,16 @@ from transformer_engine.pytorch.module.base import get_dummy_wgrad +_NVFP4_RECIPE_NAMES = ( + "nvfp4", + "nvfp4_4over6", + "nvfp4_row_scaled", + "nvfp4_row_scaled_4over6", +) +_NVFP4_ROW_SCALED_RECIPE_NAMES = ("nvfp4_row_scaled", "nvfp4_row_scaled_4over6") +_NVFP4_4OVER6_RECIPE_NAMES = ("nvfp4_4over6", "nvfp4_row_scaled_4over6") + + def str_to_dtype(dtype: str | torch.dtype) -> torch.dtype: """Convert type name to PyTorch dtype""" if isinstance(dtype, torch.dtype): @@ -118,7 +128,7 @@ def quantization_tols(name: str) -> dict[str, float]: "mxfp8_block_scaling", ): return dtype_tols(tex.DType.kFloat8E4M3) - if name in ("nvfp4", "nvfp4_4over6", "nvfp4_row_scaled", "nvfp4_row_scaled_4over6"): + if name in _NVFP4_RECIPE_NAMES: return dtype_tols(tex.DType.kFloat4E2M1) raise ValueError(f"Unsupported quantization scheme ({name})") @@ -145,38 +155,16 @@ def make_recipe(name: Optional[str], **recipe_kwargs: Any) -> Optional[Recipe]: ) if name == "fp8_block_scaling": return transformer_engine.common.recipe.Float8BlockScaling(**recipe_kwargs) - if name == "nvfp4": - return transformer_engine.common.recipe.NVFP4BlockScaling( - disable_rht=True, - disable_stochastic_rounding=True, - disable_2d_quantization=True, - **recipe_kwargs, - ) - if name == "nvfp4_4over6": - return transformer_engine.common.recipe.NVFP4BlockScaling( - disable_rht=True, - disable_stochastic_rounding=True, - disable_2d_quantization=True, - enable_4over6=True, - **recipe_kwargs, - ) - if name == "nvfp4_row_scaled": - return transformer_engine.common.recipe.NVFP4BlockScaling( - disable_rht=True, - disable_stochastic_rounding=True, - disable_2d_quantization=True, - row_scaled_activation=True, - **recipe_kwargs, - ) - if name == "nvfp4_row_scaled_4over6": - return transformer_engine.common.recipe.NVFP4BlockScaling( - disable_rht=True, - disable_stochastic_rounding=True, - disable_2d_quantization=True, - row_scaled_activation=True, - enable_4over6=True, - **recipe_kwargs, - ) + if name in _NVFP4_RECIPE_NAMES: + kwargs = { + "disable_rht": True, + "disable_stochastic_rounding": True, + "disable_2d_quantization": True, + "row_scaled_activation": name in _NVFP4_ROW_SCALED_RECIPE_NAMES, + "enable_4over6": name in _NVFP4_4OVER6_RECIPE_NAMES, + } + kwargs.update(recipe_kwargs) + return transformer_engine.common.recipe.NVFP4BlockScaling(**kwargs) raise ValueError(f"Unsupported quantization scheme ({name})") diff --git a/transformer_engine/pytorch/tensor/grouped_tensor.py b/transformer_engine/pytorch/tensor/grouped_tensor.py index b1fe0c432d..834194a0ce 100644 --- a/transformer_engine/pytorch/tensor/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/grouped_tensor.py @@ -199,8 +199,8 @@ def copy_grouped_storage_metadata(dst: GroupedTensor, src: GroupedTensor) -> Non dst.logical_shape = src.logical_shape dst.quantized_tensors = src.quantized_tensors dst._with_gemm_swizzled_scales = src._with_gemm_swizzled_scales - dst._row_scaled_nvfp4 = src._row_scaled_nvfp4 - dst._use_4over6 = src._use_4over6 + dst.row_scaled_nvfp4 = src.row_scaled_nvfp4 + dst.use_4over6 = src.use_4over6 def make_wrapper_like(src: GroupedTensor, requires_grad: bool) -> GroupedTensor: """Create a wrapper of the same type and tensor metadata as src.""" diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index 0b68c2a439..97272ab954 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -48,13 +48,6 @@ class GroupedTensorStorage: Note: This structure is used only for combined storage of multiple tensors with the same dtype and scaling mode. """ - # Whether scaling factors are in the swizzled format expected by GEMM - _with_gemm_swizzled_scales: bool - # Whether grouped NVFP4 tensors use row-scaled amax metadata - _row_scaled_nvfp4: bool - # Whether grouped NVFP4 tensors use 4over6 block scale selection - _use_4over6: bool - @staticmethod def _initialize_storage_fields( instance: "GroupedTensorStorage", @@ -156,8 +149,8 @@ def _initialize_storage_fields( # Used as a convenience. instance.quantized_tensors = None instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales - instance._row_scaled_nvfp4 = row_scaled_nvfp4 - instance._use_4over6 = use_4over6 + instance.row_scaled_nvfp4 = row_scaled_nvfp4 + instance.use_4over6 = use_4over6 def __new__( cls, @@ -404,8 +397,8 @@ def clear(self) -> None: self.columnwise_scale_inv_offsets = None self.tensor_shapes = [] self.fake_dtype = torch.float32 - self._row_scaled_nvfp4 = False - self._use_4over6 = False + self.row_scaled_nvfp4 = False + self.use_4over6 = False def __repr__(self) -> str: """String representation of the GroupedTensorStorage.""" @@ -574,8 +567,8 @@ def copy(self) -> "GroupedTensorStorage": scale_inv_offsets=self.scale_inv_offsets, columnwise_scale_inv_offsets=self.columnwise_scale_inv_offsets, with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, - row_scaled_nvfp4=self._row_scaled_nvfp4, - use_4over6=self._use_4over6, + row_scaled_nvfp4=self.row_scaled_nvfp4, + use_4over6=self.use_4over6, ) @staticmethod @@ -990,8 +983,8 @@ def split_into_quantized_tensors( columnwise_scale_inv_offsets.append(cum) self.columnwise_scale_inv_offsets = columnwise_scale_inv_offsets nvfp4_rowwise_amax_offsets = None - row_scaled_nvfp4 = self._row_scaled_nvfp4 - use_4over6 = self._use_4over6 + row_scaled_nvfp4 = self.row_scaled_nvfp4 + use_4over6 = self.use_4over6 if recipe.nvfp4() and row_scaled_nvfp4: cum = 0 nvfp4_rowwise_amax_offsets = [0] From 2c066f9acfe3a1bb93b5877e553bebe1b75e543b Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Fri, 8 May 2026 22:55:14 -0700 Subject: [PATCH 06/57] Clean up Signed-off-by: Ziang Li --- tests/cpp/operator/test_dequantize_nvfp4.cu | 2 +- tests/pytorch/test_backward_override.py | 31 +++++---- tests/pytorch/test_recipe.py | 70 ++------------------- tests/pytorch/test_sanity.py | 12 +--- tests/pytorch/utils.py | 18 ++---- 5 files changed, 27 insertions(+), 106 deletions(-) diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu index ea2ef14916..aebd05da5d 100644 --- a/tests/cpp/operator/test_dequantize_nvfp4.cu +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -48,7 +48,7 @@ void compute_ref_dequantize_nvfp4(const uint8_t *packed_data, size_t cols, size_t scale_stride, bool use_4over6) { - const float factor_inv = 1.0f / (6.0f * (use_4over6 ? 256.0f : 448.0f)); + constexpr float factor_inv = 1.0f / (6.0f * (use_4over6 ? 256.0f : 448.0f)); constexpr size_t BLOCK_SIZE = 16; const size_t Mread = cols / BLOCK_SIZE; const size_t bytes_per_block = BLOCK_SIZE / 2; diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index 2d98293dbb..5a4327b18d 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -39,14 +39,6 @@ # -------------------------- _BACKWARD_OVERRIDES = ("high_precision", "dequantized") -_NVFP4_RECIPE_NAMES = ( - "nvfp4", - "nvfp4_4over6", - "nvfp4_row_scaled", - "nvfp4_row_scaled_4over6", -) -_NVFP4_ROW_SCALED_RECIPE_NAMES = ("nvfp4_row_scaled", "nvfp4_row_scaled_4over6") -_NVFP4_4OVER6_RECIPE_NAMES = ("nvfp4_4over6", "nvfp4_row_scaled_4over6") fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) @@ -86,6 +78,11 @@ marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), id="NVFP4BlockScaling", ), + pytest.param( + "nvfp4_row_scaled", + marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), + id="NVFP4RowScaledBlockScaling", + ), pytest.param( "nvfp4_4over6", marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), @@ -183,7 +180,7 @@ def _maybe_skip_recipe_dtype( ) -> None: if dtype == torch.bfloat16 and not bf16_available: pytest.skip(reason_for_no_bf16) - if recipe_name in _NVFP4_RECIPE_NAMES: + if "nvfp4" in recipe_name: if module_type in ("linear", "layernorm_linear") and dtype not in ( torch.bfloat16, torch.float32, @@ -196,14 +193,14 @@ def _maybe_skip_recipe_dtype( def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: str) -> None: if module_type == "ops_linear" and recipe_name == "fp8_block_scaling": pytest.skip("Fusible ops (te_ops.Linear) do not support Float8BlockScaling recipe") - if module_type == "ops_linear" and recipe_name in _NVFP4_ROW_SCALED_RECIPE_NAMES: + if module_type == "ops_linear" and "nvfp4_row_scaled" in recipe_name: pytest.skip("Row-scaled NVFP4 currently does not support fused te_ops paths.") - if module_type == "grouped_linear" and recipe_name in _NVFP4_4OVER6_RECIPE_NAMES: + if module_type == "grouped_linear" and "nvfp4_4over6" in recipe_name: pytest.skip("NVFP4 4over6 currently does not support grouped quantization.") def _make_quantized_forward_reference_recipe(recipe_name: str) -> recipe.Recipe: - if recipe_name in _NVFP4_ROW_SCALED_RECIPE_NAMES: + if "nvfp4_row_scaled" in recipe_name: return make_recipe(recipe_name, backward_override="dequantized") return make_recipe(recipe_name) @@ -223,7 +220,7 @@ def _maybe_skip_unsupported_recipe_shape( " by 32." ) return - if recipe_name in _NVFP4_RECIPE_NAMES and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): + if "nvfp4" in recipe_name and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): pytest.skip( "Linear/LayerNormLinear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible" " by 16." @@ -248,7 +245,7 @@ def _maybe_skip_unsupported_recipe_shape( pytest.skip( "te_ops.Linear + MXFP8 requires prod(shape[:-1]) and shape[-1] divisible by 32." ) - if recipe_name in _NVFP4_RECIPE_NAMES and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): + if "nvfp4" in recipe_name and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): pytest.skip( "te_ops.Linear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible by 16." ) @@ -267,9 +264,9 @@ def _maybe_skip_unsupported_grouped_splits(recipe_name: str, m_splits: list[int] ) if recipe_name == "mxfp8" and any(m % 32 != 0 for m in non_empty_splits): pytest.skip("GroupedLinear + MXFP8 requires each non-empty m_split divisible by 32.") - if recipe_name in _NVFP4_RECIPE_NAMES and any(m % 16 != 0 for m in non_empty_splits): + if "nvfp4" in recipe_name and any(m % 16 != 0 for m in non_empty_splits): pytest.skip("GroupedLinear + NVFP4 requires each non-empty m_split divisible by 16.") - if recipe_name in _NVFP4_RECIPE_NAMES and any(m % 64 != 0 for m in non_empty_splits): + if "nvfp4" in recipe_name and any(m % 64 != 0 for m in non_empty_splits): pytest.skip( "GroupedLinear + NVFP4 grouped split_quantize currently requires each non-empty " "m_split divisible by 64 due to grouped amax kernel constraints." @@ -1752,7 +1749,7 @@ def test_backward_override_memory_peak_report( modes = ( ("high_precision", "dequantized") - if recipe_name in _NVFP4_ROW_SCALED_RECIPE_NAMES + if "nvfp4_row_scaled" in recipe_name else (None, "high_precision", "dequantized") ) mode_results: dict[str, dict[str, float] | str] = {} diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index aa2617a625..57305437e7 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -28,7 +28,6 @@ NVFP4BlockScalingRecipeState, _amax_and_scale_update, ) -from transformer_engine.pytorch.tensor.storage.grouped_tensor_storage import GroupedTensorStorage import transformer_engine.pytorch.ops as te_ops from transformer_engine.common.recipe import ( DelayedScaling, @@ -515,33 +514,10 @@ def test_quantizer_update(self, module_class): @pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) -def test_nvfp4_row_scaled_quantizer_roles(): - recipe = NVFP4BlockScaling(row_scaled_activation=True) - - forward_quantizers = NVFP4BlockScalingRecipeState( - recipe, - mode="forward", - num_quantizers=3, - ).make_quantizers() - assert [q.row_scaled_nvfp4 for q in forward_quantizers] == [True, False, True] - assert not forward_quantizers[0].is_quantizable(torch.empty(16, 16)) - assert forward_quantizers[1].is_quantizable(torch.empty(16, 16)) - - backward_quantizers = NVFP4BlockScalingRecipeState( - recipe, - mode="backward", - num_quantizers=2, - ).make_quantizers() - assert [q.row_scaled_nvfp4 for q in backward_quantizers] == [False, False] - - -@pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) -def test_nvfp4_4over6_quantizer_roles(): +@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) +def test_nvfp4_row_scaled_quantizer_roles(use_4over6): recipe = NVFP4BlockScaling( - disable_rht=True, - disable_stochastic_rounding=True, - disable_2d_quantization=True, - enable_4over6=True, + enable_4over6=use_4over6, row_scaled_activation=True, ) @@ -550,54 +526,18 @@ def test_nvfp4_4over6_quantizer_roles(): mode="forward", num_quantizers=3, ).make_quantizers() - assert [q.use_4over6 for q in forward_quantizers] == [True, True, True] assert [q.row_scaled_nvfp4 for q in forward_quantizers] == [True, False, True] + assert not forward_quantizers[0].is_quantizable(torch.empty(16, 16)) + assert forward_quantizers[1].is_quantizable(torch.empty(16, 16)) backward_quantizers = NVFP4BlockScalingRecipeState( recipe, mode="backward", num_quantizers=2, ).make_quantizers() - assert [q.use_4over6 for q in backward_quantizers] == [True, True] assert [q.row_scaled_nvfp4 for q in backward_quantizers] == [False, False] -@pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) -def test_nvfp4_grouped_storage_metadata(): - q = NVFP4Quantizer( - fp4_dtype=tex.DType.kFloat4E2M1, - rowwise=True, - columnwise=False, - with_rht=False, - with_post_rht_amax=False, - with_2d_quantization=False, - stochastic_rounding=False, - row_scaled_nvfp4=True, - use_4over6=True, - ) - - grouped_tensor = GroupedTensorStorage.make_grouped_tensor( - num_tensors=2, - first_dims=None, - last_dims=None, - logical_first_dim=32, - logical_last_dim=64, - quantizer=q, - device=torch.device("cuda"), - dtype=torch.bfloat16, - ) - assert grouped_tensor.row_scaled_nvfp4 - assert grouped_tensor.use_4over6 - - grouped_copy = grouped_tensor.copy() - assert grouped_copy.row_scaled_nvfp4 - assert grouped_copy.use_4over6 - - for tensor in grouped_tensor.split_into_quantized_tensors(): - assert tensor._row_scaled_nvfp4 - assert tensor._use_4over6 - - @pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 2b4799d99f..721208a85d 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -107,14 +107,8 @@ def nvfp4_4over6(): return nvfp4_recipe -def nvfp4_row_scaled_4over6(): - nvfp4_recipe = recipe.NVFP4BlockScaling( - disable_rht=True, - disable_stochastic_rounding=True, - disable_2d_quantization=True, - row_scaled_activation=True, - enable_4over6=True, - ) +def nvfp4_row_scaled(): + nvfp4_recipe = recipe.NVFP4BlockScaling(row_scaled_activation=True) nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() @@ -135,7 +129,7 @@ def nvfp4_row_scaled_4over6(): fp8_recipes.append(None) fp8_recipes_with_row_scaled = fp8_recipes.copy() if nvfp4_available: - fp8_recipes_with_row_scaled.insert(-1, nvfp4_row_scaled_4over6()) + fp8_recipes_with_row_scaled.insert(-1, nvfp4_row_scaled()) param_types = [torch.float32, torch.float16] if is_bf16_available(): # bf16 requires sm_80 or higher diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 9022b2b530..f78959aa2c 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -31,16 +31,6 @@ from transformer_engine.pytorch.module.base import get_dummy_wgrad -_NVFP4_RECIPE_NAMES = ( - "nvfp4", - "nvfp4_4over6", - "nvfp4_row_scaled", - "nvfp4_row_scaled_4over6", -) -_NVFP4_ROW_SCALED_RECIPE_NAMES = ("nvfp4_row_scaled", "nvfp4_row_scaled_4over6") -_NVFP4_4OVER6_RECIPE_NAMES = ("nvfp4_4over6", "nvfp4_row_scaled_4over6") - - def str_to_dtype(dtype: str | torch.dtype) -> torch.dtype: """Convert type name to PyTorch dtype""" if isinstance(dtype, torch.dtype): @@ -128,7 +118,7 @@ def quantization_tols(name: str) -> dict[str, float]: "mxfp8_block_scaling", ): return dtype_tols(tex.DType.kFloat8E4M3) - if name in _NVFP4_RECIPE_NAMES: + if "nvfp4" in name: return dtype_tols(tex.DType.kFloat4E2M1) raise ValueError(f"Unsupported quantization scheme ({name})") @@ -155,13 +145,13 @@ def make_recipe(name: Optional[str], **recipe_kwargs: Any) -> Optional[Recipe]: ) if name == "fp8_block_scaling": return transformer_engine.common.recipe.Float8BlockScaling(**recipe_kwargs) - if name in _NVFP4_RECIPE_NAMES: + if "nvfp4" in name: kwargs = { "disable_rht": True, "disable_stochastic_rounding": True, "disable_2d_quantization": True, - "row_scaled_activation": name in _NVFP4_ROW_SCALED_RECIPE_NAMES, - "enable_4over6": name in _NVFP4_4OVER6_RECIPE_NAMES, + "row_scaled_activation": "row_scaled" in name, + "enable_4over6": "4over6" in name, } kwargs.update(recipe_kwargs) return transformer_engine.common.recipe.NVFP4BlockScaling(**kwargs) From 69e8f3abf09b365e46b1eba5011543112fa5ca05 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Fri, 8 May 2026 23:12:04 -0700 Subject: [PATCH 07/57] Add gemm test Signed-off-by: Ziang Li --- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 21 +++++++++++++++++ .../include/transformer_engine/recipe.h | 3 +++ transformer_engine/common/recipe/nvfp4.cu | 23 +++++++++++++++---- .../custom_recipes/quantization_ref_nvfp4.py | 19 ++++++++++++++- 4 files changed, 60 insertions(+), 6 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index a7ea4f089f..3450c4bd0b 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -28,6 +28,7 @@ def check_nvfp4_gemm_versus_reference( x_columnwise: bool = False, w_columnwise: bool = False, row_scaled_nvfp4: bool = False, + use_4over6: bool = False, ): te_dtype = tex.DType.kFloat4E2M1 @@ -59,6 +60,7 @@ def check_nvfp4_gemm_versus_reference( with_rht=False, with_post_rht_amax=False, row_scaled_nvfp4=row_scaled_nvfp4, + use_4over6=use_4over6, ) w_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -68,6 +70,7 @@ def check_nvfp4_gemm_versus_reference( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, + use_4over6=use_4over6, ) # Quantize x and w @@ -123,6 +126,7 @@ def check_nvfp4_gemm_versus_reference( eps=0.0, quant_tile_shape=(1, 16), row_scaled_nvfp4=row_scaled_nvfp4, + use_4over6=use_4over6, ) w_ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, @@ -131,6 +135,7 @@ def check_nvfp4_gemm_versus_reference( pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), + use_4over6=use_4over6, ) # Create reference quantized tensors needed by reference GEMM @@ -232,6 +237,7 @@ def check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( *, use_bias: bool, single_output: bool, + use_4over6: bool = False, ): te_dtype = tex.DType.kFloat4E2M1 device = "cuda" @@ -249,6 +255,7 @@ def check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( with_rht=False, with_post_rht_amax=False, row_scaled_nvfp4=True, + use_4over6=use_4over6, ) w_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -258,6 +265,7 @@ def check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, + use_4over6=use_4over6, ) x_nvfp4 = [] @@ -321,6 +329,7 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( M: int, K: int, N: int, + use_4over6: bool = False, ): te_dtype = tex.DType.kFloat4E2M1 device = "cuda" @@ -339,6 +348,7 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( with_rht=False, with_post_rht_amax=False, row_scaled_nvfp4=True, + use_4over6=use_4over6, ) x_tensorwise_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -348,6 +358,7 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, + use_4over6=use_4over6, ) w_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -357,6 +368,7 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, + use_4over6=use_4over6, ) x_row_scaled = x_row_scaled_quantizer.update_quantized( @@ -417,6 +429,7 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( ids=["rowxrow", "colxrow", "colxcol"], ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) +@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) def test_nvfp4_gemm_versus_reference( M: int, K: int, @@ -428,6 +441,7 @@ def test_nvfp4_gemm_versus_reference( is_x_columnwise: bool, is_w_columnwise: bool, row_scaled_nvfp4: bool, + use_4over6: bool, ): if row_scaled_nvfp4: if accumulate: @@ -446,6 +460,7 @@ def test_nvfp4_gemm_versus_reference( x_columnwise=is_x_columnwise, w_columnwise=is_w_columnwise, row_scaled_nvfp4=row_scaled_nvfp4, + use_4over6=use_4over6, ) @@ -471,6 +486,7 @@ def test_nvfp4_gemm_versus_reference( @pytest.mark.parametrize("out_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("use_bias", [False, True], ids=["no_bias", "bias"]) @pytest.mark.parametrize("single_output", [False, True], ids=["list_output", "single_output"]) +@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) def test_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( m_splits: list[int], k: int, @@ -480,6 +496,7 @@ def test_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( out_dtype: torch.dtype, use_bias: bool, single_output: bool, + use_4over6: bool, ): check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( x_dtype=x_dtype, @@ -490,6 +507,7 @@ def test_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( n=n, use_bias=use_bias, single_output=single_output, + use_4over6=use_4over6, ) @@ -513,6 +531,7 @@ def test_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( @pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("w_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) def test_nvfp4_row_scaled_gemm_matches_emulated( M: int, K: int, @@ -520,6 +539,7 @@ def test_nvfp4_row_scaled_gemm_matches_emulated( x_dtype: torch.dtype, w_dtype: torch.dtype, out_dtype: torch.dtype, + use_4over6: bool, ): check_nvfp4_row_scaled_gemm_matches_emulated( x_dtype=x_dtype, @@ -528,4 +548,5 @@ def test_nvfp4_row_scaled_gemm_matches_emulated( M=M, K=K, N=N, + use_4over6=use_4over6, ) diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index cad27a2992..fe729d3b30 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -304,6 +304,9 @@ void nvte_mxfp8_scaling_partial_cast(const NVTETensor input, NVTETensor output_r * \param[in] alpha_in Input scaling factor. * \param[out] alpha_out Output scaling factor. * \param[in] stream CUDA stream used for the operation. + * + * Uses each NVFP4 tensor's 4over6 metadata to choose the matching FP8 max + * when folding global amax values into the GEMM alpha. */ void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_rowwise_amax_A, const NVTETensor inpB, const bool use_rowwise_amax_B, diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu index 1c419d4f8c..9c1397be4f 100644 --- a/transformer_engine/common/recipe/nvfp4.cu +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -65,15 +65,15 @@ namespace nvfp4_recipe { * --------------------------------------------------------------------------- */ -// constexpr float factor = 6.0 * 6.0 * 448.0 * 448.0; -constexpr float factor_inv = 1.0 / (6.0 * 6.0 * 448.0 * 448.0); constexpr int kTileDim = 16; constexpr int kThreadsPerBlock = 256; // Kernel to compute alpha *= amax_A * amax_B / factor __global__ void compute_nvfp4_per_tensor_scale_kernel(float alpha_in, const float *amax_A, - const float *amax_B, float *alpha_out) { - // factor is defined in the enclosing namespace + const float *amax_B, float fp8_max_A, + float fp8_max_B, float *alpha_out) { + constexpr float fp4_max = 6.0f; + const float factor_inv = 1.0f / (fp4_max * fp4_max * fp8_max_A * fp8_max_B); *alpha_out = alpha_in * (*amax_A) * (*amax_B) * factor_inv; } @@ -924,6 +924,18 @@ void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_r void *amax_A_ptr = use_rowwise_amax_A ? tA->amax.dptr : tA->columnwise_amax.dptr; void *amax_B_ptr = use_rowwise_amax_B ? tB->amax.dptr : tB->columnwise_amax.dptr; void *alpha_ptr = tOut->data.dptr; + float fp8_max_A; + if (tA->nvfp4_4over6) { + fp8_max_A = 256.0f; + } else { + fp8_max_A = 448.0f; + } + float fp8_max_B; + if (tB->nvfp4_4over6) { + fp8_max_B = 256.0f; + } else { + fp8_max_B = 448.0f; + } // check for not null pointers NVTE_CHECK(amax_A_ptr != nullptr, "amax_A_ptr is null"); @@ -932,7 +944,8 @@ void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_r nvfp4_recipe::compute_nvfp4_per_tensor_scale_kernel<<<1, 1, 0, stream>>>( alpha_in, reinterpret_cast(amax_A_ptr), - reinterpret_cast(amax_B_ptr), reinterpret_cast(alpha_ptr)); + reinterpret_cast(amax_B_ptr), fp8_max_A, fp8_max_B, + reinterpret_cast(alpha_ptr)); NVTE_CHECK_CUDA(cudaGetLastError()); #else NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); diff --git a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py index f0a07b4e4e..3f76037a6a 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py @@ -221,6 +221,7 @@ class NVFP4TensorRef(QuantizedTensorStorage): scale_t: Optional[torch.Tensor] = None global_amax_row: Optional[torch.Tensor] = None global_amax_col: Optional[torch.Tensor] = None + use_4over6: bool = False dtype: Optional[torch.dtype] = None device: Optional[torch.device] = None @@ -839,6 +840,7 @@ def quantize( scale_t=sx_t, global_amax_row=global_amax_row, global_amax_col=global_amax_col, + use_4over6=self.use_4over6, dtype=tensor.dtype, device=tensor.device, quant_dtype=self.dtype, @@ -886,6 +888,7 @@ def update_quantized( dst.scale_t = sx_t dst.global_amax_row = global_amax_row dst.global_amax_col = global_amax_col + dst.use_4over6 = self.use_4over6 dst.dtype = src.dtype dst.quant_dtype = self.dtype dst.original_shape = original_shape @@ -991,7 +994,21 @@ def qgemm( sx = sx.to(torch.float32) sw = sw.to(torch.float32) - factor = 6.0 * 6.0 * 448.0 * 448.0 + qresult_x_use_4over6 = getattr( + qresult_x, "use_4over6", getattr(qresult_x, "_use_4over6", self.use_4over6) + ) + qresult_w_use_4over6 = getattr( + qresult_w, "use_4over6", getattr(qresult_w, "_use_4over6", self.use_4over6) + ) + if qresult_x_use_4over6: + fp8_max_x = 256.0 + else: + fp8_max_x = 448.0 + if qresult_w_use_4over6: + fp8_max_w = 256.0 + else: + fp8_max_w = 448.0 + factor = 6.0 * 6.0 * fp8_max_x * fp8_max_w if gemm_type == quantization.GEMMType.WGRAD: partial_alpha = qresult_x.global_amax_col * qresult_w.global_amax_col From 009e65183e5602a024b75b697472a14f4912116d Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 9 May 2026 00:38:16 -0700 Subject: [PATCH 08/57] Add more tests and fix offload Signed-off-by: Ziang Li --- tests/pytorch/test_cpu_offloading.py | 46 ++++++++++++++++++- tests/pytorch/test_sanity.py | 15 ++++-- tests/pytorch/test_torch_compile.py | 30 +++++++++++- tests/pytorch/utils.py | 7 +++ transformer_engine/pytorch/cpu_offload.py | 25 ++++++++-- .../pytorch/csrc/extensions/cast.cpp | 23 ++++++++-- 6 files changed, 129 insertions(+), 17 deletions(-) diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index 50196782f2..a09c8f6a6f 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -28,6 +28,33 @@ mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available() nvfp4_available, _ = FP8GlobalStateManager.is_nvfp4_available() + +def nvfp4_4over6(): + nvfp4_recipe = recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + enable_4over6=True, + ) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + +def nvfp4_row_scaled(): + nvfp4_recipe = recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + row_scaled_activation=True, + ) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + quantization_recipes: List[Optional[recipe.Recipe]] = [None] if fp8_available: quantization_recipes.extend((recipe.Float8CurrentScaling(), recipe.DelayedScaling())) @@ -37,6 +64,8 @@ quantization_recipes.append(recipe.MXFP8BlockScaling()) if nvfp4_available: quantization_recipes.append(recipe.NVFP4BlockScaling()) + quantization_recipes.append(nvfp4_4over6()) + quantization_recipes.append(nvfp4_row_scaled()) model_config = { @@ -176,7 +205,17 @@ def create_tensor(recipe: Optional[recipe.Recipe], requires_grad: bool = False) quantizer = te.tensor.mxfp8_tensor.MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) return quantizer(tensor) elif recipe.nvfp4(): - quantizer = te.tensor.nvfp4_tensor.NVFP4Quantizer() + qparams = recipe.fp4_quant_fwd_inp + quantizer = te.tensor.nvfp4_tensor.NVFP4Quantizer( + rowwise=True, + columnwise=not recipe.row_scaled_activation, + with_rht=qparams.random_hadamard_transform, + with_post_rht_amax=qparams.random_hadamard_transform, + with_2d_quantization=qparams.fp4_2d_quantization, + stochastic_rounding=qparams.stochastic_rounding, + row_scaled_nvfp4=recipe.row_scaled_activation, + use_4over6=recipe.enable_4over6, + ) return quantizer(tensor) @staticmethod @@ -191,7 +230,10 @@ def get_tensor_size_mb(tensor): if tensor is None: return 0 if isinstance(tensor, te.quantized_tensor.QuantizedTensorStorage): - return sum(Utils.get_tensor_size_mb(t) for t in tensor.get_data_tensors()) + tensors = [ + value for value in tensor.get_metadata().values() if isinstance(value, torch.Tensor) + ] + return sum(Utils.get_tensor_size_mb(t) for t in tensors) else: return tensor.numel() * tensor.element_size() / (1024**2) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 721208a85d..b4e0bb3dba 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -108,7 +108,12 @@ def nvfp4_4over6(): def nvfp4_row_scaled(): - nvfp4_recipe = recipe.NVFP4BlockScaling(row_scaled_activation=True) + nvfp4_recipe = recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + row_scaled_activation=True, + ) nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() @@ -609,12 +614,12 @@ def test_sanity_grouped_linear( single_grouped_bias=single_param, ).cuda() - # Verify grouped linear exposes a single grouped weight parameter(and bias when applicable). + # Verify grouped linear exposes grouped params when the experimental mode is enabled. if fp8_recipe is None or not (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling()): - if single_param: + if te_grouped_linear.single_grouped_weight: check_grouped_weight(te_grouped_linear, num_gemms, ffn_hidden_size, config.hidden_size) - if use_bias: - check_grouped_bias(te_grouped_linear, num_gemms, ffn_hidden_size) + if use_bias and te_grouped_linear.single_grouped_bias: + check_grouped_bias(te_grouped_linear, num_gemms, ffn_hidden_size) inp_hidden_states = torch.randn( num_tokens, config.hidden_size, dtype=dtype, requires_grad=True diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py index 51f72b1e56..c9c163e5de 100644 --- a/tests/pytorch/test_torch_compile.py +++ b/tests/pytorch/test_torch_compile.py @@ -39,6 +39,33 @@ fp8_block_scaling_available = is_fp8_block_scaling_available() nvfp4_available = is_nvfp4_available() + +def nvfp4_4over6(): + nvfp4_recipe = recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + enable_4over6=True, + ) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + +def nvfp4_row_scaled(): + nvfp4_recipe = recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + row_scaled_activation=True, + ) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + _all_recipes: list = [] if fp8_available: _all_recipes.append(recipe.Float8CurrentScaling()) @@ -48,7 +75,8 @@ _all_recipes.append(recipe.MXFP8BlockScaling()) if nvfp4_available: _all_recipes.append(recipe.NVFP4BlockScaling()) - _all_recipes.append(recipe.NVFP4BlockScaling(row_scaled_activation=True)) + _all_recipes.append(nvfp4_4over6()) + _all_recipes.append(nvfp4_row_scaled()) # --------------------------------------------------------------------------- diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index f78959aa2c..0439112d31 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -184,6 +184,13 @@ def skip_unsupported_backward_override( and backward_override is None ): pytest.skip("Row-scaled NVFP4 does not support default quantized backward.") + if ( + quant_recipe is not None + and quant_recipe.nvfp4() + and getattr(quant_recipe, "enable_4over6", False) + and layer_type == "grouped_linear" + ): + pytest.skip("NVFP4 4over6 currently does not support grouped quantization.") if backward_override is None: return if quant_recipe is None and backward_override is not None: diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index c81c18e64f..371856b15c 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -363,7 +363,9 @@ def start_reload(self): self.bwd_gpu_tensor_group ) - def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor | tuple[list, list]: + def push_tensor( + self, tensor: torch.Tensor, *, ignore_size_threshold: bool = False + ) -> int | torch.Tensor | tuple[list, list]: """ It is called when a tensor is saved for backward pass. @@ -373,16 +375,29 @@ def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor | tuple[list, """ self._validate_state(func_name="push_tensor", allowed_states=["not_offloaded"]) - if self._check_if_offload(tensor): + if self._check_if_offload(tensor, ignore_size_threshold=ignore_size_threshold): # For QuantizedTensor: decompose into component tensors, push each one recursively if isinstance(tensor, QuantizedTensor): # Make a copy because prepare_for_saving modifies the object (sets fields to None) tensor_copy = tensor.detach() + force_offload_tensor_ids = set() + if getattr(tensor_copy, "_row_scaled_nvfp4", False): + amax_rowwise = getattr(tensor_copy, "_amax_rowwise", None) + if amax_rowwise is not None: + force_offload_tensor_ids.add(id(amax_rowwise)) # Inline prepare_for_saving logic - QuantizedTensor is a torch.Tensor subclass, # so the generic prepare_for_saving would not call tensor.prepare_for_saving() saved_tensors, tensor_obj = tensor_copy.prepare_for_saving() push_results = [ - self.push_tensor(t) if t is not None else None for t in saved_tensors + ( + self.push_tensor( + t, + ignore_size_threshold=id(t) in force_offload_tensor_ids, + ) + if t is not None + else None + ) + for t in saved_tensors ] return (push_results, [tensor_obj]) @@ -451,12 +466,12 @@ def release_all_memory(self): self.bwd_gpu_tensor_group = TensorGroup() self.state = "not_offloaded" - def _check_if_offload(self, t: torch.Tensor) -> bool: + def _check_if_offload(self, t: torch.Tensor, *, ignore_size_threshold: bool = False) -> bool: """ Check if tensor needs to be offloaded. """ # Only offload tensors with at least 256k elements (~1MB for float32) - if t.numel() < 256 * 1024: + if not ignore_size_threshold and t.numel() < 256 * 1024: return False if ( diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index d4a27c61b4..4c781a58ea 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -869,10 +869,25 @@ std::tuple, std::vector, bool> bulk_alloc py::object amax_columnwise = columnwise_usage ? py::cast(amax_columnwise_list[i]) : py::none(); // Construct Python tensor - tensor_py_list.emplace_back( - NVFP4TensorClass(rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, - amax_rowwise, amax_columnwise, fp4_dtype, quantizer_py_list[i], - with_gemm_swizzled_scales, row_scaled_nvfp4, use_4over6)); + py::dict kwargs; + kwargs["rowwise_data"] = rowwise_data; + kwargs["rowwise_scale_inv"] = rowwise_scale; + kwargs["columnwise_data"] = columnwise_data; + kwargs["columnwise_scale_inv"] = columnwise_scale; + kwargs["amax_rowwise"] = amax_rowwise; + kwargs["amax_columnwise"] = amax_columnwise; + kwargs["fp4_dtype"] = py::cast(fp4_dtype); + kwargs["quantizer"] = py::reinterpret_borrow(quantizer_py_list[i]); + kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); + kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); + kwargs["use_4over6"] = py::cast(use_4over6); + py::tuple args(0); + PyObject *tensor_py = PyObject_Call(NVFP4TensorClass.ptr(), args.ptr(), kwargs.ptr()); + if (tensor_py == nullptr) { + PyErr_Print(); + } + NVTE_CHECK(tensor_py != nullptr, "Failed to create NVFP4TensorStorage instance"); + tensor_py_list.emplace_back(py::reinterpret_steal(tensor_py)); // Construct C++ tensor // Use a TensorWrapper variable to hold the output of makeTransformerEngineTensor, From 3153fc3ebcc078b869ebf7185c998d64fba09eb4 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 9 May 2026 00:55:29 -0700 Subject: [PATCH 09/57] Fix offload Signed-off-by: Ziang Li --- tests/pytorch/test_cpu_offloading.py | 47 +++++++++++++++++++++-- transformer_engine/pytorch/cpu_offload.py | 25 +++--------- 2 files changed, 49 insertions(+), 23 deletions(-) diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index a09c8f6a6f..d30d1c596d 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -237,6 +237,21 @@ def get_tensor_size_mb(tensor): else: return tensor.numel() * tensor.element_size() / (1024**2) + @staticmethod + def get_saved_tensor_gpu_size_mb(tensor): + if tensor is None or isinstance(tensor, int): + return 0 + if isinstance(tensor, tuple): + push_results, _ = tensor + return Utils.get_saved_tensor_gpu_size_mb(push_results) + if isinstance(tensor, list): + return sum(Utils.get_saved_tensor_gpu_size_mb(t) for t in tensor) + return Utils.get_tensor_size_mb(tensor) + + @staticmethod + def keeps_small_nvfp4_row_amax_on_gpu(recipe: Optional[recipe.Recipe]): + return recipe is not None and recipe.nvfp4() and recipe.row_scaled_activation + @staticmethod def memory_leak_check(): # Should be called before each test. @@ -294,6 +309,25 @@ def test_general(self, random_num_tensors, recipe): offload_layer_state.release_all_memory() torch.cuda.synchronize() + @pytest.mark.skipif(not nvfp4_available, reason="NVFP4 requires Blackwell") + def test_nvfp4_row_scaled_amax_stays_on_gpu(self): + Utils.memory_leak_check() + stream = torch.cuda.Stream() + offload_layer_state = OffloadableLayerState( + offload_stream=stream, + ) + tensor = Utils.create_tensor(nvfp4_row_scaled()) + tensor_id = offload_layer_state.push_tensor(tensor) + assert isinstance(tensor_id, tuple) + push_results, _ = tensor_id + assert isinstance(push_results[0], int) + assert isinstance(push_results[4], torch.Tensor) + assert push_results[4].device.type == "cuda" + assert push_results[4].numel() < 256 * 1024 + del tensor, tensor_id + offload_layer_state.release_all_memory() + torch.cuda.synchronize() + def test_offload_base_tensor(self): Utils.memory_leak_check() stream = torch.cuda.Stream() @@ -405,11 +439,16 @@ def test_memory(self, recipe): del tensor, tensor_id torch.cuda.synchronize() + resident_gpu_size = sum( + Utils.get_saved_tensor_gpu_size_mb(tensor_id) for tensor_id in tensor_ids + ) if recipe is None: assert Utils.get_max_cuda_memory_mb() == pytest.approx( - init_cuda_memory + tensor_size, 0.1 + init_cuda_memory + resident_gpu_size, 0.1 ) - assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory + tensor_size, 0.1) + assert Utils.get_cuda_memory_mb() == pytest.approx( + init_cuda_memory + resident_gpu_size, 0.1 + ) for i in range(NUM_LAYERS - 1, -1, -1): offload_synchronizer.bwd_step(i) @@ -578,7 +617,9 @@ def test_memory(self, layer_type, recipe, backward_override): out = out + 1 out = sync_function(out) del inp - if backward_override is None: + if Utils.keeps_small_nvfp4_row_amax_on_gpu(recipe): + assert Utils.get_cuda_memory_mb() <= cuda_memory_no_offload + elif backward_override is None: assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory, 0.1) else: assert ( diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 371856b15c..c81c18e64f 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -363,9 +363,7 @@ def start_reload(self): self.bwd_gpu_tensor_group ) - def push_tensor( - self, tensor: torch.Tensor, *, ignore_size_threshold: bool = False - ) -> int | torch.Tensor | tuple[list, list]: + def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor | tuple[list, list]: """ It is called when a tensor is saved for backward pass. @@ -375,29 +373,16 @@ def push_tensor( """ self._validate_state(func_name="push_tensor", allowed_states=["not_offloaded"]) - if self._check_if_offload(tensor, ignore_size_threshold=ignore_size_threshold): + if self._check_if_offload(tensor): # For QuantizedTensor: decompose into component tensors, push each one recursively if isinstance(tensor, QuantizedTensor): # Make a copy because prepare_for_saving modifies the object (sets fields to None) tensor_copy = tensor.detach() - force_offload_tensor_ids = set() - if getattr(tensor_copy, "_row_scaled_nvfp4", False): - amax_rowwise = getattr(tensor_copy, "_amax_rowwise", None) - if amax_rowwise is not None: - force_offload_tensor_ids.add(id(amax_rowwise)) # Inline prepare_for_saving logic - QuantizedTensor is a torch.Tensor subclass, # so the generic prepare_for_saving would not call tensor.prepare_for_saving() saved_tensors, tensor_obj = tensor_copy.prepare_for_saving() push_results = [ - ( - self.push_tensor( - t, - ignore_size_threshold=id(t) in force_offload_tensor_ids, - ) - if t is not None - else None - ) - for t in saved_tensors + self.push_tensor(t) if t is not None else None for t in saved_tensors ] return (push_results, [tensor_obj]) @@ -466,12 +451,12 @@ def release_all_memory(self): self.bwd_gpu_tensor_group = TensorGroup() self.state = "not_offloaded" - def _check_if_offload(self, t: torch.Tensor, *, ignore_size_threshold: bool = False) -> bool: + def _check_if_offload(self, t: torch.Tensor) -> bool: """ Check if tensor needs to be offloaded. """ # Only offload tensors with at least 256k elements (~1MB for float32) - if not ignore_size_threshold and t.numel() < 256 * 1024: + if t.numel() < 256 * 1024: return False if ( From e31b758b389a0a66a5863bf2e1e21d063df9c158 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 9 May 2026 01:03:55 -0700 Subject: [PATCH 10/57] Clean up arg Signed-off-by: Ziang Li --- .../pytorch/csrc/extensions/cast.cpp | 23 ++++--------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 4c781a58ea..ceb654076d 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -869,25 +869,10 @@ std::tuple, std::vector, bool> bulk_alloc py::object amax_columnwise = columnwise_usage ? py::cast(amax_columnwise_list[i]) : py::none(); // Construct Python tensor - py::dict kwargs; - kwargs["rowwise_data"] = rowwise_data; - kwargs["rowwise_scale_inv"] = rowwise_scale; - kwargs["columnwise_data"] = columnwise_data; - kwargs["columnwise_scale_inv"] = columnwise_scale; - kwargs["amax_rowwise"] = amax_rowwise; - kwargs["amax_columnwise"] = amax_columnwise; - kwargs["fp4_dtype"] = py::cast(fp4_dtype); - kwargs["quantizer"] = py::reinterpret_borrow(quantizer_py_list[i]); - kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); - kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); - kwargs["use_4over6"] = py::cast(use_4over6); - py::tuple args(0); - PyObject *tensor_py = PyObject_Call(NVFP4TensorClass.ptr(), args.ptr(), kwargs.ptr()); - if (tensor_py == nullptr) { - PyErr_Print(); - } - NVTE_CHECK(tensor_py != nullptr, "Failed to create NVFP4TensorStorage instance"); - tensor_py_list.emplace_back(py::reinterpret_steal(tensor_py)); + tensor_py_list.emplace_back(NVFP4TensorClass( + rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, amax_rowwise, + amax_columnwise, fp4_dtype, quantizer_py_list[i], with_gemm_swizzled_scales, + py::arg("row_scaled_nvfp4") = row_scaled_nvfp4, py::arg("use_4over6") = use_4over6)); // Construct C++ tensor // Use a TensorWrapper variable to hold the output of makeTransformerEngineTensor, From fcd526ca9a24112f3bacea78e5872c3cc7e3a7ab Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 9 May 2026 01:44:51 -0700 Subject: [PATCH 11/57] Add more test Signed-off-by: Ziang Li --- tests/pytorch/test_cuda_graphs.py | 14 ++++++-- tests/pytorch/test_fusible_ops.py | 32 ++++++++++++----- tests/pytorch/test_numerics.py | 14 ++++++++ tests/pytorch/test_quantized_tensor.py | 21 ++++++++---- tests/pytorch/test_sanity.py | 27 +++++++++++---- .../pytorch/tensor/nvfp4_tensor.py | 34 ++++++++++++++++++- 6 files changed, 117 insertions(+), 25 deletions(-) diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 7ed8ebdb22..7df90a6b9a 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -64,12 +64,19 @@ def nvfp4_rht_and_2d_quantization(): return nvfp4_recipe -def nvfp4_row_scaled_4over6(): +def nvfp4_row_scaled(): + nvfp4_recipe = recipe.NVFP4BlockScaling(row_scaled_activation=True) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + +def nvfp4_4over6(): nvfp4_recipe = recipe.NVFP4BlockScaling( disable_rht=True, disable_stochastic_rounding=True, disable_2d_quantization=True, - row_scaled_activation=True, enable_4over6=True, ) nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() @@ -106,7 +113,8 @@ def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> fp8_recipes.append(recipe.MXFP8BlockScaling()) if nvfp4_available: fp8_recipes.append(nvfp4_rht_and_2d_quantization()) - fp8_recipes.append(nvfp4_row_scaled_4over6()) + fp8_recipes.append(nvfp4_row_scaled()) + fp8_recipes.append(nvfp4_4over6()) if fp8_block_scaling_available: fp8_recipes.append(recipe.Float8BlockScaling()) if fp8_available: diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 7691582f97..3c99b65360 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -77,6 +77,17 @@ _quantization_list.append("mxfp8") if nvfp4_available: _quantization_list.append("nvfp4") + _quantization_list.append("nvfp4_4over6") + + +def is_nvfp4_quantization(quantization: Optional[str]) -> bool: + """Whether a quantization recipe uses NVFP4.""" + return quantization is not None and "nvfp4" in quantization + + +def is_nvfp4_4over6_quantization(quantization: Optional[str]) -> bool: + """Whether a quantization recipe uses NVFP4 4over6.""" + return quantization is not None and "4over6" in quantization @pytest.fixture(autouse=True, scope="class") @@ -106,7 +117,7 @@ def maybe_skip_quantization( pytest.skip(reason_for_no_fp8) if quantization == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if quantization == "nvfp4" and not nvfp4_available: + if is_nvfp4_quantization(quantization) and not nvfp4_available: pytest.skip(reason_for_no_nvfp4) # Check dims @@ -119,13 +130,13 @@ def maybe_skip_quantization( elif quantization == "mxfp8": if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0: pytest.skip("MXFP8 GEMMs require dims that are divisible by 32") - elif quantization == "nvfp4": + elif is_nvfp4_quantization(quantization): if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0: pytest.skip("NVFP4 GEMMs require dims that are divisible by 16") # Check dtype if dtype is not None: - if quantization == "nvfp4" and dtype != torch.bfloat16: + if is_nvfp4_quantization(quantization) and dtype != torch.bfloat16: pytest.skip("NVFP4 quantization is only supported with BF16 data") @@ -180,13 +191,14 @@ def make_reference_and_test_tensors( test = quantizer(test) elif quantization == "mxfp8": test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) - elif quantization == "nvfp4": + elif is_nvfp4_quantization(quantization): test = NVFP4Quantizer( with_rht=False, with_post_rht_amax=False, with_2d_quantization=False, stochastic_rounding=False, with_random_sign_mask=False, + use_4over6=is_nvfp4_4over6_quantization(quantization), )(test) else: raise ValueError(f"Unsupported quantization scheme ({quantization})") @@ -1512,7 +1524,7 @@ def test_add_extra_input( if in_place: if quantization in ("fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"): tols = dtype_tols(x1_test._fp8_dtype) - elif quantization == "nvfp4": + elif is_nvfp4_quantization(quantization): tols = dtype_tols(x1_test._fp4_dtype) y_test = y_test.to(dtype=torch.float64, device="cpu") dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu") @@ -1883,7 +1895,7 @@ def test_clamped_swiglu( # Expected numerical error tols = dtype_tols(dtype) - if quantized_compute and quantization == "nvfp4": + if quantized_compute and is_nvfp4_quantization(quantization): tols = dtype_tols(tex.DType.kFloat4E2M1) elif quantized_compute: tols = dtype_tols(tex.DType.kFloat8E4M3) @@ -2076,6 +2088,8 @@ def test_grouped_linear( pytest.skip("Quantization scheme is not used") if quantization is not None and dtype not in (torch.bfloat16, torch.float16): pytest.skip("Quantized group GEMM is only supported with BF16/FP16") + if is_nvfp4_4over6_quantization(quantization): + pytest.skip("NVFP4 4over6 grouped quantization is not supported") if single_grouped_bias and not bias: pytest.skip("single_grouped_bias requires bias=True") @@ -3608,7 +3622,9 @@ def test_grouped_mlp( pytest.skip("single_grouped_bias requires bias=True") if with_quantization and dtype not in (torch.bfloat16, torch.float16): pytest.skip("Quantized group GEMM is only supported with BF16/FP16") - if quantization == "nvfp4" and activation == "scaled_clamped_qgeglu" and bias: + if is_nvfp4_4over6_quantization(quantization): + pytest.skip("NVFP4 4over6 grouped quantization is not supported") + if is_nvfp4_quantization(quantization) and activation == "scaled_clamped_qgeglu" and bias: # TODO: ksivaman: Need to debug numerics for this case. pytest.skip("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU") @@ -3830,7 +3846,7 @@ def test_grouped_mlp( # Loose tols for sanity checking tols = {"rtol": 0.125, "atol": 0.25} - if quantization == "nvfp4": + if is_nvfp4_quantization(quantization): tols = {"rtol": 0.25, "atol": 0.5} # Check values diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index a718ea2a8a..357f7b294d 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -138,6 +138,19 @@ def nvfp4_rht_and_2d_quantization(): return nvfp4_recipe +def nvfp4_4over6(): + nvfp4_recipe = recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + enable_4over6=True, + ) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + def check_rht_usage(recipe: recipe.Recipe) -> bool: # if using RHT, we can only support bf16 # check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad @@ -171,6 +184,7 @@ def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> fp8_recipes.append(recipe.DelayedScaling()) if nvfp4_available: fp8_recipes.append(nvfp4_rht_and_2d_quantization()) + fp8_recipes.append(nvfp4_4over6()) use_cutlass_grouped_gemm = [False] # Only enable cutlass grouped gemm on Hopper diff --git a/tests/pytorch/test_quantized_tensor.py b/tests/pytorch/test_quantized_tensor.py index 526045e43e..94f5dd040c 100644 --- a/tests/pytorch/test_quantized_tensor.py +++ b/tests/pytorch/test_quantized_tensor.py @@ -28,7 +28,7 @@ import transformer_engine_torch as tex from references.ref_per_tensor_cs import ref_per_tensor_cs_cast -from utils import assert_close, quantization_tols +from utils import assert_close # PyTorch tensor dtypes _dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16] @@ -69,6 +69,8 @@ def _to_list(x: Union[Iterable, Any]) -> List: _quantization_list.append("mxfp8") if nvfp4_available: _quantization_list.append("nvfp4") + _quantization_list.append("nvfp4_row_scaled") + _quantization_list.append("nvfp4_4over6") # delayed scaling @@ -163,13 +165,17 @@ def make_reference_and_test_tensors( test = quantizer(test) elif quantization == "mxfp8": test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) - elif quantization == "nvfp4": + elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): + row_scaled_nvfp4 = quantization == "nvfp4_row_scaled" test = NVFP4Quantizer( + columnwise=not row_scaled_nvfp4, with_rht=False, with_post_rht_amax=False, with_2d_quantization=False, stochastic_rounding=False, + row_scaled_nvfp4=row_scaled_nvfp4, with_random_sign_mask=False, + use_4over6=(quantization == "nvfp4_4over6"), )(test) else: raise ValueError(f"Unsupported quantization scheme ({quantization})") @@ -735,13 +741,16 @@ def test_update_nd_tensor( ) elif quantization == "mxfp8": quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) - elif quantization in ("nvfp4", "nvfp4_2d"): + elif quantization in ("nvfp4", "nvfp4_2d", "nvfp4_row_scaled", "nvfp4_4over6"): + row_scaled_nvfp4 = quantization == "nvfp4_row_scaled" quantizer = NVFP4Quantizer( rowwise=True, - columnwise=True, + columnwise=not row_scaled_nvfp4, with_rht=False, with_post_rht_amax=False, with_2d_quantization=(quantization == "nvfp4_2d"), + row_scaled_nvfp4=row_scaled_nvfp4, + use_4over6=(quantization == "nvfp4_4over6"), ) quantization = "nvfp4" else: @@ -756,9 +765,9 @@ def test_update_nd_tensor( q_x.copy_(x_new) # Check results + q_ref = quantizer(x_new) assert q_x.shape == torch.Size(shape) - tols = quantization_tols(quantization) - assert_close(q_x, x_new, **tols) + assert_close(q_x, q_ref, rtol=0, atol=0) @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index b4e0bb3dba..1beac7e829 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -132,9 +132,6 @@ def nvfp4_row_scaled(): fp8_recipes.append(recipe.Float8CurrentScaling()) fp8_recipes.append(recipe.DelayedScaling()) fp8_recipes.append(None) -fp8_recipes_with_row_scaled = fp8_recipes.copy() -if nvfp4_available: - fp8_recipes_with_row_scaled.insert(-1, nvfp4_row_scaled()) param_types = [torch.float32, torch.float16] if is_bf16_available(): # bf16 requires sm_80 or higher @@ -434,7 +431,11 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_row_scaled, ids=recipe_id) +@pytest.mark.parametrize( + "fp8_recipe", + fp8_recipes + ([nvfp4_row_scaled()] if nvfp4_available else []), + ids=recipe_id, +) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @@ -482,7 +483,11 @@ def test_sanity_layernorm_linear( @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_row_scaled, ids=recipe_id) +@pytest.mark.parametrize( + "fp8_recipe", + fp8_recipes + ([nvfp4_row_scaled()] if nvfp4_available else []), + ids=recipe_id, +) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @@ -520,7 +525,11 @@ def test_sanity_linear( @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes_with_zero) @pytest.mark.parametrize("model", ["small", "weird"]) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_row_scaled, ids=recipe_id) +@pytest.mark.parametrize( + "fp8_recipe", + fp8_recipes + ([nvfp4_row_scaled()] if nvfp4_available else []), + ids=recipe_id, +) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) @@ -561,7 +570,11 @@ def test_sanity_linear_with_zero_tokens( @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes_with_zero) @pytest.mark.parametrize("model", ["small", "weird"]) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_row_scaled, ids=recipe_id) +@pytest.mark.parametrize( + "fp8_recipe", + fp8_recipes + ([nvfp4_row_scaled()] if nvfp4_available else []), + ids=recipe_id, +) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 1cb906d61f..2d0ed5ba89 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -535,6 +535,8 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m columnwise_usage, self._amax_rowwise, self._amax_columnwise, + self._row_scaled_nvfp4, + self._use_4over6, self.shape[-1], ) return sharded_tensors, metadata @@ -553,7 +555,15 @@ def fsdp_post_all_gather( all-gathered rowwise data. Columnwise data is derived locally via _create_columnwise() instead of being all-gathered. """ - fp4_dtype, columnwise_usage, amax_rowwise, amax_columnwise, K = metadata + ( + fp4_dtype, + columnwise_usage, + amax_rowwise, + amax_columnwise, + row_scaled_nvfp4, + use_4over6, + K, + ) = metadata # Only rowwise data+scales were all-gathered rowwise_data, rowwise_scale_inv = all_gather_outputs[:2] @@ -576,6 +586,8 @@ def fsdp_post_all_gather( out._rowwise_scale_inv = rowwise_scale_inv out._amax_rowwise = amax_rowwise out._amax_columnwise = amax_columnwise + out._row_scaled_nvfp4 = row_scaled_nvfp4 + out._use_4over6 = use_4over6 else: # Construct new tensor (first iteration) out = NVFP4Tensor( @@ -591,6 +603,8 @@ def fsdp_post_all_gather( quantizer=self._quantizer, requires_grad=False, with_gemm_swizzled_scales=False, + row_scaled_nvfp4=row_scaled_nvfp4, + use_4over6=use_4over6, ) # Derive columnwise data locally via transpose instead of all-gathering it @@ -729,6 +743,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): quantizer=tensor._quantizer, requires_grad=tensor.requires_grad, with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, + row_scaled_nvfp4=tensor._row_scaled_nvfp4, + use_4over6=tensor._use_4over6, ) # Default case @@ -748,6 +764,8 @@ def _make_in_reduce_ex( dtype: torch.dtype, quantizer: Quantizer, with_gemm_swizzled_scales: bool = False, + row_scaled_nvfp4: bool = False, + use_4over6: bool = False, ) -> NVFP4Tensor: """Build NVFP4Tensor, for use in __reduce__ @@ -768,6 +786,8 @@ def _make_in_reduce_ex( quantizer=quantizer, requires_grad=False, with_gemm_swizzled_scales=with_gemm_swizzled_scales, + row_scaled_nvfp4=row_scaled_nvfp4, + use_4over6=use_4over6, ) def __reduce_ex__(self, protocol: int) -> tuple: @@ -786,6 +806,8 @@ def __reduce_ex__(self, protocol: int) -> tuple: self.dtype, self._quantizer, self._with_gemm_swizzled_scales, + self._row_scaled_nvfp4, + self._use_4over6, ), ) @@ -838,6 +860,8 @@ def _set_data(self, tensor: torch.Tensor) -> None: self._amax_rowwise = tensor._amax_rowwise self._amax_columnwise = tensor._amax_columnwise self._with_gemm_swizzled_scales = tensor._with_gemm_swizzled_scales + self._row_scaled_nvfp4 = tensor._row_scaled_nvfp4 + self._use_4over6 = tensor._use_4over6 return # Quantize to FP8 @@ -958,6 +982,8 @@ def forward( fp4_dtype=tensor._fp4_dtype, requires_grad=tensor.requires_grad, with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, + row_scaled_nvfp4=tensor._row_scaled_nvfp4, + use_4over6=tensor._use_4over6, ) @staticmethod @@ -1000,6 +1026,8 @@ def backward( fp4_dtype=grad._fp4_dtype, requires_grad=grad.requires_grad, with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales, + row_scaled_nvfp4=grad._row_scaled_nvfp4, + use_4over6=grad._use_4over6, ) return dgrad, None return grad.view(ctx.shape), None @@ -1084,6 +1112,8 @@ def forward( fp4_dtype=tensor._fp4_dtype, requires_grad=tensor.requires_grad, with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, + row_scaled_nvfp4=tensor._row_scaled_nvfp4, + use_4over6=tensor._use_4over6, ) @staticmethod @@ -1126,6 +1156,8 @@ def backward( fp4_dtype=grad._fp4_dtype, requires_grad=grad.requires_grad, with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales, + row_scaled_nvfp4=grad._row_scaled_nvfp4, + use_4over6=grad._use_4over6, ) return dgrad, None return grad.view(ctx.shape), None From 100c378ebd1ad49c4afdc2f21e29317b0b18f7b7 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 9 May 2026 18:40:59 -0700 Subject: [PATCH 12/57] Add more tests Signed-off-by: Ziang Li --- tests/pytorch/test_backward_override.py | 5 ---- tests/pytorch/test_cpu_offloading.py | 9 +++--- tests/pytorch/test_cuda_graphs.py | 8 ++++- tests/pytorch/test_numerics.py | 39 ++++++++++++++++++++++--- tests/pytorch/test_sanity.py | 9 +++--- tests/pytorch/test_torch_compile.py | 9 +++--- 6 files changed, 57 insertions(+), 22 deletions(-) diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index 5a4327b18d..5f54267a8f 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -88,11 +88,6 @@ marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), id="NVFP44Over6BlockScaling", ), - pytest.param( - "nvfp4_row_scaled_4over6", - marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), - id="NVFP4RowScaled4Over6BlockScaling", - ), ] diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index d30d1c596d..a897e492c5 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -29,12 +29,13 @@ nvfp4_available, _ = FP8GlobalStateManager.is_nvfp4_available() -def nvfp4_4over6(): +def nvfp4_row_scaled(): nvfp4_recipe = recipe.NVFP4BlockScaling( disable_rht=True, disable_stochastic_rounding=True, disable_2d_quantization=True, - enable_4over6=True, + row_scaled_activation=True, + backward_override="dequantized", ) nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() @@ -42,12 +43,12 @@ def nvfp4_4over6(): return nvfp4_recipe -def nvfp4_row_scaled(): +def nvfp4_4over6(): nvfp4_recipe = recipe.NVFP4BlockScaling( disable_rht=True, disable_stochastic_rounding=True, disable_2d_quantization=True, - row_scaled_activation=True, + enable_4over6=True, ) nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 7df90a6b9a..0526b7b99a 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -65,7 +65,13 @@ def nvfp4_rht_and_2d_quantization(): def nvfp4_row_scaled(): - nvfp4_recipe = recipe.NVFP4BlockScaling(row_scaled_activation=True) + nvfp4_recipe = recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + row_scaled_activation=True, + backward_override="dequantized", + ) nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 357f7b294d..b19a27fd0d 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -54,7 +54,7 @@ from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor from transformer_engine.common import recipe import transformer_engine_torch as tex -from utils import ModelConfig, reset_rng_states +from utils import ModelConfig, reset_rng_states, skip_unsupported_backward_override # Only run FP8 tests on supported devices. @@ -138,6 +138,20 @@ def nvfp4_rht_and_2d_quantization(): return nvfp4_recipe +def nvfp4_row_scaled(): + nvfp4_recipe = recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + row_scaled_activation=True, + backward_override="dequantized", + ) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + def nvfp4_4over6(): nvfp4_recipe = recipe.NVFP4BlockScaling( disable_rht=True, @@ -1908,7 +1922,10 @@ def _test_grouped_linear_accuracy( @pytest.mark.parametrize("num_gemms", [3, 6]) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) -@pytest.mark.parametrize("recipe", fp8_recipes + [None]) +@pytest.mark.parametrize( + "recipe", + fp8_recipes + ([nvfp4_row_scaled()] if nvfp4_available else []) + [None], +) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) @pytest.mark.parametrize("bias", all_boolean) @@ -1931,6 +1948,9 @@ def test_grouped_linear_accuracy( pytest.skip("FP8 parameters are not supported in debug mode.") if NVTE_TEST_NVINSPECT_ENABLED and delay_wgrad_compute: pytest.skip("Delayed wgrad compute is not supported in debug mode.") + skip_unsupported_backward_override( + "grouped_linear", recipe, getattr(recipe, "backward_override", None) + ) config = model_configs[model] if config.max_seqlen_q % 16 != 0 and fp8: @@ -2075,6 +2095,7 @@ def test_grouped_linear_accuracy_save_original_input( pytest.skip("DelayedScaling recipe is not supported with save_original_input") if NVTE_TEST_NVINSPECT_ENABLED and delay_wgrad_compute: pytest.skip("Delayed wgrad compute is not supported in debug mode.") + skip_unsupported_backward_override("grouped_linear", recipe, None) config = model_configs[model] if config.max_seqlen_q % 16 != 0 and fp8: @@ -2153,7 +2174,10 @@ def test_grouped_linear_accuracy_save_original_input( torch.testing.assert_close(o, o_ref, rtol=0, atol=0) -@pytest.mark.parametrize("recipe", fp8_recipes + [None]) +@pytest.mark.parametrize( + "recipe", + fp8_recipes + ([nvfp4_row_scaled()] if nvfp4_available else []) + [None], +) def test_grouped_linear_accuracy_single_gemm(recipe): """Split the tests to save CI time""" test_grouped_linear_accuracy( @@ -2267,7 +2291,10 @@ def _generate_random_numbers(n, total_sum): @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", [True]) -@pytest.mark.parametrize("recipe", fp8_recipes) +@pytest.mark.parametrize( + "recipe", + fp8_recipes + ([nvfp4_row_scaled()] if nvfp4_available else []), +) @pytest.mark.parametrize("fp8_model_params", all_boolean) def test_padding_grouped_linear_accuracy( dtype, @@ -2281,6 +2308,9 @@ def test_padding_grouped_linear_accuracy( ): if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") + skip_unsupported_backward_override( + "grouped_linear", recipe, getattr(recipe, "backward_override", None) + ) config = model_configs[model] if config.max_seqlen_q % 16 != 0 and fp8: @@ -2358,6 +2388,7 @@ def test_padding_grouped_linear_accuracy_save_original_input( pytest.skip("FP8 parameters are not supported in debug mode.") if fp8 and recipe.delayed(): pytest.skip("DelayedScaling recipe is not supported with save_original_input") + skip_unsupported_backward_override("grouped_linear", recipe, None) config = model_configs[model] if config.max_seqlen_q % 16 != 0 and fp8: diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 1beac7e829..80d057c6d5 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -94,12 +94,13 @@ def nvfp4_vanilla(): return nvfp4_recipe -def nvfp4_4over6(): +def nvfp4_row_scaled(): nvfp4_recipe = recipe.NVFP4BlockScaling( disable_rht=True, disable_stochastic_rounding=True, disable_2d_quantization=True, - enable_4over6=True, + row_scaled_activation=True, + backward_override="dequantized", ) nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() @@ -107,12 +108,12 @@ def nvfp4_4over6(): return nvfp4_recipe -def nvfp4_row_scaled(): +def nvfp4_4over6(): nvfp4_recipe = recipe.NVFP4BlockScaling( disable_rht=True, disable_stochastic_rounding=True, disable_2d_quantization=True, - row_scaled_activation=True, + enable_4over6=True, ) nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py index c9c163e5de..74be8cb358 100644 --- a/tests/pytorch/test_torch_compile.py +++ b/tests/pytorch/test_torch_compile.py @@ -40,12 +40,13 @@ nvfp4_available = is_nvfp4_available() -def nvfp4_4over6(): +def nvfp4_row_scaled(): nvfp4_recipe = recipe.NVFP4BlockScaling( disable_rht=True, disable_stochastic_rounding=True, disable_2d_quantization=True, - enable_4over6=True, + row_scaled_activation=True, + backward_override="dequantized", ) nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() @@ -53,12 +54,12 @@ def nvfp4_4over6(): return nvfp4_recipe -def nvfp4_row_scaled(): +def nvfp4_4over6(): nvfp4_recipe = recipe.NVFP4BlockScaling( disable_rht=True, disable_stochastic_rounding=True, disable_2d_quantization=True, - row_scaled_activation=True, + enable_4over6=True, ) nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() From 1c9f26b3f08be5394bae6a5b36d6e64b0115dd5d Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 9 May 2026 21:02:33 -0700 Subject: [PATCH 13/57] Clean up test Signed-off-by: Ziang Li --- tests/pytorch/test_fusible_ops.py | 32 +++++++++----------------- tests/pytorch/test_numerics.py | 38 +++++++++++++++++++------------ 2 files changed, 34 insertions(+), 36 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 3c99b65360..d612fea552 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -80,16 +80,6 @@ _quantization_list.append("nvfp4_4over6") -def is_nvfp4_quantization(quantization: Optional[str]) -> bool: - """Whether a quantization recipe uses NVFP4.""" - return quantization is not None and "nvfp4" in quantization - - -def is_nvfp4_4over6_quantization(quantization: Optional[str]) -> bool: - """Whether a quantization recipe uses NVFP4 4over6.""" - return quantization is not None and "4over6" in quantization - - @pytest.fixture(autouse=True, scope="class") def _reset_rng_states_per_test(): """Restore torch, CUDA, and Python ``random`` before each test in this module.""" @@ -117,7 +107,7 @@ def maybe_skip_quantization( pytest.skip(reason_for_no_fp8) if quantization == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if is_nvfp4_quantization(quantization) and not nvfp4_available: + if "nvfp4" in quantization and not nvfp4_available: pytest.skip(reason_for_no_nvfp4) # Check dims @@ -130,13 +120,13 @@ def maybe_skip_quantization( elif quantization == "mxfp8": if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0: pytest.skip("MXFP8 GEMMs require dims that are divisible by 32") - elif is_nvfp4_quantization(quantization): + elif "nvfp4" in quantization: if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0: pytest.skip("NVFP4 GEMMs require dims that are divisible by 16") # Check dtype if dtype is not None: - if is_nvfp4_quantization(quantization) and dtype != torch.bfloat16: + if "nvfp4" in quantization and dtype != torch.bfloat16: pytest.skip("NVFP4 quantization is only supported with BF16 data") @@ -191,14 +181,14 @@ def make_reference_and_test_tensors( test = quantizer(test) elif quantization == "mxfp8": test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) - elif is_nvfp4_quantization(quantization): + elif "nvfp4" in quantization: test = NVFP4Quantizer( with_rht=False, with_post_rht_amax=False, with_2d_quantization=False, stochastic_rounding=False, with_random_sign_mask=False, - use_4over6=is_nvfp4_4over6_quantization(quantization), + use_4over6="4over6" in quantization, )(test) else: raise ValueError(f"Unsupported quantization scheme ({quantization})") @@ -1524,7 +1514,7 @@ def test_add_extra_input( if in_place: if quantization in ("fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"): tols = dtype_tols(x1_test._fp8_dtype) - elif is_nvfp4_quantization(quantization): + elif "nvfp4" in quantization: tols = dtype_tols(x1_test._fp4_dtype) y_test = y_test.to(dtype=torch.float64, device="cpu") dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu") @@ -1895,7 +1885,7 @@ def test_clamped_swiglu( # Expected numerical error tols = dtype_tols(dtype) - if quantized_compute and is_nvfp4_quantization(quantization): + if quantized_compute and "nvfp4" in quantization: tols = dtype_tols(tex.DType.kFloat4E2M1) elif quantized_compute: tols = dtype_tols(tex.DType.kFloat8E4M3) @@ -2088,7 +2078,7 @@ def test_grouped_linear( pytest.skip("Quantization scheme is not used") if quantization is not None and dtype not in (torch.bfloat16, torch.float16): pytest.skip("Quantized group GEMM is only supported with BF16/FP16") - if is_nvfp4_4over6_quantization(quantization): + if quantization and "4over6" in quantization: pytest.skip("NVFP4 4over6 grouped quantization is not supported") if single_grouped_bias and not bias: @@ -3622,9 +3612,9 @@ def test_grouped_mlp( pytest.skip("single_grouped_bias requires bias=True") if with_quantization and dtype not in (torch.bfloat16, torch.float16): pytest.skip("Quantized group GEMM is only supported with BF16/FP16") - if is_nvfp4_4over6_quantization(quantization): + if with_quantization and "4over6" in quantization: pytest.skip("NVFP4 4over6 grouped quantization is not supported") - if is_nvfp4_quantization(quantization) and activation == "scaled_clamped_qgeglu" and bias: + if "nvfp4" in quantization and activation == "scaled_clamped_qgeglu" and bias: # TODO: ksivaman: Need to debug numerics for this case. pytest.skip("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU") @@ -3846,7 +3836,7 @@ def test_grouped_mlp( # Loose tols for sanity checking tols = {"rtol": 0.125, "atol": 0.25} - if is_nvfp4_quantization(quantization): + if "nvfp4" in quantization: tols = {"rtol": 0.25, "atol": 0.5} # Check values diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index b19a27fd0d..430f5c12f3 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -144,7 +144,7 @@ def nvfp4_row_scaled(): disable_stochastic_rounding=True, disable_2d_quantization=True, row_scaled_activation=True, - backward_override="dequantized", + backward_override="high_precision", ) nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() @@ -199,6 +199,7 @@ def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> if nvfp4_available: fp8_recipes.append(nvfp4_rht_and_2d_quantization()) fp8_recipes.append(nvfp4_4over6()) + fp8_recipes.append(nvfp4_row_scaled()) use_cutlass_grouped_gemm = [False] # Only enable cutlass grouped gemm on Hopper @@ -660,6 +661,10 @@ def _test_e2e_selective_recompute( def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params): if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") + if fp8 or fp8_model_params: + skip_unsupported_backward_override( + "transformer_layer", recipe, getattr(recipe, "backward_override", None) + ) if fp8 and recipe.nvfp4(): if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): pytest.skip( @@ -775,6 +780,10 @@ def test_gpt_full_activation_recompute( ): if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") + if fp8 or fp8_model_params: + skip_unsupported_backward_override( + "transformer_layer", recipe, getattr(recipe, "backward_override", None) + ) if fp8 and recipe.nvfp4(): if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): pytest.skip( @@ -1361,6 +1370,7 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe): if fp8 and recipe.delayed(): pytest.skip("DelayedScaling recipe is not supported with save_original_input") + skip_unsupported_backward_override("linear", recipe, getattr(recipe, "backward_override", None)) config = model_configs[model] if config.max_seqlen_q % 16 != 0 and fp8: @@ -1922,10 +1932,7 @@ def _test_grouped_linear_accuracy( @pytest.mark.parametrize("num_gemms", [3, 6]) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) -@pytest.mark.parametrize( - "recipe", - fp8_recipes + ([nvfp4_row_scaled()] if nvfp4_available else []) + [None], -) +@pytest.mark.parametrize("recipe", fp8_recipes + [None]) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) @pytest.mark.parametrize("bias", all_boolean) @@ -2095,7 +2102,9 @@ def test_grouped_linear_accuracy_save_original_input( pytest.skip("DelayedScaling recipe is not supported with save_original_input") if NVTE_TEST_NVINSPECT_ENABLED and delay_wgrad_compute: pytest.skip("Delayed wgrad compute is not supported in debug mode.") - skip_unsupported_backward_override("grouped_linear", recipe, None) + skip_unsupported_backward_override( + "grouped_linear", recipe, getattr(recipe, "backward_override", None) + ) config = model_configs[model] if config.max_seqlen_q % 16 != 0 and fp8: @@ -2174,10 +2183,7 @@ def test_grouped_linear_accuracy_save_original_input( torch.testing.assert_close(o, o_ref, rtol=0, atol=0) -@pytest.mark.parametrize( - "recipe", - fp8_recipes + ([nvfp4_row_scaled()] if nvfp4_available else []) + [None], -) +@pytest.mark.parametrize("recipe", fp8_recipes + [None]) def test_grouped_linear_accuracy_single_gemm(recipe): """Split the tests to save CI time""" test_grouped_linear_accuracy( @@ -2291,10 +2297,7 @@ def _generate_random_numbers(n, total_sum): @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", [True]) -@pytest.mark.parametrize( - "recipe", - fp8_recipes + ([nvfp4_row_scaled()] if nvfp4_available else []), -) +@pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("fp8_model_params", all_boolean) def test_padding_grouped_linear_accuracy( dtype, @@ -2388,7 +2391,9 @@ def test_padding_grouped_linear_accuracy_save_original_input( pytest.skip("FP8 parameters are not supported in debug mode.") if fp8 and recipe.delayed(): pytest.skip("DelayedScaling recipe is not supported with save_original_input") - skip_unsupported_backward_override("grouped_linear", recipe, None) + skip_unsupported_backward_override( + "grouped_linear", recipe, getattr(recipe, "backward_override", None) + ) config = model_configs[model] if config.max_seqlen_q % 16 != 0 and fp8: @@ -2608,6 +2613,9 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): def test_gpt_fp8_parameters(dtype, bs, model, recipe): if NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") + skip_unsupported_backward_override( + "transformer_layer", recipe, getattr(recipe, "backward_override", None) + ) if recipe.nvfp4(): if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): From 93fe922c1fc4e1e45a779a2e1444a48ff07ac577 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 9 May 2026 22:15:05 -0700 Subject: [PATCH 14/57] Refactor cuh kernel impl Signed-off-by: Ziang Li --- tests/pytorch/test_recipe.py | 2 + .../common/cast/nvfp4/core_nvfp4.cuh | 142 -------- .../cast/nvfp4/quantize_4over6_nvfp4.cuh | 316 ++++++++++++++++++ .../quantize_transpose_nvfp4_tuned_1D.cuh | 243 +++++--------- ...quantize_transpose_vector_blockwise_fp4.cu | 83 +---- 5 files changed, 425 insertions(+), 361 deletions(-) create mode 100644 transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 57305437e7..095fd60b8c 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -573,6 +573,8 @@ def test_fp4_dequantize(dtype, row_scaled_nvfp4, use_4over6, M, N): assert new_tensor._row_scaled_nvfp4 == row_scaled_nvfp4 assert new_tensor._use_4over6 == use_4over6 assert new_tensor._amax_rowwise.numel() == (M if row_scaled_nvfp4 else 1) + # 4over6 can re-encode a dequantized block with the alternate 4/6 scale + # choice while preserving the dequantized values. if not use_4over6: torch.testing.assert_close( new_tensor._rowwise_data, diff --git a/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh index fcd88d9585..92a6b7963e 100644 --- a/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh @@ -13,7 +13,6 @@ #include #include -#include #include #include @@ -92,147 +91,6 @@ __device__ __forceinline__ float compute_global_encode_scaling_factor_FP4(const return global_encode_scale; } -__device__ __forceinline__ void compute_4over6_decoding_scaling_factors( - const float block_amax, const float S_enc, nvfp4_scale_t &S_dec_b_fp8_map4, - nvfp4_scale_t &S_dec_b_fp8_map6) { - constexpr float fp4_max = detail::TypeExtrema::max; // 6.0f - const float sf_high_precision = block_amax / fp4_max * S_enc; - S_dec_b_fp8_map4 = static_cast(sf_high_precision * 1.5f); - S_dec_b_fp8_map6 = static_cast(sf_high_precision); -} - -template -__device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_mse_rn(const float (&x)[8], - const float block_scale_inverse, - const nvfp4_scale_t S_dec_b_fp8, - const float global_amax, - float *err) { - uint32_t out = 0; - uint32_t out_dequant_1 = 0; - uint32_t out_dequant_2 = 0; - uint32_t out_dequant_3 = 0; - uint32_t out_dequant_4 = 0; - - constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; - if constexpr (is_blackwell) { - float x_scaled[8]; - if constexpr (USE_FAST_MATH) { -#pragma unroll - for (int i = 0; i < 8; ++i) { - x_scaled[i] = x[i] * block_scale_inverse; - } - } else { - x_scaled[0] = __fmul_rn(x[0], block_scale_inverse); - x_scaled[1] = __fmul_rn(x[1], block_scale_inverse); - x_scaled[2] = __fmul_rn(x[2], block_scale_inverse); - x_scaled[3] = __fmul_rn(x[3], block_scale_inverse); - x_scaled[4] = __fmul_rn(x[4], block_scale_inverse); - x_scaled[5] = __fmul_rn(x[5], block_scale_inverse); - x_scaled[6] = __fmul_rn(x[6], block_scale_inverse); - x_scaled[7] = __fmul_rn(x[7], block_scale_inverse); - } - - asm volatile( - "{\n" - ".reg .b8 byte0, byte1, byte2, byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %8, %7;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %10, %9;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %12, %11;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "cvt.rn.f16x2.e2m1x2 %1, byte0;\n" - "cvt.rn.f16x2.e2m1x2 %2, byte1;\n" - "cvt.rn.f16x2.e2m1x2 %3, byte2;\n" - "cvt.rn.f16x2.e2m1x2 %4, byte3;\n" - "}" - : "=r"(out), "=r"(out_dequant_1), "=r"(out_dequant_2), "=r"(out_dequant_3), - "=r"(out_dequant_4) - : "f"(x_scaled[0]), "f"(x_scaled[1]), "f"(x_scaled[2]), "f"(x_scaled[3]), "f"(x_scaled[4]), - "f"(x_scaled[5]), "f"(x_scaled[6]), "f"(x_scaled[7])); - - const uint16_t out_dequant_1_hi = (out_dequant_1 >> 16) & 0xFFFF; - const uint16_t out_dequant_1_lo = out_dequant_1 & 0xFFFF; - const uint16_t out_dequant_2_hi = (out_dequant_2 >> 16) & 0xFFFF; - const uint16_t out_dequant_2_lo = out_dequant_2 & 0xFFFF; - const uint16_t out_dequant_3_hi = (out_dequant_3 >> 16) & 0xFFFF; - const uint16_t out_dequant_3_lo = out_dequant_3 & 0xFFFF; - const uint16_t out_dequant_4_hi = (out_dequant_4 >> 16) & 0xFFFF; - const uint16_t out_dequant_4_lo = out_dequant_4 & 0xFFFF; - - constexpr float fp4_max = detail::TypeExtrema::max; // 6.0f - constexpr float fp8_4over6_max = 256.0f; - constexpr float mse_denom = fp4_max * fp8_4over6_max; - const float sf = static_cast(S_dec_b_fp8); - if constexpr (USE_FAST_MATH) { - const float dequant[8] = { - __half2float(__ushort_as_half(out_dequant_1_lo)), - __half2float(__ushort_as_half(out_dequant_1_hi)), - __half2float(__ushort_as_half(out_dequant_2_lo)), - __half2float(__ushort_as_half(out_dequant_2_hi)), - __half2float(__ushort_as_half(out_dequant_3_lo)), - __half2float(__ushort_as_half(out_dequant_3_hi)), - __half2float(__ushort_as_half(out_dequant_4_lo)), - __half2float(__ushort_as_half(out_dequant_4_hi)), - }; -#pragma unroll - for (int i = 0; i < 8; ++i) { - const float val = dequant[i] * sf * global_amax / mse_denom; - const float diff = val - x[i]; - *err += diff * diff; - } - } else { - const float val0 = __fdiv_rn( - __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_1_lo)), sf), global_amax), - mse_denom); - const float val1 = __fdiv_rn( - __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_1_hi)), sf), global_amax), - mse_denom); - const float val2 = __fdiv_rn( - __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_2_lo)), sf), global_amax), - mse_denom); - const float val3 = __fdiv_rn( - __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_2_hi)), sf), global_amax), - mse_denom); - const float val4 = __fdiv_rn( - __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_3_lo)), sf), global_amax), - mse_denom); - const float val5 = __fdiv_rn( - __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_3_hi)), sf), global_amax), - mse_denom); - const float val6 = __fdiv_rn( - __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_4_lo)), sf), global_amax), - mse_denom); - const float val7 = __fdiv_rn( - __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_4_hi)), sf), global_amax), - mse_denom); - - const float diff0 = __fsub_rn(val0, x[0]); - const float diff1 = __fsub_rn(val1, x[1]); - const float diff2 = __fsub_rn(val2, x[2]); - const float diff3 = __fsub_rn(val3, x[3]); - const float diff4 = __fsub_rn(val4, x[4]); - const float diff5 = __fsub_rn(val5, x[5]); - const float diff6 = __fsub_rn(val6, x[6]); - const float diff7 = __fsub_rn(val7, x[7]); - - *err = __fadd_rn(*err, __fmul_rn(diff0, diff0)); - *err = __fadd_rn(*err, __fmul_rn(diff1, diff1)); - *err = __fadd_rn(*err, __fmul_rn(diff2, diff2)); - *err = __fadd_rn(*err, __fmul_rn(diff3, diff3)); - *err = __fadd_rn(*err, __fmul_rn(diff4, diff4)); - *err = __fadd_rn(*err, __fmul_rn(diff5, diff5)); - *err = __fadd_rn(*err, __fmul_rn(diff6, diff6)); - *err = __fadd_rn(*err, __fmul_rn(diff7, diff7)); - } - } else { - NVTE_DEVICE_ERROR( - "FP4 cvt PTX instructions are architecture-specific. " - "Try recompiling with sm_XXXa instead of sm_XXX."); - } - - return out; -} - __device__ __forceinline__ uint32_t get_rbits( transformer_engine::curanddx::detail::philox4x32_native_state &rng, diff --git a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh new file mode 100644 index 0000000000..eada49ed45 --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh @@ -0,0 +1,316 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_4over6_nvfp4.cuh + * \brief Helpers used by NVFP4 4over6 quantization. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_4OVER6_NVFP4_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_4OVER6_NVFP4_CUH_ + +#include +#include +#include +#include + +#include "core_nvfp4.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { +namespace core { + +#if FP4_TYPE_SUPPORTED + +__device__ __forceinline__ void compute_4over6_decoding_scaling_factors( + const float block_amax, const float S_enc, nvfp4_scale_t &S_dec_b_fp8_map4, + nvfp4_scale_t &S_dec_b_fp8_map6) { + constexpr float fp4_max = detail::TypeExtrema::max; // 6.0f + const float sf_high_precision = block_amax / fp4_max * S_enc; + S_dec_b_fp8_map4 = static_cast(sf_high_precision * 1.5f); + S_dec_b_fp8_map6 = static_cast(sf_high_precision); +} + +template +__device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_mse_rn(const float (&x)[8], + const float block_scale_inverse, + const nvfp4_scale_t S_dec_b_fp8, + const float global_amax, + float *err) { + uint32_t out = 0; + uint32_t out_dequant_1 = 0; + uint32_t out_dequant_2 = 0; + uint32_t out_dequant_3 = 0; + uint32_t out_dequant_4 = 0; + + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; + if constexpr (is_blackwell) { + float x_scaled[8]; + if constexpr (USE_FAST_MATH) { +#pragma unroll + for (int i = 0; i < 8; ++i) { + x_scaled[i] = x[i] * block_scale_inverse; + } + } else { + x_scaled[0] = __fmul_rn(x[0], block_scale_inverse); + x_scaled[1] = __fmul_rn(x[1], block_scale_inverse); + x_scaled[2] = __fmul_rn(x[2], block_scale_inverse); + x_scaled[3] = __fmul_rn(x[3], block_scale_inverse); + x_scaled[4] = __fmul_rn(x[4], block_scale_inverse); + x_scaled[5] = __fmul_rn(x[5], block_scale_inverse); + x_scaled[6] = __fmul_rn(x[6], block_scale_inverse); + x_scaled[7] = __fmul_rn(x[7], block_scale_inverse); + } + + asm volatile( + "{\n" + ".reg .b8 byte0, byte1, byte2, byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %8, %7;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %10, %9;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %12, %11;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "cvt.rn.f16x2.e2m1x2 %1, byte0;\n" + "cvt.rn.f16x2.e2m1x2 %2, byte1;\n" + "cvt.rn.f16x2.e2m1x2 %3, byte2;\n" + "cvt.rn.f16x2.e2m1x2 %4, byte3;\n" + "}" + : "=r"(out), "=r"(out_dequant_1), "=r"(out_dequant_2), "=r"(out_dequant_3), + "=r"(out_dequant_4) + : "f"(x_scaled[0]), "f"(x_scaled[1]), "f"(x_scaled[2]), "f"(x_scaled[3]), "f"(x_scaled[4]), + "f"(x_scaled[5]), "f"(x_scaled[6]), "f"(x_scaled[7])); + + const uint16_t out_dequant_1_hi = (out_dequant_1 >> 16) & 0xFFFF; + const uint16_t out_dequant_1_lo = out_dequant_1 & 0xFFFF; + const uint16_t out_dequant_2_hi = (out_dequant_2 >> 16) & 0xFFFF; + const uint16_t out_dequant_2_lo = out_dequant_2 & 0xFFFF; + const uint16_t out_dequant_3_hi = (out_dequant_3 >> 16) & 0xFFFF; + const uint16_t out_dequant_3_lo = out_dequant_3 & 0xFFFF; + const uint16_t out_dequant_4_hi = (out_dequant_4 >> 16) & 0xFFFF; + const uint16_t out_dequant_4_lo = out_dequant_4 & 0xFFFF; + + constexpr float fp4_max = detail::TypeExtrema::max; // 6.0f + constexpr float fp8_4over6_max = 256.0f; + constexpr float mse_denom = fp4_max * fp8_4over6_max; + const float sf = static_cast(S_dec_b_fp8); + if constexpr (USE_FAST_MATH) { + const float dequant[8] = { + __half2float(__ushort_as_half(out_dequant_1_lo)), + __half2float(__ushort_as_half(out_dequant_1_hi)), + __half2float(__ushort_as_half(out_dequant_2_lo)), + __half2float(__ushort_as_half(out_dequant_2_hi)), + __half2float(__ushort_as_half(out_dequant_3_lo)), + __half2float(__ushort_as_half(out_dequant_3_hi)), + __half2float(__ushort_as_half(out_dequant_4_lo)), + __half2float(__ushort_as_half(out_dequant_4_hi)), + }; +#pragma unroll + for (int i = 0; i < 8; ++i) { + const float val = dequant[i] * sf * global_amax / mse_denom; + const float diff = val - x[i]; + *err += diff * diff; + } + } else { + const float val0 = __fdiv_rn( + __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_1_lo)), sf), global_amax), + mse_denom); + const float val1 = __fdiv_rn( + __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_1_hi)), sf), global_amax), + mse_denom); + const float val2 = __fdiv_rn( + __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_2_lo)), sf), global_amax), + mse_denom); + const float val3 = __fdiv_rn( + __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_2_hi)), sf), global_amax), + mse_denom); + const float val4 = __fdiv_rn( + __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_3_lo)), sf), global_amax), + mse_denom); + const float val5 = __fdiv_rn( + __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_3_hi)), sf), global_amax), + mse_denom); + const float val6 = __fdiv_rn( + __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_4_lo)), sf), global_amax), + mse_denom); + const float val7 = __fdiv_rn( + __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_4_hi)), sf), global_amax), + mse_denom); + + const float diff0 = __fsub_rn(val0, x[0]); + const float diff1 = __fsub_rn(val1, x[1]); + const float diff2 = __fsub_rn(val2, x[2]); + const float diff3 = __fsub_rn(val3, x[3]); + const float diff4 = __fsub_rn(val4, x[4]); + const float diff5 = __fsub_rn(val5, x[5]); + const float diff6 = __fsub_rn(val6, x[6]); + const float diff7 = __fsub_rn(val7, x[7]); + + *err = __fadd_rn(*err, __fmul_rn(diff0, diff0)); + *err = __fadd_rn(*err, __fmul_rn(diff1, diff1)); + *err = __fadd_rn(*err, __fmul_rn(diff2, diff2)); + *err = __fadd_rn(*err, __fmul_rn(diff3, diff3)); + *err = __fadd_rn(*err, __fmul_rn(diff4, diff4)); + *err = __fadd_rn(*err, __fmul_rn(diff5, diff5)); + *err = __fadd_rn(*err, __fmul_rn(diff6, diff6)); + *err = __fadd_rn(*err, __fmul_rn(diff7, diff7)); + } + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } + + return out; +} + +template +__device__ __forceinline__ void quantize_4over6_16x( + const float (&first_half)[8], const float (&second_half)[8], + const nvfp4_scale_t S_dec_b_fp8_map4, const nvfp4_scale_t S_dec_b_fp8_map6, + const scaling_coeff_type SFcoefficient_map4, const scaling_coeff_type SFcoefficient_map6, + const float global_amax, nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { + float err_map4 = 0.0f; + float err_map6 = 0.0f; + __align__(8) uint32_t rOut_map4[2]; + __align__(8) uint32_t rOut_map6[2]; + + if constexpr (REVERSE_PACK_ORDER) { + rOut_map4[1] = cvt_fp32_to_fp4_8x_with_mse_rn( + second_half, static_cast(SFcoefficient_map4), S_dec_b_fp8_map4, global_amax, + &err_map4); + rOut_map6[1] = cvt_fp32_to_fp4_8x_with_mse_rn( + second_half, static_cast(SFcoefficient_map6), S_dec_b_fp8_map6, global_amax, + &err_map6); + rOut_map4[0] = cvt_fp32_to_fp4_8x_with_mse_rn( + first_half, static_cast(SFcoefficient_map4), S_dec_b_fp8_map4, global_amax, + &err_map4); + rOut_map6[0] = cvt_fp32_to_fp4_8x_with_mse_rn( + first_half, static_cast(SFcoefficient_map6), S_dec_b_fp8_map6, global_amax, + &err_map6); + } else { + rOut_map4[0] = cvt_fp32_to_fp4_8x_with_mse_rn( + first_half, static_cast(SFcoefficient_map4), S_dec_b_fp8_map4, global_amax, + &err_map4); + rOut_map6[0] = cvt_fp32_to_fp4_8x_with_mse_rn( + first_half, static_cast(SFcoefficient_map6), S_dec_b_fp8_map6, global_amax, + &err_map6); + rOut_map4[1] = cvt_fp32_to_fp4_8x_with_mse_rn( + second_half, static_cast(SFcoefficient_map4), S_dec_b_fp8_map4, global_amax, + &err_map4); + rOut_map6[1] = cvt_fp32_to_fp4_8x_with_mse_rn( + second_half, static_cast(SFcoefficient_map6), S_dec_b_fp8_map6, global_amax, + &err_map6); + } + + if (err_map4 < err_map6) { + S_dec_b_fp8 = S_dec_b_fp8_map4; + rOut[0] = rOut_map4[0]; + rOut[1] = rOut_map4[1]; + } else { + S_dec_b_fp8 = S_dec_b_fp8_map6; + rOut[0] = rOut_map6[0]; + rOut[1] = rOut_map6[1]; + } +} + +template +__device__ __forceinline__ void store_4over6_packed_16x(const uint32_t (&packed)[2], + output_vec_type &output_vec) { + *reinterpret_cast(&output_vec.data.elt[0]) = packed[0]; + *reinterpret_cast(&output_vec.data.elt[4]) = packed[1]; +} + +template +__device__ __forceinline__ void quantize_4over6_contiguous_16x( + const input_type *x, const nvfp4_scale_t S_dec_b_fp8_map4, const nvfp4_scale_t S_dec_b_fp8_map6, + const scaling_coeff_type SFcoefficient_map4, const scaling_coeff_type SFcoefficient_map6, + const float global_amax, nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { + float first_half[8]; + float second_half[8]; +#pragma unroll + for (int i = 0; i < 8; ++i) { + first_half[i] = static_cast(x[i]); + second_half[i] = static_cast(x[i + 8]); + } + + quantize_4over6_16x( + first_half, second_half, S_dec_b_fp8_map4, S_dec_b_fp8_map6, SFcoefficient_map4, + SFcoefficient_map6, global_amax, S_dec_b_fp8, rOut); +} + +template +__device__ __forceinline__ void quantize_4over6_pair_array_16x( + const pair_type (&x)[2][4], const nvfp4_scale_t S_dec_b_fp8_map4, + const nvfp4_scale_t S_dec_b_fp8_map6, const scaling_coeff_type SFcoefficient_map4, + const scaling_coeff_type SFcoefficient_map6, const float global_amax, + nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { + float first_half[8]; + float second_half[8]; +#pragma unroll + for (int i = 0; i < 4; ++i) { + first_half[2 * i] = static_cast(x[0][i].x); + first_half[2 * i + 1] = static_cast(x[0][i].y); + second_half[2 * i] = static_cast(x[1][i].x); + second_half[2 * i + 1] = static_cast(x[1][i].y); + } + + quantize_4over6_16x( + first_half, second_half, S_dec_b_fp8_map4, S_dec_b_fp8_map6, SFcoefficient_map4, + SFcoefficient_map6, global_amax, S_dec_b_fp8, rOut); +} + +template +__device__ __forceinline__ void quantize_4over6_vec2_array_16x( + const vec_type (&x)[8], const nvfp4_scale_t S_dec_b_fp8_map4, + const nvfp4_scale_t S_dec_b_fp8_map6, const scaling_coeff_type SFcoefficient_map4, + const scaling_coeff_type SFcoefficient_map6, const float global_amax, + nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { + float first_half[8]; + float second_half[8]; +#pragma unroll + for (int i = 0; i < 4; ++i) { + first_half[2 * i] = static_cast(x[i].data.elt[0]); + first_half[2 * i + 1] = static_cast(x[i].data.elt[1]); + second_half[2 * i] = static_cast(x[i + 4].data.elt[0]); + second_half[2 * i + 1] = static_cast(x[i + 4].data.elt[1]); + } + + quantize_4over6_16x( + first_half, second_half, S_dec_b_fp8_map4, S_dec_b_fp8_map6, SFcoefficient_map4, + SFcoefficient_map6, global_amax, S_dec_b_fp8, rOut); +} + +template +__device__ __forceinline__ void quantize_4over6_vec_index_16x( + const vec_type (&x)[16], const int idx, const nvfp4_scale_t S_dec_b_fp8_map4, + const nvfp4_scale_t S_dec_b_fp8_map6, const scaling_coeff_type SFcoefficient_map4, + const scaling_coeff_type SFcoefficient_map6, const float global_amax, + nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { + float first_half[8]; + float second_half[8]; +#pragma unroll + for (int i = 0; i < 8; ++i) { + first_half[i] = static_cast(x[i].data.elt[idx]); + second_half[i] = static_cast(x[i + 8].data.elt[idx]); + } + + quantize_4over6_16x( + first_half, second_half, S_dec_b_fp8_map4, S_dec_b_fp8_map6, SFcoefficient_map4, + SFcoefficient_map6, global_amax, S_dec_b_fp8, rOut); +} + +#endif // FP4_TYPE_SUPPORTED + +} // namespace core +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_4OVER6_NVFP4_CUH_ diff --git a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh index 998539836f..6bfda2f24e 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh @@ -21,6 +21,7 @@ #include "../../../util/ptx.cuh" #include "../../../utils.cuh" #include "../core_nvfp4.cuh" +#include "../quantize_4over6_nvfp4.cuh" namespace transformer_engine { namespace dispatch { @@ -227,60 +228,43 @@ __device__ __forceinline__ void colwise_scaling( } const float block_amax[2] = {static_cast(__habs(thread_amax_2x.x)), static_cast(__habs(thread_amax_2x.y))}; + #pragma unroll for (int w = 0; w < 2; ++w) { - __align__(8) uint32_t rOut[SCALE_DIM / 8]; - nvfp4_scale_t S_dec_b_fp8; - if constexpr (USE_4OVER6) { + __align__(8) uint32_t rOut[SCALE_DIM / 8]; + nvfp4_scale_t S_dec_b_fp8; nvfp4_scale_t S_dec_b_fp8_map4; nvfp4_scale_t S_dec_b_fp8_map6; core::compute_4over6_decoding_scaling_factors(block_amax[w], S_enc_colwise, S_dec_b_fp8_map4, S_dec_b_fp8_map6); - const scaling_coeff_type SFcoefficient_map4 = compute_nvfp4_scaling_coefficient(S_dec_b_fp8_map4, S_enc_colwise); const scaling_coeff_type SFcoefficient_map6 = compute_nvfp4_scaling_coefficient(S_dec_b_fp8_map6, S_enc_colwise); - float err_map4 = 0.0f; - float err_map6 = 0.0f; - __align__(8) uint32_t rOut_map4[SCALE_DIM / 8]; - __align__(8) uint32_t rOut_map6[SCALE_DIM / 8]; -#pragma unroll - for (int e = 0; e < SCALE_DIM / 8; ++e) { - const float x[8] = { - static_cast(rIn[w][8 * e + 0]), static_cast(rIn[w][8 * e + 1]), - static_cast(rIn[w][8 * e + 2]), static_cast(rIn[w][8 * e + 3]), - static_cast(rIn[w][8 * e + 4]), static_cast(rIn[w][8 * e + 5]), - static_cast(rIn[w][8 * e + 6]), static_cast(rIn[w][8 * e + 7]), - }; - rOut_map4[e] = core::cvt_fp32_to_fp4_8x_with_mse_rn( - x, static_cast(SFcoefficient_map4), S_dec_b_fp8_map4, global_amax_colwise, - &err_map4); - rOut_map6[e] = core::cvt_fp32_to_fp4_8x_with_mse_rn( - x, static_cast(SFcoefficient_map6), S_dec_b_fp8_map6, global_amax_colwise, - &err_map6); - } + core::quantize_4over6_contiguous_16x( + rIn[w], S_dec_b_fp8_map4, S_dec_b_fp8_map6, SFcoefficient_map4, SFcoefficient_map6, + global_amax_colwise, S_dec_b_fp8, rOut); - if (err_map4 < err_map6) { - S_dec_b_fp8 = S_dec_b_fp8_map4; -#pragma unroll - for (int e = 0; e < SCALE_DIM / 8; ++e) { - rOut[e] = rOut_map4[e]; - } - } else { - S_dec_b_fp8 = S_dec_b_fp8_map6; -#pragma unroll - for (int e = 0; e < SCALE_DIM / 8; ++e) { - rOut[e] = rOut_map6[e]; - } - } + // Store scaling factors to SMEM buffer (R2S) + sSFcolwise[scale_tr_offset_Y + w][scale_tr_offset_X] = S_dec_b_fp8; + + uint64_t &out_pack_16x = *reinterpret_cast(rOut); + ptx::st_shared_b64(&sOut_tr[buff_out_tr][out_tr_thread_offset_Y + w][out_tr_thread_offset_X], + out_pack_16x); } else { - S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax[w], S_enc_colwise); + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax[w], S_enc_colwise); + + // Store scaling factors to SMEM buffer (R2S) + sSFcolwise[scale_tr_offset_Y + w][scale_tr_offset_X] = S_dec_b_fp8; + const scaling_coeff_type SFcoefficient = compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_colwise); + // Scale elements + __align__(8) uint32_t rOut[SCALE_DIM / 8]; #pragma unroll for (int e = 0; e < SCALE_DIM / 8; ++e) { const uint64_t elts03 = *reinterpret_cast(&rIn[w][8 * e]); @@ -295,14 +279,11 @@ __device__ __forceinline__ void colwise_scaling( SFcoefficient); } } - } - - // Store scaling factors to SMEM buffer (R2S) - sSFcolwise[scale_tr_offset_Y + w][scale_tr_offset_X] = S_dec_b_fp8; - uint64_t &out_pack_16x = *reinterpret_cast(rOut); - ptx::st_shared_b64(&sOut_tr[buff_out_tr][out_tr_thread_offset_Y + w][out_tr_thread_offset_X], - out_pack_16x); + uint64_t &out_pack_16x = *reinterpret_cast(rOut); + ptx::st_shared_b64(&sOut_tr[buff_out_tr][out_tr_thread_offset_Y + w][out_tr_thread_offset_X], + out_pack_16x); + } } } @@ -358,32 +339,28 @@ __device__ __forceinline__ void rowwise_scaling( } const float block_amax = get_amax_of_pair(thread_amax_2x); - nvfp4_scale_t S_dec_b_fp8; - float block_S_enc_rowwise; - float block_global_amax; - if constexpr (ROW_SCALED_NVFP4) { - const size_t row_idx = row_offset + stage_Y * TILE_DIM_Y + it_offset_Y_rowwise; - if (row_idx < rows) { - block_global_amax = amax_rowwise_ptr[row_idx]; - block_S_enc_rowwise = - core::compute_global_encode_scaling_factor_FP4(block_global_amax); - } else { - block_global_amax = 1.0f; - block_S_enc_rowwise = 1.0f; - } - } else { - block_global_amax = *amax_rowwise_ptr; - block_S_enc_rowwise = S_enc_rowwise; - } - - __align__(8) uint32_t rOut[WAVES]; - if constexpr (USE_4OVER6) { + nvfp4_scale_t S_dec_b_fp8; nvfp4_scale_t S_dec_b_fp8_map4; nvfp4_scale_t S_dec_b_fp8_map6; + float block_S_enc_rowwise; + float block_global_amax; + if constexpr (ROW_SCALED_NVFP4) { + const size_t row_idx = row_offset + stage_Y * TILE_DIM_Y + it_offset_Y_rowwise; + if (row_idx < rows) { + block_global_amax = amax_rowwise_ptr[row_idx]; + block_S_enc_rowwise = + core::compute_global_encode_scaling_factor_FP4(block_global_amax); + } else { + block_global_amax = 1.0f; + block_S_enc_rowwise = 1.0f; + } + } else { + block_global_amax = *amax_rowwise_ptr; + block_S_enc_rowwise = S_enc_rowwise; + } core::compute_4over6_decoding_scaling_factors(block_amax, block_S_enc_rowwise, S_dec_b_fp8_map4, S_dec_b_fp8_map6); - const scaling_coeff_type SFcoefficient_map4 = compute_nvfp4_scaling_coefficient(S_dec_b_fp8_map4, block_S_enc_rowwise); @@ -391,112 +368,76 @@ __device__ __forceinline__ void rowwise_scaling( compute_nvfp4_scaling_coefficient(S_dec_b_fp8_map6, block_S_enc_rowwise); - float err_map4 = 0.0f; - float err_map6 = 0.0f; - __align__(8) uint32_t rOut_map4[WAVES]; - __align__(8) uint32_t rOut_map6[WAVES]; - + __align__(8) uint32_t rOut[WAVES]; if (bank_group == 0) { - const float x0[8] = { - static_cast(rIn[0][0].x), static_cast(rIn[0][0].y), - static_cast(rIn[0][1].x), static_cast(rIn[0][1].y), - static_cast(rIn[0][2].x), static_cast(rIn[0][2].y), - static_cast(rIn[0][3].x), static_cast(rIn[0][3].y), - }; - rOut_map4[0] = core::cvt_fp32_to_fp4_8x_with_mse_rn( - x0, static_cast(SFcoefficient_map4), S_dec_b_fp8_map4, block_global_amax, - &err_map4); - rOut_map6[0] = core::cvt_fp32_to_fp4_8x_with_mse_rn( - x0, static_cast(SFcoefficient_map6), S_dec_b_fp8_map6, block_global_amax, - &err_map6); - - const float x1[8] = { - static_cast(rIn[1][0].x), static_cast(rIn[1][0].y), - static_cast(rIn[1][1].x), static_cast(rIn[1][1].y), - static_cast(rIn[1][2].x), static_cast(rIn[1][2].y), - static_cast(rIn[1][3].x), static_cast(rIn[1][3].y), - }; - rOut_map4[1] = core::cvt_fp32_to_fp4_8x_with_mse_rn( - x1, static_cast(SFcoefficient_map4), S_dec_b_fp8_map4, block_global_amax, - &err_map4); - rOut_map6[1] = core::cvt_fp32_to_fp4_8x_with_mse_rn( - x1, static_cast(SFcoefficient_map6), S_dec_b_fp8_map6, block_global_amax, - &err_map6); + core::quantize_4over6_pair_array_16x( + rIn, S_dec_b_fp8_map4, S_dec_b_fp8_map6, SFcoefficient_map4, SFcoefficient_map6, + block_global_amax, S_dec_b_fp8, rOut); } else { - const float x1[8] = { - static_cast(rIn[1][0].x), static_cast(rIn[1][0].y), - static_cast(rIn[1][1].x), static_cast(rIn[1][1].y), - static_cast(rIn[1][2].x), static_cast(rIn[1][2].y), - static_cast(rIn[1][3].x), static_cast(rIn[1][3].y), - }; - rOut_map4[1] = core::cvt_fp32_to_fp4_8x_with_mse_rn( - x1, static_cast(SFcoefficient_map4), S_dec_b_fp8_map4, block_global_amax, - &err_map4); - rOut_map6[1] = core::cvt_fp32_to_fp4_8x_with_mse_rn( - x1, static_cast(SFcoefficient_map6), S_dec_b_fp8_map6, block_global_amax, - &err_map6); - - const float x0[8] = { - static_cast(rIn[0][0].x), static_cast(rIn[0][0].y), - static_cast(rIn[0][1].x), static_cast(rIn[0][1].y), - static_cast(rIn[0][2].x), static_cast(rIn[0][2].y), - static_cast(rIn[0][3].x), static_cast(rIn[0][3].y), - }; - rOut_map4[0] = core::cvt_fp32_to_fp4_8x_with_mse_rn( - x0, static_cast(SFcoefficient_map4), S_dec_b_fp8_map4, block_global_amax, - &err_map4); - rOut_map6[0] = core::cvt_fp32_to_fp4_8x_with_mse_rn( - x0, static_cast(SFcoefficient_map6), S_dec_b_fp8_map6, block_global_amax, - &err_map6); + core::quantize_4over6_pair_array_16x( + rIn, S_dec_b_fp8_map4, S_dec_b_fp8_map6, SFcoefficient_map4, SFcoefficient_map6, + block_global_amax, S_dec_b_fp8, rOut); + } + + // Store scaling factors to SMEM buffer (R2S) + if (SF_storing_thread) { + const int scales_offset_Y = stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE; + const int scales_offset_X = stage_rowwise_scales_offset_X; + sSFrowwise[scales_offset_Y][scales_offset_X] = S_dec_b_fp8; } - if (err_map4 < err_map6) { - S_dec_b_fp8 = S_dec_b_fp8_map4; -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - rOut[w] = rOut_map4[w]; - } - } else { - S_dec_b_fp8 = S_dec_b_fp8_map6; #pragma unroll - for (int w = 0; w < WAVES; ++w) { - rOut[w] = rOut_map6[w]; - } + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % ELTS_PER_THREAD; + const int swizzled_idx = (swizzled_group_idx + thread_offset_X_rowwise) / 2; + ptx::st_shared_b32(&sOut[buff_out][it_offset_Y_rowwise][swizzled_idx], rOut[w]); } } else { - S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, block_S_enc_rowwise); - const scaling_coeff_type SFcoefficient = - compute_nvfp4_scaling_coefficient(S_dec_b_fp8, block_S_enc_rowwise); + nvfp4_scale_t S_dec_b_fp8; + scaling_coeff_type SFcoefficient; + if constexpr (ROW_SCALED_NVFP4) { + const size_t row_idx = row_offset + stage_Y * TILE_DIM_Y + it_offset_Y_rowwise; + const float S_enc_rowwise_block = + row_idx < rows + ? core::compute_global_encode_scaling_factor_FP4(amax_rowwise_ptr[row_idx]) + : 1.0f; + S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc_rowwise_block); + SFcoefficient = + compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_rowwise_block); + } else { + S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + SFcoefficient = + compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_rowwise); + } + + // Store scaling factors to SMEM buffer (R2S) + if (SF_storing_thread) { + const int scales_offset_Y = stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE; + const int scales_offset_X = stage_rowwise_scales_offset_X; + sSFrowwise[scales_offset_Y][scales_offset_X] = S_dec_b_fp8; + } +// Scale elements #pragma unroll for (int w = 0; w < WAVES; ++w) { const uint64_t elts03 = *reinterpret_cast(&rIn[w][0]); const uint64_t elts47 = *reinterpret_cast(&rIn[w][2]); + uint32_t out_x8; if constexpr (USE_STOCHASTIC_ROUNDING) { const uint32_t rbits03 = core::get_rbits(rng, random_uint4, rnd_idx); const uint32_t rbits47 = core::get_rbits(rng, random_uint4, rnd_idx); - rOut[w] = ptx::mul_cvt_bf16_to_fp4_8x_stochastic_rounding( + out_x8 = ptx::mul_cvt_bf16_to_fp4_8x_stochastic_rounding( elts03, elts47, SFcoefficient, rbits03, rbits47); } else { - rOut[w] = ptx::mul_cvt_bf16_to_fp4_8x_round_to_nearest(elts03, elts47, - SFcoefficient); + out_x8 = ptx::mul_cvt_bf16_to_fp4_8x_round_to_nearest(elts03, elts47, + SFcoefficient); } - } - } - // Store scaling factors to SMEM buffer (R2S) - if (SF_storing_thread) { - const int scales_offset_Y = stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE; - const int scales_offset_X = stage_rowwise_scales_offset_X; - sSFrowwise[scales_offset_Y][scales_offset_X] = S_dec_b_fp8; - } - -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % ELTS_PER_THREAD; - const int swizzled_idx = (swizzled_group_idx + thread_offset_X_rowwise) / 2; - ptx::st_shared_b32(&sOut[buff_out][it_offset_Y_rowwise][swizzled_idx], rOut[w]); + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % ELTS_PER_THREAD; + const int swizzled_idx = (swizzled_group_idx + thread_offset_X_rowwise) / 2; + ptx::st_shared_b32(&sOut[buff_out][it_offset_Y_rowwise][swizzled_idx], out_x8); + } } } } diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index c3bfb0ab7c..26f7170fc2 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -15,6 +15,7 @@ #include #include "common/cast/nvfp4/core_nvfp4.cuh" +#include "common/cast/nvfp4/quantize_4over6_nvfp4.cuh" #include "common/common.h" #include "common/recipe/recipe_common.cuh" #include "common/transpose/cast_transpose.h" @@ -538,40 +539,13 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo row_global_amax = global_amax[0]; } - float err_map4 = 0.0f; - float err_map6 = 0.0f; - uint32_t output_vec_map4[2]; - uint32_t output_vec_map6[2]; -#pragma unroll - for (int i = 0; i < kNVecOut / kNVecSMem; i += 4) { - const int out_idx = i / 4; - const float x[8] = { - static_cast(smem_vec[i + 0].data.elt[0]), - static_cast(smem_vec[i + 0].data.elt[1]), - static_cast(smem_vec[i + 1].data.elt[0]), - static_cast(smem_vec[i + 1].data.elt[1]), - static_cast(smem_vec[i + 2].data.elt[0]), - static_cast(smem_vec[i + 2].data.elt[1]), - static_cast(smem_vec[i + 3].data.elt[0]), - static_cast(smem_vec[i + 3].data.elt[1]), - }; - output_vec_map4[out_idx] = - transformer_engine::dispatch::nvfp4::core::cvt_fp32_to_fp4_8x_with_mse_rn< - kUseFastMath>(x, encode_scale_map4, scale_inv_map4, row_global_amax, &err_map4); - output_vec_map6[out_idx] = - transformer_engine::dispatch::nvfp4::core::cvt_fp32_to_fp4_8x_with_mse_rn< - kUseFastMath>(x, encode_scale_map6, scale_inv_map6, row_global_amax, &err_map6); - } - - if (err_map4 < err_map6) { - scale_inv = scale_inv_map4; - *reinterpret_cast(&output_vec.data.elt[0]) = output_vec_map4[0]; - *reinterpret_cast(&output_vec.data.elt[4]) = output_vec_map4[1]; - } else { - scale_inv = scale_inv_map6; - *reinterpret_cast(&output_vec.data.elt[0]) = output_vec_map6[0]; - *reinterpret_cast(&output_vec.data.elt[4]) = output_vec_map6[1]; - } + uint32_t output_vec_4over6[2]; + transformer_engine::dispatch::nvfp4::core::quantize_4over6_vec2_array_16x( + smem_vec, scale_inv_map4, scale_inv_map6, encode_scale_map4, encode_scale_map6, + row_global_amax, scale_inv, output_vec_4over6); + transformer_engine::dispatch::nvfp4::core::store_4over6_packed_16x(output_vec_4over6, + output_vec); } else { scale_inv = ComputeDecodeScaleFP4(amax, row_global_encode_scale_multiplier); encode_scale = ComputeEncodeScaleFP4(scale_inv, row_global_decode_scale); @@ -712,40 +686,13 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo const float encode_scale_map6 = ComputeEncodeScaleFP4(scale_inv_map6, global_decode_scale); - float err_map4 = 0.0f; - float err_map6 = 0.0f; - uint32_t output_vec_map4[2]; - uint32_t output_vec_map6[2]; -#pragma unroll - for (int i = 0; i < kNVecOut / kNFP4PerContainer; i += 4) { - const int out_idx = i / 4; - const float x[8] = { - static_cast(smem_vec[2 * (i + 0)].data.elt[smem_idx]), - static_cast(smem_vec[2 * (i + 0) + 1].data.elt[smem_idx]), - static_cast(smem_vec[2 * (i + 1)].data.elt[smem_idx]), - static_cast(smem_vec[2 * (i + 1) + 1].data.elt[smem_idx]), - static_cast(smem_vec[2 * (i + 2)].data.elt[smem_idx]), - static_cast(smem_vec[2 * (i + 2) + 1].data.elt[smem_idx]), - static_cast(smem_vec[2 * (i + 3)].data.elt[smem_idx]), - static_cast(smem_vec[2 * (i + 3) + 1].data.elt[smem_idx]), - }; - output_vec_map4[out_idx] = - transformer_engine::dispatch::nvfp4::core::cvt_fp32_to_fp4_8x_with_mse_rn< - kUseFastMath>(x, encode_scale_map4, scale_inv_map4, global_amax[0], &err_map4); - output_vec_map6[out_idx] = - transformer_engine::dispatch::nvfp4::core::cvt_fp32_to_fp4_8x_with_mse_rn< - kUseFastMath>(x, encode_scale_map6, scale_inv_map6, global_amax[0], &err_map6); - } - - if (err_map4 < err_map6) { - scale_inv = scale_inv_map4; - *reinterpret_cast(&output_vec.data.elt[0]) = output_vec_map4[0]; - *reinterpret_cast(&output_vec.data.elt[4]) = output_vec_map4[1]; - } else { - scale_inv = scale_inv_map6; - *reinterpret_cast(&output_vec.data.elt[0]) = output_vec_map6[0]; - *reinterpret_cast(&output_vec.data.elt[4]) = output_vec_map6[1]; - } + uint32_t output_vec_4over6[2]; + transformer_engine::dispatch::nvfp4::core::quantize_4over6_vec_index_16x( + smem_vec, smem_idx, scale_inv_map4, scale_inv_map6, encode_scale_map4, + encode_scale_map6, global_amax[0], scale_inv, output_vec_4over6); + transformer_engine::dispatch::nvfp4::core::store_4over6_packed_16x(output_vec_4over6, + output_vec); } else { scale_inv = ComputeDecodeScaleFP4(amax, global_encode_scale_multiplier); encode_scale = ComputeEncodeScaleFP4(scale_inv, global_decode_scale); From f4e4a4e8ed5b82daf7634a084a4a3020d153d3b3 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 9 May 2026 22:42:50 -0700 Subject: [PATCH 15/57] Further extract Signed-off-by: Ziang Li --- .../cast/nvfp4/quantize_4over6_nvfp4.cuh | 149 ++++++++++++------ .../quantize_transpose_nvfp4_tuned_1D.cuh | 41 ++--- ...quantize_transpose_vector_blockwise_fp4.cu | 34 ++-- 3 files changed, 121 insertions(+), 103 deletions(-) diff --git a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh index eada49ed45..0a27544323 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh @@ -16,6 +16,8 @@ #include #include +#include + #include "core_nvfp4.cuh" namespace transformer_engine { @@ -34,6 +36,62 @@ __device__ __forceinline__ void compute_4over6_decoding_scaling_factors( S_dec_b_fp8_map6 = static_cast(sf_high_precision); } +template +struct QuantizationScales4Over6 { + nvfp4_scale_t S_dec_b_fp8_map4; + nvfp4_scale_t S_dec_b_fp8_map6; + scaling_coeff_type SFcoefficient_map4; + scaling_coeff_type SFcoefficient_map6; +}; + +template +__device__ __forceinline__ scaling_coeff_type +compute_4over6_nvfp4_scaling_coefficient(const nvfp4_scale_t S_dec_block, const float S_enc) { + if constexpr (std::is_same_v) { + const float S_dec = 1.0f / S_enc; + const float scale_rcp = + fminf(1.0f / (static_cast(S_dec_block) * S_dec), detail::TypeExtrema::max); + return scale_rcp; + } else if constexpr (std::is_same_v) { + const float scale_rcp = + fminf(S_enc / static_cast(S_dec_block), detail::TypeExtrema::max); + return static_cast(scale_rcp); + } else { + NVTE_DEVICE_ERROR("Unsupported scaling-factor type. Only FP32 and BF16 are supported."); + return scaling_coeff_type{}; + } +} + +template +__device__ __forceinline__ QuantizationScales4Over6 +compute_4over6_nvfp4_quantization_scaling_factors(const float block_amax, const float S_enc) { + QuantizationScales4Over6 scaling_factors; + compute_4over6_decoding_scaling_factors(block_amax, S_enc, scaling_factors.S_dec_b_fp8_map4, + scaling_factors.S_dec_b_fp8_map6); + scaling_factors.SFcoefficient_map4 = compute_4over6_nvfp4_scaling_coefficient( + scaling_factors.S_dec_b_fp8_map4, S_enc); + scaling_factors.SFcoefficient_map6 = compute_4over6_nvfp4_scaling_coefficient( + scaling_factors.S_dec_b_fp8_map6, S_enc); + return scaling_factors; +} + +__device__ __forceinline__ QuantizationScales4Over6 +compute_4over6_fp4_encode_quantization_scaling_factors(const float block_amax, + const float global_encode_scale, + const float global_decode_scale) { + QuantizationScales4Over6 scaling_factors; + compute_4over6_decoding_scaling_factors(block_amax, global_encode_scale, + scaling_factors.S_dec_b_fp8_map4, + scaling_factors.S_dec_b_fp8_map6); + scaling_factors.SFcoefficient_map4 = + fminf(1.0f / (static_cast(scaling_factors.S_dec_b_fp8_map4) * global_decode_scale), + detail::TypeExtrema::max); + scaling_factors.SFcoefficient_map6 = + fminf(1.0f / (static_cast(scaling_factors.S_dec_b_fp8_map6) * global_decode_scale), + detail::TypeExtrema::max); + return scaling_factors; +} + template __device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_mse_rn(const float (&x)[8], const float block_scale_inverse, @@ -166,12 +224,11 @@ __device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_mse_rn(const float ( return out; } -template +template __device__ __forceinline__ void quantize_4over6_16x( const float (&first_half)[8], const float (&second_half)[8], - const nvfp4_scale_t S_dec_b_fp8_map4, const nvfp4_scale_t S_dec_b_fp8_map6, - const scaling_coeff_type SFcoefficient_map4, const scaling_coeff_type SFcoefficient_map6, - const float global_amax, nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { + const QuantizationScales4Over6 &scaling_factors, const float global_amax, + nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { float err_map4 = 0.0f; float err_map6 = 0.0f; __align__(8) uint32_t rOut_map4[2]; @@ -179,38 +236,38 @@ __device__ __forceinline__ void quantize_4over6_16x( if constexpr (REVERSE_PACK_ORDER) { rOut_map4[1] = cvt_fp32_to_fp4_8x_with_mse_rn( - second_half, static_cast(SFcoefficient_map4), S_dec_b_fp8_map4, global_amax, - &err_map4); + second_half, static_cast(scaling_factors.SFcoefficient_map4), + scaling_factors.S_dec_b_fp8_map4, global_amax, &err_map4); rOut_map6[1] = cvt_fp32_to_fp4_8x_with_mse_rn( - second_half, static_cast(SFcoefficient_map6), S_dec_b_fp8_map6, global_amax, - &err_map6); + second_half, static_cast(scaling_factors.SFcoefficient_map6), + scaling_factors.S_dec_b_fp8_map6, global_amax, &err_map6); rOut_map4[0] = cvt_fp32_to_fp4_8x_with_mse_rn( - first_half, static_cast(SFcoefficient_map4), S_dec_b_fp8_map4, global_amax, - &err_map4); + first_half, static_cast(scaling_factors.SFcoefficient_map4), + scaling_factors.S_dec_b_fp8_map4, global_amax, &err_map4); rOut_map6[0] = cvt_fp32_to_fp4_8x_with_mse_rn( - first_half, static_cast(SFcoefficient_map6), S_dec_b_fp8_map6, global_amax, - &err_map6); + first_half, static_cast(scaling_factors.SFcoefficient_map6), + scaling_factors.S_dec_b_fp8_map6, global_amax, &err_map6); } else { rOut_map4[0] = cvt_fp32_to_fp4_8x_with_mse_rn( - first_half, static_cast(SFcoefficient_map4), S_dec_b_fp8_map4, global_amax, - &err_map4); + first_half, static_cast(scaling_factors.SFcoefficient_map4), + scaling_factors.S_dec_b_fp8_map4, global_amax, &err_map4); rOut_map6[0] = cvt_fp32_to_fp4_8x_with_mse_rn( - first_half, static_cast(SFcoefficient_map6), S_dec_b_fp8_map6, global_amax, - &err_map6); + first_half, static_cast(scaling_factors.SFcoefficient_map6), + scaling_factors.S_dec_b_fp8_map6, global_amax, &err_map6); rOut_map4[1] = cvt_fp32_to_fp4_8x_with_mse_rn( - second_half, static_cast(SFcoefficient_map4), S_dec_b_fp8_map4, global_amax, - &err_map4); + second_half, static_cast(scaling_factors.SFcoefficient_map4), + scaling_factors.S_dec_b_fp8_map4, global_amax, &err_map4); rOut_map6[1] = cvt_fp32_to_fp4_8x_with_mse_rn( - second_half, static_cast(SFcoefficient_map6), S_dec_b_fp8_map6, global_amax, - &err_map6); + second_half, static_cast(scaling_factors.SFcoefficient_map6), + scaling_factors.S_dec_b_fp8_map6, global_amax, &err_map6); } if (err_map4 < err_map6) { - S_dec_b_fp8 = S_dec_b_fp8_map4; + S_dec_b_fp8 = scaling_factors.S_dec_b_fp8_map4; rOut[0] = rOut_map4[0]; rOut[1] = rOut_map4[1]; } else { - S_dec_b_fp8 = S_dec_b_fp8_map6; + S_dec_b_fp8 = scaling_factors.S_dec_b_fp8_map6; rOut[0] = rOut_map6[0]; rOut[1] = rOut_map6[1]; } @@ -223,11 +280,10 @@ __device__ __forceinline__ void store_4over6_packed_16x(const uint32_t (&packed) *reinterpret_cast(&output_vec.data.elt[4]) = packed[1]; } -template __device__ __forceinline__ void quantize_4over6_contiguous_16x( - const input_type *x, const nvfp4_scale_t S_dec_b_fp8_map4, const nvfp4_scale_t S_dec_b_fp8_map6, - const scaling_coeff_type SFcoefficient_map4, const scaling_coeff_type SFcoefficient_map6, + const input_type *x, const QuantizationScales4Over6 &scaling_factors, const float global_amax, nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { float first_half[8]; float second_half[8]; @@ -237,18 +293,15 @@ __device__ __forceinline__ void quantize_4over6_contiguous_16x( second_half[i] = static_cast(x[i + 8]); } - quantize_4over6_16x( - first_half, second_half, S_dec_b_fp8_map4, S_dec_b_fp8_map6, SFcoefficient_map4, - SFcoefficient_map6, global_amax, S_dec_b_fp8, rOut); + quantize_4over6_16x(first_half, second_half, scaling_factors, + global_amax, S_dec_b_fp8, rOut); } -template __device__ __forceinline__ void quantize_4over6_pair_array_16x( - const pair_type (&x)[2][4], const nvfp4_scale_t S_dec_b_fp8_map4, - const nvfp4_scale_t S_dec_b_fp8_map6, const scaling_coeff_type SFcoefficient_map4, - const scaling_coeff_type SFcoefficient_map6, const float global_amax, - nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { + const pair_type (&x)[2][4], const QuantizationScales4Over6 &scaling_factors, + const float global_amax, nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { float first_half[8]; float second_half[8]; #pragma unroll @@ -259,18 +312,15 @@ __device__ __forceinline__ void quantize_4over6_pair_array_16x( second_half[2 * i + 1] = static_cast(x[1][i].y); } - quantize_4over6_16x( - first_half, second_half, S_dec_b_fp8_map4, S_dec_b_fp8_map6, SFcoefficient_map4, - SFcoefficient_map6, global_amax, S_dec_b_fp8, rOut); + quantize_4over6_16x(first_half, second_half, scaling_factors, + global_amax, S_dec_b_fp8, rOut); } -template __device__ __forceinline__ void quantize_4over6_vec2_array_16x( - const vec_type (&x)[8], const nvfp4_scale_t S_dec_b_fp8_map4, - const nvfp4_scale_t S_dec_b_fp8_map6, const scaling_coeff_type SFcoefficient_map4, - const scaling_coeff_type SFcoefficient_map6, const float global_amax, - nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { + const vec_type (&x)[8], const QuantizationScales4Over6 &scaling_factors, + const float global_amax, nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { float first_half[8]; float second_half[8]; #pragma unroll @@ -281,17 +331,15 @@ __device__ __forceinline__ void quantize_4over6_vec2_array_16x( second_half[2 * i + 1] = static_cast(x[i + 4].data.elt[1]); } - quantize_4over6_16x( - first_half, second_half, S_dec_b_fp8_map4, S_dec_b_fp8_map6, SFcoefficient_map4, - SFcoefficient_map6, global_amax, S_dec_b_fp8, rOut); + quantize_4over6_16x(first_half, second_half, scaling_factors, + global_amax, S_dec_b_fp8, rOut); } -template __device__ __forceinline__ void quantize_4over6_vec_index_16x( - const vec_type (&x)[16], const int idx, const nvfp4_scale_t S_dec_b_fp8_map4, - const nvfp4_scale_t S_dec_b_fp8_map6, const scaling_coeff_type SFcoefficient_map4, - const scaling_coeff_type SFcoefficient_map6, const float global_amax, + const vec_type (&x)[16], const int idx, + const QuantizationScales4Over6 &scaling_factors, const float global_amax, nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { float first_half[8]; float second_half[8]; @@ -301,9 +349,8 @@ __device__ __forceinline__ void quantize_4over6_vec_index_16x( second_half[i] = static_cast(x[i + 8].data.elt[idx]); } - quantize_4over6_16x( - first_half, second_half, S_dec_b_fp8_map4, S_dec_b_fp8_map6, SFcoefficient_map4, - SFcoefficient_map6, global_amax, S_dec_b_fp8, rOut); + quantize_4over6_16x(first_half, second_half, scaling_factors, + global_amax, S_dec_b_fp8, rOut); } #endif // FP4_TYPE_SUPPORTED diff --git a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh index 6bfda2f24e..17c27f827a 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh @@ -234,18 +234,12 @@ __device__ __forceinline__ void colwise_scaling( if constexpr (USE_4OVER6) { __align__(8) uint32_t rOut[SCALE_DIM / 8]; nvfp4_scale_t S_dec_b_fp8; - nvfp4_scale_t S_dec_b_fp8_map4; - nvfp4_scale_t S_dec_b_fp8_map6; - core::compute_4over6_decoding_scaling_factors(block_amax[w], S_enc_colwise, S_dec_b_fp8_map4, - S_dec_b_fp8_map6); - const scaling_coeff_type SFcoefficient_map4 = - compute_nvfp4_scaling_coefficient(S_dec_b_fp8_map4, S_enc_colwise); - const scaling_coeff_type SFcoefficient_map6 = - compute_nvfp4_scaling_coefficient(S_dec_b_fp8_map6, S_enc_colwise); - - core::quantize_4over6_contiguous_16x( - rIn[w], S_dec_b_fp8_map4, S_dec_b_fp8_map6, SFcoefficient_map4, SFcoefficient_map6, - global_amax_colwise, S_dec_b_fp8, rOut); + const auto scaling_factors = + core::compute_4over6_nvfp4_quantization_scaling_factors( + block_amax[w], S_enc_colwise); + + core::quantize_4over6_contiguous_16x(rIn[w], scaling_factors, + global_amax_colwise, S_dec_b_fp8, rOut); // Store scaling factors to SMEM buffer (R2S) sSFcolwise[scale_tr_offset_Y + w][scale_tr_offset_X] = S_dec_b_fp8; @@ -341,8 +335,6 @@ __device__ __forceinline__ void rowwise_scaling( if constexpr (USE_4OVER6) { nvfp4_scale_t S_dec_b_fp8; - nvfp4_scale_t S_dec_b_fp8_map4; - nvfp4_scale_t S_dec_b_fp8_map6; float block_S_enc_rowwise; float block_global_amax; if constexpr (ROW_SCALED_NVFP4) { @@ -359,24 +351,17 @@ __device__ __forceinline__ void rowwise_scaling( block_global_amax = *amax_rowwise_ptr; block_S_enc_rowwise = S_enc_rowwise; } - core::compute_4over6_decoding_scaling_factors(block_amax, block_S_enc_rowwise, - S_dec_b_fp8_map4, S_dec_b_fp8_map6); - const scaling_coeff_type SFcoefficient_map4 = - compute_nvfp4_scaling_coefficient(S_dec_b_fp8_map4, - block_S_enc_rowwise); - const scaling_coeff_type SFcoefficient_map6 = - compute_nvfp4_scaling_coefficient(S_dec_b_fp8_map6, - block_S_enc_rowwise); + const auto scaling_factors = + core::compute_4over6_nvfp4_quantization_scaling_factors( + block_amax, block_S_enc_rowwise); __align__(8) uint32_t rOut[WAVES]; if (bank_group == 0) { - core::quantize_4over6_pair_array_16x( - rIn, S_dec_b_fp8_map4, S_dec_b_fp8_map6, SFcoefficient_map4, SFcoefficient_map6, - block_global_amax, S_dec_b_fp8, rOut); + core::quantize_4over6_pair_array_16x(rIn, scaling_factors, block_global_amax, + S_dec_b_fp8, rOut); } else { - core::quantize_4over6_pair_array_16x( - rIn, S_dec_b_fp8_map4, S_dec_b_fp8_map6, SFcoefficient_map4, SFcoefficient_map6, - block_global_amax, S_dec_b_fp8, rOut); + core::quantize_4over6_pair_array_16x( + rIn, scaling_factors, block_global_amax, S_dec_b_fp8, rOut); } // Store scaling factors to SMEM buffer (R2S) diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index 26f7170fc2..44abd108e9 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -520,14 +520,9 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo float encode_scale; OVec output_vec; if constexpr (kUse4Over6) { - ScaleType scale_inv_map4; - ScaleType scale_inv_map6; - transformer_engine::dispatch::nvfp4::core::compute_4over6_decoding_scaling_factors( - amax, row_global_encode_scale, scale_inv_map4, scale_inv_map6); - const float encode_scale_map4 = - ComputeEncodeScaleFP4(scale_inv_map4, row_global_decode_scale); - const float encode_scale_map6 = - ComputeEncodeScaleFP4(scale_inv_map6, row_global_decode_scale); + const auto scaling_factors = transformer_engine::dispatch::nvfp4::core:: + compute_4over6_fp4_encode_quantization_scaling_factors(amax, row_global_encode_scale, + row_global_decode_scale); float row_global_amax; if constexpr (kRowScaledNVFP4) { if (row_idx < num_rows) { @@ -540,10 +535,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo } uint32_t output_vec_4over6[2]; - transformer_engine::dispatch::nvfp4::core::quantize_4over6_vec2_array_16x( - smem_vec, scale_inv_map4, scale_inv_map6, encode_scale_map4, encode_scale_map6, - row_global_amax, scale_inv, output_vec_4over6); + transformer_engine::dispatch::nvfp4::core::quantize_4over6_vec2_array_16x( + smem_vec, scaling_factors, row_global_amax, scale_inv, output_vec_4over6); transformer_engine::dispatch::nvfp4::core::store_4over6_packed_16x(output_vec_4over6, output_vec); } else { @@ -677,20 +670,13 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo float encode_scale; OVec output_vec; if constexpr (kUse4Over6) { - ScaleType scale_inv_map4; - ScaleType scale_inv_map6; - transformer_engine::dispatch::nvfp4::core::compute_4over6_decoding_scaling_factors( - amax, global_encode_scale, scale_inv_map4, scale_inv_map6); - const float encode_scale_map4 = - ComputeEncodeScaleFP4(scale_inv_map4, global_decode_scale); - const float encode_scale_map6 = - ComputeEncodeScaleFP4(scale_inv_map6, global_decode_scale); + const auto scaling_factors = transformer_engine::dispatch::nvfp4::core:: + compute_4over6_fp4_encode_quantization_scaling_factors(amax, global_encode_scale, + global_decode_scale); uint32_t output_vec_4over6[2]; - transformer_engine::dispatch::nvfp4::core::quantize_4over6_vec_index_16x( - smem_vec, smem_idx, scale_inv_map4, scale_inv_map6, encode_scale_map4, - encode_scale_map6, global_amax[0], scale_inv, output_vec_4over6); + transformer_engine::dispatch::nvfp4::core::quantize_4over6_vec_index_16x( + smem_vec, smem_idx, scaling_factors, global_amax[0], scale_inv, output_vec_4over6); transformer_engine::dispatch::nvfp4::core::store_4over6_packed_16x(output_vec_4over6, output_vec); } else { From b3f59eebc4099d277b5028379f0871770e5cd7ab Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 9 May 2026 23:04:01 -0700 Subject: [PATCH 16/57] Clean up Signed-off-by: Ziang Li --- .../common/cast/nvfp4/dequantize_nvfp4.cuh | 10 ++-------- .../quantize_transpose_nvfp4_tuned_1D.cuh | 9 +++++++-- .../common/include/transformer_engine/recipe.h | 3 --- transformer_engine/common/recipe/nvfp4.cu | 14 ++------------ 4 files changed, 11 insertions(+), 25 deletions(-) diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index 8f12be20bf..630eb540ca 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -63,14 +63,8 @@ __global__ void __launch_bounds__(512) fp4vec value; value.vec = input_vectorized[my_index]; fp8e4m3 scale = scales[my_scale_index]; - float amax; - if constexpr (ROW_SCALED_NVFP4) { - amax = tensor_amax[y]; - } else { - amax = tensor_amax[0]; - } - constexpr float fp8_max = USE_4OVER6 ? 256.0f : 448.0f; - const float factor_inv = 1.0f / (6.0f * fp8_max); + float amax = ROW_SCALED_NVFP4 ? tensor_amax[y] : tensor_amax[0]; + constexpr float factor_inv = 1.0f / (6.0f * (USE_4OVER6 ? 256.0f : 448.0f)); float final_scale = static_cast(scale) * amax * factor_inv; #pragma unroll for (int i = 0; i < 4; i++) { diff --git a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh index 17c27f827a..e210a1cb8d 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh @@ -228,7 +228,6 @@ __device__ __forceinline__ void colwise_scaling( } const float block_amax[2] = {static_cast(__habs(thread_amax_2x.x)), static_cast(__habs(thread_amax_2x.y))}; - #pragma unroll for (int w = 0; w < 2; ++w) { if constexpr (USE_4OVER6) { @@ -273,7 +272,6 @@ __device__ __forceinline__ void colwise_scaling( SFcoefficient); } } - uint64_t &out_pack_16x = *reinterpret_cast(rOut); ptx::st_shared_b64(&sOut_tr[buff_out_tr][out_tr_thread_offset_Y + w][out_tr_thread_offset_X], out_pack_16x); @@ -502,6 +500,8 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D (amax_colwise_ptr == nullptr) ? S_enc_rowwise : core::compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); + // Original NVFP4 uses a scalar per-tensor amax for both rowwise and columnwise output. + // If no dedicated columnwise amax buffer is allocated, the rowwise amax is that same scalar. const float global_amax_colwise = (amax_colwise_ptr == nullptr) ? ((amax_rowwise_ptr == nullptr) ? 1.0f : *amax_rowwise_ptr) : *amax_colwise_ptr; @@ -816,6 +816,11 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, const float *const amax_rowwise_ptr = reinterpret_cast(output->amax.dptr); const float *const amax_colwise_ptr = reinterpret_cast(output->columnwise_amax.dptr); + if (use_4over6 && return_transpose && amax_colwise_ptr == nullptr) { + NVTE_CHECK(amax_rowwise_ptr != nullptr && output->amax.numel() == 1, + "NVFP4 4over6 quantization with columnwise output requires columnwise amax " + "or scalar per-tensor rowwise amax."); + } const NVTETensor rng_state_tensor = (quant_config != nullptr) ? quant_config->rng_state : nullptr; const size_t *rng_state = nullptr; diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index fe729d3b30..cad27a2992 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -304,9 +304,6 @@ void nvte_mxfp8_scaling_partial_cast(const NVTETensor input, NVTETensor output_r * \param[in] alpha_in Input scaling factor. * \param[out] alpha_out Output scaling factor. * \param[in] stream CUDA stream used for the operation. - * - * Uses each NVFP4 tensor's 4over6 metadata to choose the matching FP8 max - * when folding global amax values into the GEMM alpha. */ void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_rowwise_amax_A, const NVTETensor inpB, const bool use_rowwise_amax_B, diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu index 9c1397be4f..23ca156092 100644 --- a/transformer_engine/common/recipe/nvfp4.cu +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -924,18 +924,8 @@ void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_r void *amax_A_ptr = use_rowwise_amax_A ? tA->amax.dptr : tA->columnwise_amax.dptr; void *amax_B_ptr = use_rowwise_amax_B ? tB->amax.dptr : tB->columnwise_amax.dptr; void *alpha_ptr = tOut->data.dptr; - float fp8_max_A; - if (tA->nvfp4_4over6) { - fp8_max_A = 256.0f; - } else { - fp8_max_A = 448.0f; - } - float fp8_max_B; - if (tB->nvfp4_4over6) { - fp8_max_B = 256.0f; - } else { - fp8_max_B = 448.0f; - } + const float fp8_max_A = tA->nvfp4_4over6 ? 256.0f : 448.0f; + const float fp8_max_B = tB->nvfp4_4over6 ? 256.0f : 448.0f; // check for not null pointers NVTE_CHECK(amax_A_ptr != nullptr, "amax_A_ptr is null"); From 31decf9ace357a51f8829839462d3052bd00383b Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 9 May 2026 23:28:07 -0700 Subject: [PATCH 17/57] Add recipe_id Signed-off-by: Ziang Li --- tests/pytorch/test_cpu_offloading.py | 18 +++++++++--------- tests/pytorch/test_numerics.py | 20 ++++++++++---------- tests/pytorch/test_sanity.py | 16 ++++++++-------- 3 files changed, 27 insertions(+), 27 deletions(-) diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index a897e492c5..a5b53bbea9 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -19,7 +19,7 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager import transformer_engine.pytorch as te from transformer_engine.common import recipe -from utils import ModelConfig, skip_unsupported_backward_override +from utils import ModelConfig, recipe_id, skip_unsupported_backward_override import transformer_engine_torch as tex # Check supported quantization schemes @@ -270,7 +270,7 @@ def memory_leak_check(): class TestsOffloadableLayerState: @pytest.mark.parametrize("random_num_tensors", [True, False]) - @pytest.mark.parametrize("recipe", quantization_recipes) + @pytest.mark.parametrize("recipe", quantization_recipes, ids=recipe_id) def test_general(self, random_num_tensors, recipe): """ Test general functionality of DefaultOffloadSynchronizer - offload NUM_LAYERS-1 out of NUM_LAYERS layers, @@ -366,7 +366,7 @@ def test_offload_base_tensor(self): class TestsDefaultOffloadSynchronizer: @pytest.mark.parametrize("random_num_tensors", [True, False]) - @pytest.mark.parametrize("recipe", quantization_recipes) + @pytest.mark.parametrize("recipe", quantization_recipes, ids=recipe_id) def test_general(self, random_num_tensors, recipe): """ Test general functionality of DefaultOffloadSynchronizer - offload NUM_LAYERS-1 out of NUM_LAYERS layers, @@ -412,7 +412,7 @@ def test_general(self, random_num_tensors, recipe): offload_synchronizer.finish_part_of_bwd() torch.cuda.synchronize() - @pytest.mark.parametrize("recipe", quantization_recipes) + @pytest.mark.parametrize("recipe", quantization_recipes, ids=recipe_id) def test_memory(self, recipe): torch.cuda.synchronize() Utils.memory_leak_check() @@ -467,7 +467,7 @@ def test_memory(self, recipe): ) assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory, 0.1) - @pytest.mark.parametrize("recipe", quantization_recipes) + @pytest.mark.parametrize("recipe", quantization_recipes, ids=recipe_id) def test_multiple_tensor_offload(self, recipe): Utils.memory_leak_check() init_cpu_memory = Utils.get_cpu_memory_mb() @@ -498,7 +498,7 @@ def test_multiple_tensor_offload(self, recipe): class TestTELayers: @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) - @pytest.mark.parametrize("recipe", quantization_recipes) + @pytest.mark.parametrize("recipe", quantization_recipes, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) def test_sanity(self, layer_type, recipe, backward_override): Utils.memory_leak_check() @@ -545,7 +545,7 @@ def test_sanity(self, layer_type, recipe, backward_override): del out, inp, layers @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) - @pytest.mark.parametrize("recipe", quantization_recipes) + @pytest.mark.parametrize("recipe", quantization_recipes, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) def test_memory(self, layer_type, recipe, backward_override): Utils.memory_leak_check() @@ -638,7 +638,7 @@ def test_memory(self, layer_type, recipe, backward_override): out.sum().backward() @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) - @pytest.mark.parametrize("recipe", quantization_recipes) + @pytest.mark.parametrize("recipe", quantization_recipes, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) def test_manual_synchronization(self, recipe, layer_type, backward_override): Utils.memory_leak_check() @@ -707,7 +707,7 @@ def test_manual_synchronization(self, recipe, layer_type, backward_override): out_1.sum().backward() out_2.sum().backward() - @pytest.mark.parametrize("recipe", quantization_recipes) + @pytest.mark.parametrize("recipe", quantization_recipes, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) @pytest.mark.parametrize("use_cuda_graphs", [True, False]) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 430f5c12f3..eeded6ff45 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -54,7 +54,7 @@ from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor from transformer_engine.common import recipe import transformer_engine_torch as tex -from utils import ModelConfig, reset_rng_states, skip_unsupported_backward_override +from utils import ModelConfig, recipe_id, reset_rng_states, skip_unsupported_backward_override # Only run FP8 tests on supported devices. @@ -656,7 +656,7 @@ def _test_e2e_selective_recompute( @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", all_boolean) -@pytest.mark.parametrize("recipe", fp8_recipes) +@pytest.mark.parametrize("recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("fp8_model_params", all_boolean) def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params): if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: @@ -772,7 +772,7 @@ def _test_e2e_full_recompute( @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", all_boolean) -@pytest.mark.parametrize("recipe", fp8_recipes) +@pytest.mark.parametrize("recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_reentrant", all_boolean) def test_gpt_full_activation_recompute( @@ -1361,7 +1361,7 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_ @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("model", ["small"]) -@pytest.mark.parametrize("recipe", fp8_recipes + [None]) +@pytest.mark.parametrize("recipe", fp8_recipes + [None], ids=recipe_id) def test_linear_accuracy_save_original_input(dtype, model, recipe): bs = 1 fuse_wgrad_accumulation = True @@ -1932,7 +1932,7 @@ def _test_grouped_linear_accuracy( @pytest.mark.parametrize("num_gemms", [3, 6]) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) -@pytest.mark.parametrize("recipe", fp8_recipes + [None]) +@pytest.mark.parametrize("recipe", fp8_recipes + [None], ids=recipe_id) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) @pytest.mark.parametrize("bias", all_boolean) @@ -2078,7 +2078,7 @@ def test_grouped_linear_accuracy_cutlass( @pytest.mark.parametrize("num_gemms", [3]) @pytest.mark.parametrize("bs", [1]) @pytest.mark.parametrize("model", ["126m"]) -@pytest.mark.parametrize("recipe", fp8_recipes + [None]) +@pytest.mark.parametrize("recipe", fp8_recipes + [None], ids=recipe_id) @pytest.mark.parametrize("fp8_model_params", [False]) @pytest.mark.parametrize("fuse_wgrad_accumulation", [True]) @pytest.mark.parametrize("bias", [False]) @@ -2183,7 +2183,7 @@ def test_grouped_linear_accuracy_save_original_input( torch.testing.assert_close(o, o_ref, rtol=0, atol=0) -@pytest.mark.parametrize("recipe", fp8_recipes + [None]) +@pytest.mark.parametrize("recipe", fp8_recipes + [None], ids=recipe_id) def test_grouped_linear_accuracy_single_gemm(recipe): """Split the tests to save CI time""" test_grouped_linear_accuracy( @@ -2297,7 +2297,7 @@ def _generate_random_numbers(n, total_sum): @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", [True]) -@pytest.mark.parametrize("recipe", fp8_recipes) +@pytest.mark.parametrize("recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("fp8_model_params", all_boolean) def test_padding_grouped_linear_accuracy( dtype, @@ -2375,7 +2375,7 @@ def test_padding_grouped_linear_accuracy( @pytest.mark.parametrize("bs", [1]) @pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", [True]) -@pytest.mark.parametrize("recipe", fp8_recipes) +@pytest.mark.parametrize("recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("fp8_model_params", [False]) def test_padding_grouped_linear_accuracy_save_original_input( dtype, @@ -2609,7 +2609,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) -@pytest.mark.parametrize("recipe", fp8_recipes) +@pytest.mark.parametrize("recipe", fp8_recipes, ids=recipe_id) def test_gpt_fp8_parameters(dtype, bs, model, recipe): if NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 80d057c6d5..277ed7ca4d 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -654,7 +654,7 @@ def test_sanity_grouped_linear( @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean) @@ -704,7 +704,7 @@ def test_sanity_layernorm_mlp( @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("bias", all_boolean) @@ -777,7 +777,7 @@ def test_sanity_gpt_126m(): @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("normalization", all_normalizations) @@ -833,7 +833,7 @@ def test_sanity_bert_126m(): @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("normalization", all_normalizations) @@ -889,7 +889,7 @@ def test_sanity_T5_126m(): @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): @@ -922,7 +922,7 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("model", ["small"]) def test_sanity_drop_path(dtype, fp8_recipe, model): config = model_configs[model] @@ -957,7 +957,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model): @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): @@ -993,7 +993,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgrad): From 2fa6b8cbd0a5eff27f894f0f43615e87214e34d3 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 10 May 2026 00:26:19 -0700 Subject: [PATCH 18/57] Fix failing unit tests Signed-off-by: Ziang Li --- tests/pytorch/test_recipe.py | 3 +++ tests/pytorch/test_torch_compile.py | 13 ++++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 095fd60b8c..5b18a53f70 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -517,6 +517,9 @@ def test_quantizer_update(self, module_class): @pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) def test_nvfp4_row_scaled_quantizer_roles(use_4over6): recipe = NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, enable_4over6=use_4over6, row_scaled_activation=True, ) diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py index 74be8cb358..c3eb024b94 100644 --- a/tests/pytorch/test_torch_compile.py +++ b/tests/pytorch/test_torch_compile.py @@ -126,8 +126,19 @@ def __fx_repr__(self): def _make_qfactory(tag: str): """Return a qfactory that produces ToyQuantizer instances tagged with *tag*.""" + quantizers = { + role: ToyQuantizer(tag=f"{tag}:{role}") + for role in ( + "linear_input", + "linear_weight", + "linear_output", + "linear_grad_output", + "linear_grad_input", + ) + } + def qfactory(role: str): - return ToyQuantizer(tag=f"{tag}:{role}") + return quantizers[role] return qfactory From 7df2db0d1fefbb13697dc254f81ef5037a564219 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 10 May 2026 01:28:30 -0700 Subject: [PATCH 19/57] Clean up test Signed-off-by: Ziang Li --- tests/pytorch/test_cpu_offloading.py | 25 +------------------------ tests/pytorch/test_fusible_ops.py | 11 ++++++++--- 2 files changed, 9 insertions(+), 27 deletions(-) diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index a5b53bbea9..8020a2ba10 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -249,10 +249,6 @@ def get_saved_tensor_gpu_size_mb(tensor): return sum(Utils.get_saved_tensor_gpu_size_mb(t) for t in tensor) return Utils.get_tensor_size_mb(tensor) - @staticmethod - def keeps_small_nvfp4_row_amax_on_gpu(recipe: Optional[recipe.Recipe]): - return recipe is not None and recipe.nvfp4() and recipe.row_scaled_activation - @staticmethod def memory_leak_check(): # Should be called before each test. @@ -310,25 +306,6 @@ def test_general(self, random_num_tensors, recipe): offload_layer_state.release_all_memory() torch.cuda.synchronize() - @pytest.mark.skipif(not nvfp4_available, reason="NVFP4 requires Blackwell") - def test_nvfp4_row_scaled_amax_stays_on_gpu(self): - Utils.memory_leak_check() - stream = torch.cuda.Stream() - offload_layer_state = OffloadableLayerState( - offload_stream=stream, - ) - tensor = Utils.create_tensor(nvfp4_row_scaled()) - tensor_id = offload_layer_state.push_tensor(tensor) - assert isinstance(tensor_id, tuple) - push_results, _ = tensor_id - assert isinstance(push_results[0], int) - assert isinstance(push_results[4], torch.Tensor) - assert push_results[4].device.type == "cuda" - assert push_results[4].numel() < 256 * 1024 - del tensor, tensor_id - offload_layer_state.release_all_memory() - torch.cuda.synchronize() - def test_offload_base_tensor(self): Utils.memory_leak_check() stream = torch.cuda.Stream() @@ -618,7 +595,7 @@ def test_memory(self, layer_type, recipe, backward_override): out = out + 1 out = sync_function(out) del inp - if Utils.keeps_small_nvfp4_row_amax_on_gpu(recipe): + if recipe is not None and recipe.nvfp4() and recipe.row_scaled_activation: assert Utils.get_cuda_memory_mb() <= cuda_memory_no_offload elif backward_override is None: assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory, 0.1) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index d612fea552..9d060611d2 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1514,7 +1514,7 @@ def test_add_extra_input( if in_place: if quantization in ("fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"): tols = dtype_tols(x1_test._fp8_dtype) - elif "nvfp4" in quantization: + elif quantization is not None and "nvfp4" in quantization: tols = dtype_tols(x1_test._fp4_dtype) y_test = y_test.to(dtype=torch.float64, device="cpu") dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu") @@ -3614,7 +3614,12 @@ def test_grouped_mlp( pytest.skip("Quantized group GEMM is only supported with BF16/FP16") if with_quantization and "4over6" in quantization: pytest.skip("NVFP4 4over6 grouped quantization is not supported") - if "nvfp4" in quantization and activation == "scaled_clamped_qgeglu" and bias: + if ( + with_quantization + and "nvfp4" in quantization + and activation == "scaled_clamped_qgeglu" + and bias + ): # TODO: ksivaman: Need to debug numerics for this case. pytest.skip("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU") @@ -3836,7 +3841,7 @@ def test_grouped_mlp( # Loose tols for sanity checking tols = {"rtol": 0.125, "atol": 0.25} - if "nvfp4" in quantization: + if with_quantization and "nvfp4" in quantization: tols = {"rtol": 0.25, "atol": 0.5} # Check values From ce85be25988e77079bce6ef96276fd35c214f26c Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 10 May 2026 02:09:53 -0700 Subject: [PATCH 20/57] Clean up Signed-off-by: Ziang Li --- docs/envvars.rst | 2 +- tests/cpp/operator/test_dequantize_nvfp4.cu | 2 +- tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py | 13 ------------- .../quantize_transpose_vector_blockwise_fp4.cu | 14 +++++--------- 4 files changed, 7 insertions(+), 24 deletions(-) diff --git a/docs/envvars.rst b/docs/envvars.rst index 015593cb67..e02b798c06 100644 --- a/docs/envvars.rst +++ b/docs/envvars.rst @@ -291,7 +291,7 @@ Kernel Configuration :Type: ``int`` (0 or 1) :Default: ``0`` - :Description: Enable per-block map-to-4 versus map-to-6 candidate selection for NVFP4 1D quantization in the ``NVFP4BlockScaling`` recipe. This mode currently requires RHT, stochastic rounding, and 2D quantization to be disabled, either with the corresponding recipe fields or with :envvar:`NVTE_NVFP4_DISABLE_RHT`, :envvar:`NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING`, and :envvar:`NVTE_NVFP4_DISABLE_2D_QUANTIZATION`. + :Description: Enable per-block map-to-4 versus map-to-6 candidate selection for NVFP4 1D quantization in the ``NVFP4BlockScaling`` recipe. This mode currently requires RHT, stochastic rounding, and 2D quantization to be disabled. Torch Compilation and Fusion ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu index aebd05da5d..ea2ef14916 100644 --- a/tests/cpp/operator/test_dequantize_nvfp4.cu +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -48,7 +48,7 @@ void compute_ref_dequantize_nvfp4(const uint8_t *packed_data, size_t cols, size_t scale_stride, bool use_4over6) { - constexpr float factor_inv = 1.0f / (6.0f * (use_4over6 ? 256.0f : 448.0f)); + const float factor_inv = 1.0f / (6.0f * (use_4over6 ? 256.0f : 448.0f)); constexpr size_t BLOCK_SIZE = 16; const size_t Mread = cols / BLOCK_SIZE; const size_t bytes_per_block = BLOCK_SIZE / 2; diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 48d4717ecc..d7f9c8994e 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -2,8 +2,6 @@ # # See LICENSE for license information. -import os - import pytest import torch import transformer_engine.pytorch as te @@ -18,17 +16,6 @@ recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) -@pytest.fixture(autouse=True) -def disable_fast_math_for_exact_reference_tests(): - original = os.environ.get("NVTE_USE_FAST_MATH") - os.environ["NVTE_USE_FAST_MATH"] = "0" - yield - if original is None: - os.environ.pop("NVTE_USE_FAST_MATH", None) - else: - os.environ["NVTE_USE_FAST_MATH"] = original - - def maybe_skip_row_scaled_unsupported_quantization( row_scaled_nvfp4: bool, return_transpose: bool, diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index 44abd108e9..b69f790531 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -34,6 +34,7 @@ using std::uint32_t; using std::uint8_t; using transformer_engine::detail::TypeExtrema; +using transformer_engine::dispatch::nvfp4::core::compute_global_encode_scaling_factor_FP4; // clang-format off /* @@ -189,12 +190,6 @@ __device__ __forceinline__ float ComputeOutputFP4(IType input, float encode_scal return static_cast(input) * encode_scale; } -template -__device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_amax) { - return transformer_engine::dispatch::nvfp4::core::compute_global_encode_scaling_factor_FP4< - kUse4Over6>(global_amax); -} - __device__ __forceinline__ uint32_t get_rbits( transformer_engine::curanddx::detail::philox4x32_native_state& rng, // NVTE_BUILD_NUM_PHILOX_ROUNDS rounds of philox4x32 @@ -414,7 +409,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo const int kNumThreadsReduce = kScaleBlockDim / kNVecOut; const float global_encode_scale = - kIsE8Scaling ? 1.0f : ComputeGlobalEncodeScaleFP4(global_amax[0]); + kIsE8Scaling ? 1.0f : compute_global_encode_scaling_factor_FP4(global_amax[0]); constexpr float fp4_max_inv = 1.0f / TypeExtrema::max; const float global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; const float global_decode_scale = 1.0 / global_encode_scale; @@ -509,8 +504,9 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo float row_global_encode_scale = global_encode_scale; if constexpr (kRowScaledNVFP4) { row_global_encode_scale = - row_idx < num_rows ? ComputeGlobalEncodeScaleFP4(global_amax[row_idx]) - : 1.0f; + row_idx < num_rows + ? compute_global_encode_scaling_factor_FP4(global_amax[row_idx]) + : 1.0f; } const float row_global_encode_scale_multiplier = kRowScaledNVFP4 ? row_global_encode_scale * fp4_max_inv : global_encode_scale_multiplier; From 1b68038149225e3cba653b5fa7ef0dbdb57b8a4d Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 10 May 2026 02:21:50 -0700 Subject: [PATCH 21/57] Refactor ref Signed-off-by: Ziang Li --- .../custom_recipes/quantization_ref_nvfp4.py | 213 ++++++++++-------- 1 file changed, 115 insertions(+), 98 deletions(-) diff --git a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py index 3f76037a6a..1929dbe884 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py @@ -456,6 +456,91 @@ def _recover_swizzled_scales( result = torch.reshape(tmp, (rounded_m, rounded_n)) return result[:m, :scale_n] + @staticmethod + def _quantize_blockwise_4over6_reference( + x: torch.Tensor, + vec_max: torch.Tensor, + global_amax: torch.Tensor, + global_encode_scale: torch.Tensor, + global_decode_scale: torch.Tensor, + row_scaled_nvfp4: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantize NVFP4 with 4over6 candidate selection.""" + m, num_blocks, tile_len_x = x.shape + n = num_blocks * tile_len_x + FLOAT4_E2M1_MAX = torch.tensor(6.0, device=x.device, dtype=torch.float32) + FLOAT8_E4M3_MAX = torch.tensor(448.0, device=x.device, dtype=torch.float32) + GLOBAL_SCALE_E4M3_MAX = torch.tensor(256.0, device=x.device, dtype=torch.float32) + + decode_scale_map6 = torch.div(vec_max, FLOAT4_E2M1_MAX) * global_encode_scale + decode_scale_map4 = decode_scale_map6 * 1.5 + decode_scale_map4 = torch.clamp( + decode_scale_map4, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX + ).to(torch.float8_e4m3fn) + decode_scale_map6 = torch.clamp( + decode_scale_map6, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX + ).to(torch.float8_e4m3fn) + + fp32_max = torch.tensor( + torch.finfo(torch.float32).max, + device=decode_scale_map4.device, + dtype=torch.float32, + ) + encode_scale_map4 = torch.min( + torch.div(1.0, decode_scale_map4.to(torch.float32) * global_decode_scale), + fp32_max, + ) + encode_scale_map6 = torch.min( + torch.div(1.0, decode_scale_map6.to(torch.float32) * global_decode_scale), + fp32_max, + ) + + clipped_x_map4 = torch.clamp( + x.to(torch.float32) * encode_scale_map4, + -FLOAT4_E2M1_MAX, + FLOAT4_E2M1_MAX, + ).reshape(m, n) + clipped_x_map6 = torch.clamp( + x.to(torch.float32) * encode_scale_map6, + -FLOAT4_E2M1_MAX, + FLOAT4_E2M1_MAX, + ).reshape(m, n) + qx_map4 = cast_to_fp4x2(clipped_x_map4) + qx_map6 = cast_to_fp4x2(clipped_x_map6) + + fp4_map4 = cast_from_fp4x2(qx_map4, torch.float32).view(m, num_blocks, tile_len_x) + fp4_map6 = cast_from_fp4x2(qx_map6, torch.float32).view(m, num_blocks, tile_len_x) + denom = FLOAT4_E2M1_MAX * GLOBAL_SCALE_E4M3_MAX + sf_map4 = decode_scale_map4.to(torch.float32).squeeze(-1) + sf_map6 = decode_scale_map6.to(torch.float32).squeeze(-1) + if row_scaled_nvfp4: + mse_global_amax = global_amax.squeeze(-1) + else: + mse_global_amax = global_amax + x_float = x.to(torch.float32) + err_map4 = torch.zeros_like(vec_max) + err_map6 = torch.zeros_like(vec_max) + for idx in range(tile_len_x): + val_map4 = fp4_map4[:, :, idx] * sf_map4 + val_map4 = val_map4 * mse_global_amax + val_map4 = val_map4 / denom + diff_map4 = val_map4 - x_float[:, :, idx] + err_map4 = err_map4 + (diff_map4 * diff_map4).unsqueeze(-1) + + val_map6 = fp4_map6[:, :, idx] * sf_map6 + val_map6 = val_map6 * mse_global_amax + val_map6 = val_map6 / denom + diff_map6 = val_map6 - x_float[:, :, idx] + err_map6 = err_map6 + (diff_map6 * diff_map6).unsqueeze(-1) + pick_map4 = err_map4 < err_map6 + qx = torch.where( + pick_map4.expand(-1, -1, tile_len_x // 2), + qx_map4.view(m, num_blocks, tile_len_x // 2), + qx_map6.view(m, num_blocks, tile_len_x // 2), + ).reshape(m, n // 2) + decode_scale = torch.where(pick_map4, decode_scale_map4, decode_scale_map6).squeeze(-1) + return qx, decode_scale + @classmethod def _quantize_blockwise_reference( cls, @@ -539,107 +624,39 @@ def _quantize_blockwise_reference( # FourOverSix compares map-to-4 and map-to-6 candidates using # the original input-domain MSE, while keeping TE-style FP4 # quantization for each candidate. - decode_scale_map6 = torch.div(vec_max, FLOAT4_E2M1_MAX) * global_encode_scale - decode_scale_map4 = decode_scale_map6 * 1.5 - decode_scale_map4 = torch.clamp( - decode_scale_map4, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX - ).to(torch.float8_e4m3fn) - decode_scale_map6 = torch.clamp( - decode_scale_map6, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX - ).to(torch.float8_e4m3fn) - - fp32_max = torch.tensor( - torch.finfo(torch.float32).max, - device=decode_scale_map4.device, - dtype=torch.float32, - ) - encode_scale_map4 = torch.min( - torch.div(1.0, decode_scale_map4.to(torch.float32) * global_decode_scale), - fp32_max, - ) - encode_scale_map6 = torch.min( - torch.div(1.0, decode_scale_map6.to(torch.float32) * global_decode_scale), - fp32_max, + return cls._quantize_blockwise_4over6_reference( + x, + vec_max, + global_amax, + global_encode_scale, + global_decode_scale, + row_scaled_nvfp4, ) - clipped_x_map4 = torch.clamp( - x.to(torch.float32) * encode_scale_map4, - -FLOAT4_E2M1_MAX, - FLOAT4_E2M1_MAX, - ).reshape(m, n) - clipped_x_map6 = torch.clamp( - x.to(torch.float32) * encode_scale_map6, - -FLOAT4_E2M1_MAX, - FLOAT4_E2M1_MAX, - ).reshape(m, n) - qx_map4 = cast_to_fp4x2(clipped_x_map4) - qx_map6 = cast_to_fp4x2(clipped_x_map6) - - fp4_map4 = cast_from_fp4x2(qx_map4, torch.float32).view( - m, n // tile_len_x, tile_len_x - ) - fp4_map6 = cast_from_fp4x2(qx_map6, torch.float32).view( - m, n // tile_len_x, tile_len_x - ) - denom = FLOAT4_E2M1_MAX * GLOBAL_SCALE_E4M3_MAX - sf_map4 = decode_scale_map4.to(torch.float32).squeeze(-1) - sf_map6 = decode_scale_map6.to(torch.float32).squeeze(-1) - if row_scaled_nvfp4: - mse_global_amax = global_amax.squeeze(-1) - else: - mse_global_amax = global_amax - x_float = x.to(torch.float32) - err_map4 = torch.zeros_like(vec_max) - err_map6 = torch.zeros_like(vec_max) - for idx in range(tile_len_x): - val_map4 = fp4_map4[:, :, idx] * sf_map4 - val_map4 = val_map4 * mse_global_amax - val_map4 = val_map4 / denom - diff_map4 = val_map4 - x_float[:, :, idx] - err_map4 = err_map4 + (diff_map4 * diff_map4).unsqueeze(-1) - - val_map6 = fp4_map6[:, :, idx] * sf_map6 - val_map6 = val_map6 * mse_global_amax - val_map6 = val_map6 / denom - diff_map6 = val_map6 - x_float[:, :, idx] - err_map6 = err_map6 + (diff_map6 * diff_map6).unsqueeze(-1) - pick_map4 = err_map4 < err_map6 - qx = torch.where( - pick_map4.expand(-1, -1, tile_len_x // 2), - qx_map4.view(m, n // tile_len_x, tile_len_x // 2), - qx_map6.view(m, n // tile_len_x, tile_len_x // 2), - ).reshape(m, n // 2) - decode_scale = torch.where(pick_map4, decode_scale_map4, decode_scale_map6).squeeze( - -1 - ) - return qx, decode_scale - else: - global_encode_scale_multiplier = global_encode_scale * torch.reciprocal( - FLOAT4_E2M1_MAX - ) + global_encode_scale_multiplier = global_encode_scale * torch.reciprocal(FLOAT4_E2M1_MAX) - # Match the kernel's default path: fold the FP4 reciprocal into the - # global scale multiplier, but keep the final reciprocal exact. - decode_scale = vec_max * global_encode_scale_multiplier - decode_scale = torch.min( - decode_scale, - torch.tensor( - torch.finfo(torch.float32).max, - device=decode_scale.device, - dtype=torch.float32, - ), - ) - decode_scale = torch.clamp(decode_scale, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX) - decode_scale = decode_scale.to(torch.float8_e4m3fn) - - encode_scale = torch.min( - torch.div(1.0, decode_scale.to(torch.float32) * global_decode_scale), - torch.tensor( - torch.finfo(torch.float32).max, - device=decode_scale.device, - dtype=torch.float32, - ), - ) + # Match the kernel's default path: fold the FP4 reciprocal into the + # global scale multiplier, but keep the final reciprocal exact. + decode_scale = vec_max * global_encode_scale_multiplier + decode_scale = torch.min( + decode_scale, + torch.tensor( + torch.finfo(torch.float32).max, + device=decode_scale.device, + dtype=torch.float32, + ), + ) + decode_scale = torch.clamp(decode_scale, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX) + decode_scale = decode_scale.to(torch.float8_e4m3fn) + + encode_scale = torch.min( + torch.div(1.0, decode_scale.to(torch.float32) * global_decode_scale), + torch.tensor( + torch.finfo(torch.float32).max, + device=decode_scale.device, + dtype=torch.float32, + ), + ) scaled_x = x.to(torch.float32) * encode_scale From bb722a330339640ffacb1b3862749bbbd99d7691 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 10 May 2026 02:59:25 -0700 Subject: [PATCH 22/57] Update comments and docs Signed-off-by: Ziang Li --- docs/envvars.rst | 2 +- transformer_engine/common/cast/nvfp4/core_nvfp4.cuh | 1 + .../common/cast/nvfp4/quantize_4over6_nvfp4.cuh | 6 ++++++ transformer_engine/common/common.h | 4 +++- .../common/include/transformer_engine/transformer_engine.h | 6 +++++- transformer_engine/common/recipe/__init__.py | 1 + transformer_engine/pytorch/csrc/common.h | 1 + transformer_engine/pytorch/csrc/extensions/cast.cpp | 1 + .../pytorch/custom_recipes/quantization_ref_nvfp4.py | 7 ++++++- transformer_engine/pytorch/tensor/nvfp4_tensor.py | 2 +- .../pytorch/tensor/storage/grouped_tensor_storage.py | 2 +- .../pytorch/tensor/storage/nvfp4_tensor_storage.py | 2 +- 12 files changed, 28 insertions(+), 7 deletions(-) diff --git a/docs/envvars.rst b/docs/envvars.rst index e02b798c06..3285c7ec5b 100644 --- a/docs/envvars.rst +++ b/docs/envvars.rst @@ -291,7 +291,7 @@ Kernel Configuration :Type: ``int`` (0 or 1) :Default: ``0`` - :Description: Enable per-block map-to-4 versus map-to-6 candidate selection for NVFP4 1D quantization in the ``NVFP4BlockScaling`` recipe. This mode currently requires RHT, stochastic rounding, and 2D quantization to be disabled. + :Description: Enable per-block map-to-4 versus map-to-6 candidate selection for NVFP4 1D quantization in the ``NVFP4BlockScaling`` recipe. The selected block scale is the candidate with lower input-domain MSE, and ties select map-to-6. This mode uses 256 as the global E4M3 scale bound instead of 448 bound, and currently requires RHT, stochastic rounding, and 2D quantization to be disabled. Torch Compilation and Fusion ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh index 92a6b7963e..d04417d47d 100644 --- a/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh @@ -76,6 +76,7 @@ namespace core { using namespace ptx; // Compute the global encode scale factor for a given global amax +// 4over6 uses 256 instead of 448 to leave room for the map-to-4 scale expansion template __device__ __forceinline__ float compute_global_encode_scaling_factor_FP4(const float global_amax) { using namespace detail; diff --git a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh index 0a27544323..6548f3a5bc 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh @@ -6,6 +6,12 @@ /*! \file quantize_4over6_nvfp4.cuh * \brief Helpers used by NVFP4 4over6 quantization. + * + * 4over6 evaluates two TE-style NVFP4 encodings for each 1x16 block. The + * map-to-6 candidate uses the normal block scale. The map-to-4 candidate uses + * a 1.5x expanded block scale, which maps the FP4 value 4 to the same dynamic + * range as FP4 value 6. The selected candidate is the one with lower MSE after + * dequantizing back to the original input domain; ties select map-to-6. */ #ifndef TRANSFORMER_ENGINE_QUANTIZE_4OVER6_NVFP4_CUH_ diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index b9d1c3f70e..be59dd5068 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -180,7 +180,9 @@ struct Tensor { bool row_scaled_nvfp4 = false; /*! \brief Whether NVFP4 uses 4over6 block scale selection. * - * Only meaningful for NVFP4 tensors. + * Only meaningful for NVFP4 tensors. 4over6 tensors use 256 as their + * global E4M3 scale bound and store the lower-MSE map-to-4/map-to-6 + * candidate for each 1x16 block. */ bool nvfp4_4over6 = false; diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 35bf020e8b..8c8155f99b 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -382,7 +382,11 @@ enum NVTEQuantizationConfigAttribute { * inconsistently between kernels. */ kNVTEQuantizationConfigUseFastMath = 7, - /*! Whether to use NVFP4 4over6 block scale selection */ + /*! Whether to use NVFP4 4over6 block scale selection. + * + * 4over6 evaluates map-to-4 and map-to-6 candidates for each 1x16 block, + * stores the lower-MSE candidate, and uses a 256 global E4M3 scale bound. + */ kNVTEQuantizationConfigNVFP44Over6 = 8, kNVTEQuantizationConfigNumAttributes }; diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index aedd3458fe..48838621fb 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -525,6 +525,7 @@ class NVFP4BlockScaling(Recipe): enable_4over6 : bool, default = False If set to `True`, NVFP4 1D quantization evaluates per-block map-to-4 and map-to-6 candidates and chooses the one with lower MSE. + Ties choose map-to-6. The global E4M3 scale bound is 256 in this mode instead of 448. backward_override : {None, 'high_precision', 'dequantized'}, default = None Backward precision mode. None does not modify backward behavior, `high_precision` keeps original high-precision operands for backward, diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 6f14fcdf89..2f411b3bd6 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -327,6 +327,7 @@ class NVFP4Quantizer : public Quantizer { // 2D block scaling bool with_2d_quantization; bool stochastic_rounding; + // Whether emitted NVFP4 tensors use 4over6 candidate selection. bool use_4over6; // Whether tensors emitted by this quantizer use row-scaled NVFP4 metadata. bool row_scaled_nvfp4; diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index ceb654076d..6d05955148 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -1201,6 +1201,7 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, config.set_nvfp4_4over6(quantizer.use_4over6); } + // Fast math affects the 4over6 MSE computation when 4over6 is enabled const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); if (use_fast_math) { for (auto &config : quant_config_list) { diff --git a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py index 1929dbe884..ced235a372 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py @@ -465,7 +465,12 @@ def _quantize_blockwise_4over6_reference( global_decode_scale: torch.Tensor, row_scaled_nvfp4: bool, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Quantize NVFP4 with 4over6 candidate selection.""" + """Quantize NVFP4 with 4over6 candidate selection. + + This mirrors the CUDA path: map-to-4 uses a 1.5x expanded E4M3 block scale, + MSE is computed in the original input domain with the 6 * 256 denominator, + and ties choose map-to-6. + """ m, num_blocks, tile_len_x = x.shape n = num_blocks * tile_len_x FLOAT4_E2M1_MAX = torch.tensor(6.0, device=x.device, dtype=torch.float32) diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 2d0ed5ba89..625977ba48 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -130,7 +130,7 @@ class NVFP4Quantizer(Quantizer): """Whether emitted NVFP4 tensors store one FP32 amax per row.""" row_scaled_nvfp4: bool - """Whether to use NVFP4 4over6 block scale selection.""" + """Whether to use NVFP4 4over6 map-to-4/map-to-6 block selection.""" use_4over6: bool """RHT matrix random sign mask""" diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index 97272ab954..713d04ccf2 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -322,7 +322,7 @@ def row_scaled_nvfp4(self, row_scaled_nvfp4: bool) -> None: @property def use_4over6(self) -> bool: - """Whether grouped NVFP4 tensors use 4over6 block scale selection.""" + """Whether grouped NVFP4 tensors carry 4over6 metadata.""" return self._use_4over6 @use_4over6.setter diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index 94bf9cd3e6..ff4dcd8a79 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -99,7 +99,7 @@ class NVFP4TensorStorage(QuantizedTensorStorage): _with_gemm_swizzled_scales: bool # Whether this NVFP4 tensor uses row-scaled amax metadata _row_scaled_nvfp4: bool - # Whether this NVFP4 tensor uses 4over6 block scale selection + # Whether this NVFP4 tensor uses 4over6 map-to-4/map-to-6 block selection _use_4over6: bool def __new__( From fe18a1e4d772202a477f20dd4cb7e0ac4733cae7 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 10 May 2026 03:12:58 -0700 Subject: [PATCH 23/57] Drop unnecessary test_sanity workaround The following tests passed: `NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 python3 -m pytest --tb=auto tests/pytorch/test_sanity.py ` `NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=tests/pytorch/debug/test_configs/dummy_feature.yaml NVTE_TEST_NVINSPECT_FEATURE_DIRS=transformer_engine/debug/features PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto tests/pytorch/test_sanity.py ` Signed-off-by: Ziang Li --- tests/pytorch/test_sanity.py | 8 ++++---- .../include/transformer_engine/transformer_engine.h | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 277ed7ca4d..c4733bb7ba 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -628,12 +628,12 @@ def test_sanity_grouped_linear( single_grouped_bias=single_param, ).cuda() - # Verify grouped linear exposes grouped params when the experimental mode is enabled. + # Verify grouped linear exposes a single grouped weight parameter(and bias when applicable). if fp8_recipe is None or not (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling()): - if te_grouped_linear.single_grouped_weight: + if single_param: check_grouped_weight(te_grouped_linear, num_gemms, ffn_hidden_size, config.hidden_size) - if use_bias and te_grouped_linear.single_grouped_bias: - check_grouped_bias(te_grouped_linear, num_gemms, ffn_hidden_size) + if use_bias: + check_grouped_bias(te_grouped_linear, num_gemms, ffn_hidden_size) inp_hidden_states = torch.randn( num_tokens, config.hidden_size, dtype=dtype, requires_grad=True diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 8c8155f99b..1b78dcf826 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -83,7 +83,7 @@ enum NVTETensorParam { * its values are populated during quantization. */ kNVTERowScaledNVFP4 = 8, - kNVTENVFP44Over6 = 9, /*!< Whether an NVFP4 tensor uses 4over6 scaling */ + kNVTENVFP44Over6 = 9, /*!< Whether an NVFP4 tensor uses 4over6 scaling */ kNVTENumTensorParams }; From 522e93e6c8b078baea9de06ee432fac0381e6d6d Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 11 May 2026 14:51:32 -0700 Subject: [PATCH 24/57] Refactor `QuantizerRole` Signed-off-by: Ziang Li --- tests/cpp/operator/test_dequantize_nvfp4.cu | 3 +++ tests/cpp/test_common.cu | 6 ++++++ tests/cpp/test_common.h | 2 ++ tests/pytorch/test_recipe.py | 21 +++++++++++++++++++++ 4 files changed, 32 insertions(+) diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu index ea2ef14916..e1289b7d60 100644 --- a/tests/cpp/operator/test_dequantize_nvfp4.cu +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -108,6 +108,7 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, // Configure quantized tensor amax size_t amax_size = 1; quantized.set_nvfp4_4over6(use_4over6); + ASSERT_EQ(quantized.nvfp4_4over6(), use_4over6); if (row_scaled_nvfp4) { quantized.set_row_scaled_nvfp4(true); amax_size = rows; @@ -170,6 +171,7 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, Tensor quantized_compact("quantized_compact", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); quantized_compact.set_nvfp4_4over6(use_4over6); + ASSERT_EQ(quantized_compact.nvfp4_4over6(), use_4over6); if (row_scaled_nvfp4) { quantized_compact.set_row_scaled_nvfp4(true); } else if (rows > 0 && cols > 0) { @@ -192,6 +194,7 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, Tensor quantized_swizzled("quantized_swizzled", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); quantized_swizzled.set_nvfp4_4over6(use_4over6); + ASSERT_EQ(quantized_swizzled.nvfp4_4over6(), use_4over6); if (row_scaled_nvfp4) { quantized_swizzled.set_row_scaled_nvfp4(true); } else { diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 6474540b39..4687160f1b 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -446,6 +446,12 @@ void Tensor::set_nvfp4_4over6(bool nvfp4_4over6) { tensor_.set_nvfp4_4over6(nvfp4_4over6); } +bool Tensor::nvfp4_4over6() const { + NVTE_CHECK(tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING, + "NVFP4 4over6 is only supported for NVFP4 tensors."); + return tensor_.get_nvfp4_4over6(); +} + void Tensor::to_cpu() { if (data_rowwise_) { data_rowwise_->to_cpu(); } if (data_columnwise_) { data_columnwise_->to_cpu(); } diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 06afb86e7c..46fd320080 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -293,6 +293,8 @@ class Tensor { return columnwise_; } + bool nvfp4_4over6() const; + void set_tensor_amax_nullptr(); void set_with_gemm_swizzled_scales(bool with_gemm_swizzled_scales); diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 5b18a53f70..1b906f291a 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -26,6 +26,7 @@ from transformer_engine.pytorch.quantization import ( FP8GlobalStateManager, NVFP4BlockScalingRecipeState, + QuantizerRole, _amax_and_scale_update, ) import transformer_engine.pytorch.ops as te_ops @@ -530,15 +531,35 @@ def test_nvfp4_row_scaled_quantizer_roles(use_4over6): num_quantizers=3, ).make_quantizers() assert [q.row_scaled_nvfp4 for q in forward_quantizers] == [True, False, True] + assert [q.use_4over6 for q in forward_quantizers] == [use_4over6] * 3 assert not forward_quantizers[0].is_quantizable(torch.empty(16, 16)) assert forward_quantizers[1].is_quantizable(torch.empty(16, 16)) + role_quantizers = NVFP4BlockScalingRecipeState( + recipe, + mode="forward", + num_quantizers=4, + roles=[ + QuantizerRole(module_type="linear", tensor_type="weight"), + QuantizerRole(module_type="linear", tensor_type="input"), + QuantizerRole(module_type="linear", tensor_type="output"), + None, + ], + ).make_quantizers() + assert [q.row_scaled_nvfp4 for q in role_quantizers] == [False, True, True, True] + assert [q.use_4over6 for q in role_quantizers] == [use_4over6] * 4 + backward_quantizers = NVFP4BlockScalingRecipeState( recipe, mode="backward", num_quantizers=2, + roles=[ + QuantizerRole(module_type="linear", tensor_type="grad_output"), + QuantizerRole(module_type="linear", tensor_type="grad_input"), + ], ).make_quantizers() assert [q.row_scaled_nvfp4 for q in backward_quantizers] == [False, False] + assert [q.use_4over6 for q in backward_quantizers] == [use_4over6] * 2 @pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) From 782b7eefdb92791fa8be90e177037be2d3f293d3 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 11 May 2026 15:18:56 -0700 Subject: [PATCH 25/57] Allow separate recipe 4over6 config Signed-off-by: Ziang Li --- docs/envvars.rst | 8 ++--- tests/pytorch/test_cpu_offloading.py | 7 +++-- tests/pytorch/test_cuda_graphs.py | 2 +- tests/pytorch/test_numerics.py | 2 +- tests/pytorch/test_recipe.py | 31 ++++++++++++++++---- tests/pytorch/test_sanity.py | 2 +- tests/pytorch/test_torch_compile.py | 2 +- tests/pytorch/utils.py | 8 ++--- transformer_engine/common/recipe/__init__.py | 20 ++++++++----- transformer_engine/pytorch/quantization.py | 9 +++++- 10 files changed, 63 insertions(+), 28 deletions(-) diff --git a/docs/envvars.rst b/docs/envvars.rst index 3285c7ec5b..8cd6d4df36 100644 --- a/docs/envvars.rst +++ b/docs/envvars.rst @@ -287,11 +287,11 @@ Kernel Configuration :Default: ``0`` :Description: Enable row-scaled NVFP4 tensors for forward activation quantizers in the ``NVFP4BlockScaling`` recipe. When set to ``1`` (or when ``NVFP4BlockScaling(row_scaled_activation=True)`` is used), rowwise ``amax`` metadata is stored as one FP32 value per tensor row instead of a single scalar. -.. envvar:: NVTE_NVFP4_ENABLE_4OVER6 +.. envvar:: NVTE_NVFP4_4OVER6 - :Type: ``int`` (0 or 1) - :Default: ``0`` - :Description: Enable per-block map-to-4 versus map-to-6 candidate selection for NVFP4 1D quantization in the ``NVFP4BlockScaling`` recipe. The selected block scale is the candidate with lower input-domain MSE, and ties select map-to-6. This mode uses 256 as the global E4M3 scale bound instead of 448 bound, and currently requires RHT, stochastic rounding, and 2D quantization to be disabled. + :Type: ``str`` (``weights``, ``activations``, or ``all``) + :Default: unset + :Description: Enable per-block map-to-4 versus map-to-6 candidate selection for selected NVFP4 1D quantizers in the ``NVFP4BlockScaling`` recipe. ``weights`` selects weight tensor roles, ``activations`` selects non-weight tensor roles, and ``all`` selects both. The selected block scale is the candidate with lower input-domain MSE, and ties select map-to-6. This mode uses 256 as the global E4M3 scale bound instead of the 448 bound, and currently requires RHT, stochastic rounding, and 2D quantization to be disabled. Torch Compilation and Fusion ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index 8020a2ba10..7500e95196 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -48,7 +48,7 @@ def nvfp4_4over6(): disable_rht=True, disable_stochastic_rounding=True, disable_2d_quantization=True, - enable_4over6=True, + nvfp4_4over6="all", ) nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() @@ -207,6 +207,9 @@ def create_tensor(recipe: Optional[recipe.Recipe], requires_grad: bool = False) return quantizer(tensor) elif recipe.nvfp4(): qparams = recipe.fp4_quant_fwd_inp + use_4over6 = False + if recipe.nvfp4_4over6 in ("activations", "all"): + use_4over6 = True quantizer = te.tensor.nvfp4_tensor.NVFP4Quantizer( rowwise=True, columnwise=not recipe.row_scaled_activation, @@ -215,7 +218,7 @@ def create_tensor(recipe: Optional[recipe.Recipe], requires_grad: bool = False) with_2d_quantization=qparams.fp4_2d_quantization, stochastic_rounding=qparams.stochastic_rounding, row_scaled_nvfp4=recipe.row_scaled_activation, - use_4over6=recipe.enable_4over6, + use_4over6=use_4over6, ) return quantizer(tensor) diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 0526b7b99a..ade160222b 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -83,7 +83,7 @@ def nvfp4_4over6(): disable_rht=True, disable_stochastic_rounding=True, disable_2d_quantization=True, - enable_4over6=True, + nvfp4_4over6="all", ) nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index eeded6ff45..1b37a3803a 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -157,7 +157,7 @@ def nvfp4_4over6(): disable_rht=True, disable_stochastic_rounding=True, disable_2d_quantization=True, - enable_4over6=True, + nvfp4_4over6="all", ) nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 1b906f291a..4e4a4fdc0c 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -515,23 +515,38 @@ def test_quantizer_update(self, module_class): @pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) -@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) -def test_nvfp4_row_scaled_quantizer_roles(use_4over6): +@pytest.mark.parametrize( + "nvfp4_4over6", + [None, "weights", "activations", "all"], + ids=["default", "weights", "activations", "all"], +) +def test_nvfp4_row_scaled_quantizer_roles(nvfp4_4over6): recipe = NVFP4BlockScaling( disable_rht=True, disable_stochastic_rounding=True, disable_2d_quantization=True, - enable_4over6=use_4over6, + nvfp4_4over6=nvfp4_4over6, row_scaled_activation=True, ) + def expected_use_4over6(tensor_type): + if nvfp4_4over6 == "all": + return True + if nvfp4_4over6 == "weights": + return tensor_type == "weight" + if nvfp4_4over6 == "activations": + return tensor_type != "weight" + return False + forward_quantizers = NVFP4BlockScalingRecipeState( recipe, mode="forward", num_quantizers=3, ).make_quantizers() assert [q.row_scaled_nvfp4 for q in forward_quantizers] == [True, False, True] - assert [q.use_4over6 for q in forward_quantizers] == [use_4over6] * 3 + assert [q.use_4over6 for q in forward_quantizers] == [ + expected_use_4over6(tensor_type) for tensor_type in ("input", "weight", "output") + ] assert not forward_quantizers[0].is_quantizable(torch.empty(16, 16)) assert forward_quantizers[1].is_quantizable(torch.empty(16, 16)) @@ -547,7 +562,9 @@ def test_nvfp4_row_scaled_quantizer_roles(use_4over6): ], ).make_quantizers() assert [q.row_scaled_nvfp4 for q in role_quantizers] == [False, True, True, True] - assert [q.use_4over6 for q in role_quantizers] == [use_4over6] * 4 + assert [q.use_4over6 for q in role_quantizers] == [ + expected_use_4over6(tensor_type) for tensor_type in ("weight", "input", "output", "input") + ] backward_quantizers = NVFP4BlockScalingRecipeState( recipe, @@ -559,7 +576,9 @@ def test_nvfp4_row_scaled_quantizer_roles(use_4over6): ], ).make_quantizers() assert [q.row_scaled_nvfp4 for q in backward_quantizers] == [False, False] - assert [q.use_4over6 for q in backward_quantizers] == [use_4over6] * 2 + assert [q.use_4over6 for q in backward_quantizers] == [ + expected_use_4over6(tensor_type) for tensor_type in ("grad_output", "grad_input") + ] @pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index c4733bb7ba..cff5b4a88f 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -113,7 +113,7 @@ def nvfp4_4over6(): disable_rht=True, disable_stochastic_rounding=True, disable_2d_quantization=True, - enable_4over6=True, + nvfp4_4over6="all", ) nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py index c3eb024b94..b7869da7c7 100644 --- a/tests/pytorch/test_torch_compile.py +++ b/tests/pytorch/test_torch_compile.py @@ -59,7 +59,7 @@ def nvfp4_4over6(): disable_rht=True, disable_stochastic_rounding=True, disable_2d_quantization=True, - enable_4over6=True, + nvfp4_4over6="all", ) nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 0439112d31..e425e785f5 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -151,7 +151,7 @@ def make_recipe(name: Optional[str], **recipe_kwargs: Any) -> Optional[Recipe]: "disable_stochastic_rounding": True, "disable_2d_quantization": True, "row_scaled_activation": "row_scaled" in name, - "enable_4over6": "4over6" in name, + "nvfp4_4over6": "all" if "4over6" in name else None, } kwargs.update(recipe_kwargs) return transformer_engine.common.recipe.NVFP4BlockScaling(**kwargs) @@ -162,9 +162,9 @@ def recipe_id(recipe: Optional[Recipe]) -> str: """Readable pytest id for a quantization recipe.""" if not isinstance(recipe, Recipe): return "None" - if recipe.nvfp4() and recipe.row_scaled_activation and recipe.enable_4over6: + if recipe.nvfp4() and recipe.row_scaled_activation and recipe.nvfp4_4over6 is not None: return "NVFP4RowScaled4Over6BlockScaling" - if recipe.nvfp4() and recipe.enable_4over6: + if recipe.nvfp4() and recipe.nvfp4_4over6 is not None: return "NVFP44Over6BlockScaling" if recipe.nvfp4() and recipe.row_scaled_activation: return "NVFP4RowScaledBlockScaling" @@ -187,7 +187,7 @@ def skip_unsupported_backward_override( if ( quant_recipe is not None and quant_recipe.nvfp4() - and getattr(quant_recipe, "enable_4over6", False) + and quant_recipe.nvfp4_4over6 is not None and layer_type == "grouped_linear" ): pytest.skip("NVFP4 4over6 currently does not support grouped quantization.") diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 48838621fb..9feb696d06 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -13,6 +13,7 @@ _BACKWARD_OVERRIDES = (None, "high_precision", "dequantized") +_NVFP4_4OVER6_SCOPES = (None, "weights", "activations", "all") class _FormatHelper(NamedTuple): @@ -522,10 +523,12 @@ class NVFP4BlockScaling(Recipe): If set to `True`, forward activation quantizers emit row-scaled NVFP4 tensors. In this mode, rowwise ``amax`` metadata is stored as a vector with one FP32 value per tensor row. - enable_4over6 : bool, default = False - If set to `True`, NVFP4 1D quantization evaluates per-block - map-to-4 and map-to-6 candidates and chooses the one with lower MSE. - Ties choose map-to-6. The global E4M3 scale bound is 256 in this mode instead of 448. + nvfp4_4over6 : {None, 'weights', 'activations', 'all'}, default = None + Select tensors that use NVFP4 4over6. In this mode NVFP4 1D + quantization evaluates per-block map-to-4 and map-to-6 candidates + and chooses the one with lower MSE. Ties choose map-to-6. The + global E4M3 scale bound is 256 in this mode instead of 448. The + ``activations`` scope applies to every non-weight tensor role. backward_override : {None, 'high_precision', 'dequantized'}, default = None Backward precision mode. None does not modify backward behavior, `high_precision` keeps original high-precision operands for backward, @@ -540,7 +543,7 @@ class NVFP4BlockScaling(Recipe): ) disable_2d_quantization: bool = os.getenv("NVTE_NVFP4_DISABLE_2D_QUANTIZATION", "0") == "1" row_scaled_activation: bool = os.getenv("NVTE_NVFP4_ROW_SCALED_ACTIVATION", "0") == "1" - enable_4over6: bool = os.getenv("NVTE_NVFP4_ENABLE_4OVER6", "0") == "1" + nvfp4_4over6: Optional[str] = os.getenv("NVTE_NVFP4_4OVER6", None) fp4_format: Format = Format.E2M1 fp8_format: Format = Format.E4M3 @@ -556,7 +559,10 @@ def __post_init__(self) -> None: assert ( self.backward_override in _BACKWARD_OVERRIDES ), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'." - if self.enable_4over6: + assert ( + self.nvfp4_4over6 in _NVFP4_4OVER6_SCOPES + ), "NVTE_NVFP4_4OVER6 must be unset or one of: 'weights', 'activations', 'all'." + if self.nvfp4_4over6 is not None: assert self.disable_rht, "NVFP4 4over6 currently requires RHT to be disabled" assert ( self.disable_stochastic_rounding @@ -593,7 +599,7 @@ def _make_repr(self) -> str: f"fp8_mha={self.fp8_mha}, " f"backward_override={self.backward_override}, " f"row_scaled_activation={self.row_scaled_activation}, " - f"enable_4over6={self.enable_4over6}, " + f"nvfp4_4over6={self.nvfp4_4over6}, " f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, " f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, " f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, " diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 5642344143..da13991cf4 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -1655,6 +1655,13 @@ def _qparams(tensor_type: str): def _make(tensor_type: str) -> NVFP4Quantizer: qparams = _qparams(tensor_type) + use_4over6 = False + if self.recipe.nvfp4_4over6 == "all": + use_4over6 = True + elif self.recipe.nvfp4_4over6 == "weights": + use_4over6 = tensor_type == "weight" + elif self.recipe.nvfp4_4over6 == "activations": + use_4over6 = tensor_type != "weight" return NVFP4Quantizer( fp4_dtype=self.dtype, rowwise=True, @@ -1668,7 +1675,7 @@ def _make(tensor_type: str) -> NVFP4Quantizer: and tensor_type != "weight" and self.recipe.row_scaled_activation ), - use_4over6=self.recipe.enable_4over6, + use_4over6=use_4over6, ) if self.mode not in ("forward", "backward"): From d9cd12c5f431f78fc09af6de01f5c8310a6d7e00 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 11 May 2026 19:45:30 -0700 Subject: [PATCH 26/57] Support 2d Signed-off-by: Ziang Li --- docs/envvars.rst | 2 +- .../cpp/operator/test_cast_nvfp4_transpose.cu | 189 ++++++++-- .../nvfp4/test_nvfp4_quantize_exact.py | 8 +- tests/pytorch/test_cpu_offloading.py | 3 +- tests/pytorch/test_cuda_graphs.py | 3 +- tests/pytorch/test_numerics.py | 3 +- tests/pytorch/test_sanity.py | 3 +- tests/pytorch/test_torch_compile.py | 3 +- tests/pytorch/utils.py | 5 +- .../common/cast/dispatch/quantize.cuh | 4 - .../cast/nvfp4/quantize_4over6_nvfp4.cuh | 237 ++++++++++++- .../cast/nvfp4/quantize_transpose_nvfp4.cuh | 328 ++++++++++++------ transformer_engine/common/recipe/__init__.py | 9 +- ...quantize_transpose_vector_blockwise_fp4.cu | 82 ++++- transformer_engine/pytorch/csrc/quantizer.cpp | 2 - .../custom_recipes/quantization_ref_nvfp4.py | 15 +- 16 files changed, 708 insertions(+), 188 deletions(-) diff --git a/docs/envvars.rst b/docs/envvars.rst index 8cd6d4df36..d4302ea91d 100644 --- a/docs/envvars.rst +++ b/docs/envvars.rst @@ -291,7 +291,7 @@ Kernel Configuration :Type: ``str`` (``weights``, ``activations``, or ``all``) :Default: unset - :Description: Enable per-block map-to-4 versus map-to-6 candidate selection for selected NVFP4 1D quantizers in the ``NVFP4BlockScaling`` recipe. ``weights`` selects weight tensor roles, ``activations`` selects non-weight tensor roles, and ``all`` selects both. The selected block scale is the candidate with lower input-domain MSE, and ties select map-to-6. This mode uses 256 as the global E4M3 scale bound instead of the 448 bound, and currently requires RHT, stochastic rounding, and 2D quantization to be disabled. + :Description: Enable per-block map-to-4 versus map-to-6 candidate selection for selected NVFP4 quantizers in the ``NVFP4BlockScaling`` recipe. ``weights`` selects weight tensor roles, ``activations`` selects non-weight tensor roles, and ``all`` selects both. The selected block scale is the candidate with lower input-domain MSE, and ties select map-to-6. This mode uses 256 as the global E4M3 scale bound instead of the 448 bound. Tensors using 4over6 currently require RHT and stochastic rounding to be disabled. Torch Compilation and Fusion ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index a8f58f8598..63b1e00462 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -62,8 +62,12 @@ std::vector create_transpose(const InputType* const input, const size } // Compute the global encode scale factor for a given global amax -float compute_global_encode_scaling_factor_FP4(const float global_amax, const bool use_fast_math) { - constexpr float fp8_max = 448.0f; // 448.0f; +float compute_global_encode_scaling_factor_FP4(const float global_amax, const bool use_fast_math, + const bool use_4over6 = false) { + float fp8_max = 448.0f; + if (use_4over6) { + fp8_max = 256.0f; + } constexpr float fp4_max = 6.0f; // 6.0f; float global_encode_scale = fp8_max * fp4_max / global_amax; // If scale is infinity, return the max normalized value @@ -89,10 +93,11 @@ void quantize_nvfp4_1d(float (*OP)(const float), const size_t cols, const size_t scales_stride, const float global_amax, - const bool use_fast_math) { + const bool use_fast_math, + const bool use_4over6 = false) { // Compute a global encoding/decoding scaling factor for all S_dec_b - const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math); + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math, use_4over6); constexpr size_t block_size_X = 16; const size_t blocks_X = divide_round_up(cols, block_size_X); @@ -122,6 +127,90 @@ void quantize_nvfp4_1d(float (*OP)(const float), block_amax = std::max(block_amax, std::abs(elt)); } + const size_t scale_idx = i * scales_stride + block_X; + + if (use_4over6) { + const float S_dec_b_map6 = block_amax * (S_enc * (1.0f / 6.0f)); + const fp8e4m3 S_dec_b_fp8_map6 = static_cast(S_dec_b_map6); + const fp8e4m3 S_dec_b_fp8_map4 = static_cast(S_dec_b_map6 * 1.5f); + const float S_dec_b_fp32_map6 = static_cast(S_dec_b_fp8_map6); + const float S_dec_b_fp32_map4 = static_cast(S_dec_b_fp8_map4); + + float scale_reciprocal_map6 = 0.0f; + if (S_dec_b_fp32_map6 != 0.0f) { + scale_reciprocal_map6 = + fminf(S_enc / S_dec_b_fp32_map6, Numeric_Traits::maxNorm); + } + + float scale_reciprocal_map4 = 0.0f; + if (S_dec_b_fp32_map4 != 0.0f) { + scale_reciprocal_map4 = + fminf(S_enc / S_dec_b_fp32_map4, Numeric_Traits::maxNorm); + } + + if (use_fast_math) { + scale_reciprocal_map6 = static_cast(static_cast(scale_reciprocal_map6)); + scale_reciprocal_map4 = static_cast(static_cast(scale_reciprocal_map4)); + } + + std::array output_map6; + std::array output_map4; + float err_map6 = 0.0f; + float err_map4 = 0.0f; + constexpr float mse_scale = 1.0f / (6.0f * 256.0f); + + for (size_t j = j_min; j < j_max; j += 2) { + const int cache_idx_x = j - j_min; + const int cache_idx_y = cache_idx_x + 1; + const float cached_x = cache_buffer[cache_idx_x]; + const float cached_y = cache_buffer[cache_idx_y]; + + const float2 scaled_elt_pair_map6 = { + cached_x * scale_reciprocal_map6, + cached_y * scale_reciprocal_map6, + }; + const fp4e2m1x2 casted_to_e2m1_pair_map6(scaled_elt_pair_map6); + output_map6[cache_idx_x / 2] = casted_to_e2m1_pair_map6; + const double2 truncated_pair_map6 = cvt_fp4x2_to_double2(casted_to_e2m1_pair_map6); + const double dequant_x_map6 = + truncated_pair_map6.x * S_dec_b_fp32_map6 * global_amax * mse_scale; + const double dequant_y_map6 = + truncated_pair_map6.y * S_dec_b_fp32_map6 * global_amax * mse_scale; + err_map6 += (dequant_x_map6 - cached_x) * (dequant_x_map6 - cached_x); + err_map6 += (dequant_y_map6 - cached_y) * (dequant_y_map6 - cached_y); + + const float2 scaled_elt_pair_map4 = { + cached_x * scale_reciprocal_map4, + cached_y * scale_reciprocal_map4, + }; + const fp4e2m1x2 casted_to_e2m1_pair_map4(scaled_elt_pair_map4); + output_map4[cache_idx_x / 2] = casted_to_e2m1_pair_map4; + const double2 truncated_pair_map4 = cvt_fp4x2_to_double2(casted_to_e2m1_pair_map4); + const double dequant_x_map4 = + truncated_pair_map4.x * S_dec_b_fp32_map4 * global_amax * mse_scale; + const double dequant_y_map4 = + truncated_pair_map4.y * S_dec_b_fp32_map4 * global_amax * mse_scale; + err_map4 += (dequant_x_map4 - cached_x) * (dequant_x_map4 - cached_x); + err_map4 += (dequant_y_map4 - cached_y) * (dequant_y_map4 - cached_y); + } + + const bool pick_map4 = err_map4 < err_map6; + if (pick_map4) { + scales[scale_idx] = S_dec_b_fp8_map4; + for (size_t j = j_min; j < j_max; j += 2) { + const int idx_pair = (i * cols + j) / 2; + output[idx_pair] = output_map4[(j - j_min) / 2]; + } + } else { + scales[scale_idx] = S_dec_b_fp8_map6; + for (size_t j = j_min; j < j_max; j += 2) { + const int idx_pair = (i * cols + j) / 2; + output[idx_pair] = output_map6[(j - j_min) / 2]; + } + } + continue; + } + // Compute and store the per-block FP8 decode scale const float S_dec_b = block_amax * (S_enc * (1.0f / 6.0f)); const fp8e4m3 S_dec_b_fp8 = static_cast(fminf(S_dec_b, Numeric_Traits::maxNorm)); @@ -131,7 +220,6 @@ void quantize_nvfp4_1d(float (*OP)(const float), const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : fminf(1.0f / (S_dec_b_fp32 * (1.0f / S_enc)), Numeric_Traits::maxNorm); - const size_t scale_idx = i * scales_stride + block_X; scales[scale_idx] = S_dec_b_fp8; float scale_reciprocal = S_enc_b_fp8; @@ -167,9 +255,10 @@ void compute_2d_mathematical_scales(float (*OP)(const float), const size_t cols, const float global_amax, std::vector>& math_scales, - const bool use_fast_math) { + const bool use_fast_math, + const bool use_4over6 = false) { - const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math); + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math, use_4over6); constexpr size_t block_size_Y = 16; constexpr size_t block_size_X = 16; const size_t blocks_Y = divide_round_up(rows, block_size_Y); @@ -214,13 +303,15 @@ void quantize_nvfp4_2d(float (*OP)(const float), const size_t cols, const size_t scales_stride, const float global_amax, - const bool use_fast_math) { + const bool use_fast_math, + const bool use_4over6 = false) { // Step 1: Compute mathematical 8x8 scaling factors std::vector> math_scales; - compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math); + compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math, + use_4over6); - const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math); + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math, use_4over6); constexpr size_t block_size_Y = 16; constexpr size_t block_size_X = 16; const size_t blocks_Y = divide_round_up(rows, block_size_Y); @@ -302,11 +393,14 @@ void quantize_nvfp4(float (*OP)(const float), const size_t scales_stride, const float global_amax, const bool use_fast_math, - const bool use_2d_quantization = false) { + const bool use_2d_quantization = false, + const bool use_4over6 = false) { if (use_2d_quantization) { - quantize_nvfp4_2d(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_fast_math); + quantize_nvfp4_2d(OP, input, output, scales, rows, cols, scales_stride, global_amax, + use_fast_math, use_4over6); } else { - quantize_nvfp4_1d(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_fast_math); + quantize_nvfp4_1d(OP, input, output, scales, rows, cols, scales_stride, global_amax, + use_fast_math, use_4over6); } } @@ -324,7 +418,8 @@ void compute_ref(float (*OP)(const float), const size_t scales_stride_t, const bool use_fast_math, const bool use_2d_quantization = false, - const bool row_scaled_nvfp4 = false) + const bool row_scaled_nvfp4 = false, + const bool use_4over6 = false) { std::vector input_t = create_transpose(input, rows, cols); NVTE_CHECK(!(use_2d_quantization && row_scaled_nvfp4), @@ -334,7 +429,8 @@ void compute_ref(float (*OP)(const float), if (use_2d_quantization) { // Step 1: Compute mathematical 8×8 scaling factors std::vector> math_scales; - compute_2d_mathematical_scales(OP, input, rows, cols, *amax, math_scales, use_fast_math); + compute_2d_mathematical_scales(OP, input, rows, cols, *amax, math_scales, use_fast_math, + use_4over6); constexpr size_t block_size_Y = 16; constexpr size_t block_size_X = 16; @@ -362,9 +458,9 @@ void compute_ref(float (*OP)(const float), // Step 4: Process quantized outputs using the same algorithm as quantize_nvfp4_2d // (This part processes the actual FP4 data using the mathematical scaling factors) quantize_nvfp4_2d(OP, input, output, nullptr, rows, cols, scales_stride, *amax, - use_fast_math); // scales already filled + use_fast_math, use_4over6); // scales already filled quantize_nvfp4_2d(OP, input_t.data(), output_t, nullptr, cols, rows, scales_stride_t, *amax, - use_fast_math); // scales_t already filled + use_fast_math, use_4over6); // scales_t already filled return; } @@ -381,16 +477,17 @@ void compute_ref(float (*OP)(const float), scales_stride, amax[row], use_fast_math, - use_2d_quantization); + use_2d_quantization, + use_4over6); } return; } // Ref impl for basic NVFP4 quantize_nvfp4(OP, input, output, scales, rows, cols, scales_stride, *amax, - use_fast_math, use_2d_quantization); + use_fast_math, use_2d_quantization, use_4over6); quantize_nvfp4(OP, input_t.data(), output_t, scales_t, cols, rows, scales_stride_t, *amax, - use_fast_math, use_2d_quantization); + use_fast_math, use_2d_quantization, use_4over6); } void compare_nvfp4_tensors(const std::string& name, @@ -529,7 +626,8 @@ template void performTest(float (*OP)(const float), const std::vector& shape, const bool use_fast_math, - const bool row_scaled_nvfp4 = false) { + const bool row_scaled_nvfp4 = false, + const bool use_4over6 = false) { using namespace test; DType itype = TypeInfo::dtype; @@ -560,6 +658,7 @@ void performTest(float (*OP)(const float), Tensor input("input", shape, itype); Tensor output("output", shape, otype, rowwise, columnwise, NVTE_NVFP4_1D_SCALING); + output.set_nvfp4_4over6(use_4over6); std::unique_ptr ref_output = std::make_unique(rows * (cols / 2)); std::unique_ptr ref_output_t = std::make_unique(cols * (rows / 2)); @@ -587,7 +686,11 @@ void performTest(float (*OP)(const float), output.set_row_scaled_nvfp4(row_scaled_nvfp4); } else { // Golden value of amax chosen to make the 2nd-stage scaling mantissa zero and avoid rounding issues - ref_amax.assign(1, 448.0f * 6.0f * 8.0f); + if (use_4over6) { + ref_amax.assign(1, 256.0f * 6.0f * 8.0f); + } else { + ref_amax.assign(1, 448.0f * 6.0f * 8.0f); + } // Update tensor if (rowwise) { @@ -614,7 +717,8 @@ void performTest(float (*OP)(const float), scales_stride_t, use_fast_math, use_2d_quantization, - row_scaled_nvfp4); + row_scaled_nvfp4, + use_4over6); // Initialize stochastic rounding Tensor rng_state("rng_state", std::vector{2}, DType::kInt64); @@ -628,6 +732,7 @@ void performTest(float (*OP)(const float), quant_config.set_stochastic_rounding(false); quant_config.set_rng_state(rng_state.data()); quant_config.set_nvfp4_2d_quantization(use_2d_quantization); + quant_config.set_nvfp4_4over6(use_4over6); // Call appropriate function based on operation type // Activation functions take 3 parameters (input, output, stream) @@ -707,6 +812,7 @@ class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam std::vector, transformer_engine::DType, bool, + bool, bool>> {}; TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { @@ -723,6 +829,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { const DType input_type = std::get<2>(GetParam()); const bool use_fast_math = std::get<3>(GetParam()); const bool row_scaled_nvfp4 = std::get<4>(GetParam()); + const bool use_4over6 = std::get<5>(GetParam()); // Skip tests if the input tensor is 1D if (tensor_dims.size() < 2) { @@ -740,7 +847,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { } TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, - performTest(OP, tensor_dims, use_fast_math, row_scaled_nvfp4); + performTest(OP, tensor_dims, use_fast_math, row_scaled_nvfp4, use_4over6); ); } @@ -764,6 +871,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(tensor_dims), ::testing::Values(DType::kBFloat16), ::testing::Values(false), + ::testing::Values(false), ::testing::Values(false)), [](const testing::TestParamInfo& info) { std::string name = to_string(std::get<0>(info.param)); @@ -774,6 +882,9 @@ INSTANTIATE_TEST_SUITE_P( name += "X" + test::typeName(std::get<2>(info.param)); if (std::get<3>(info.param)) { name += "X_FAST_SCALING"; + } + if (std::get<5>(info.param)) { + name += "X4OVER6"; } return name; }); @@ -786,7 +897,8 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(tensor_dims[4], tensor_dims[9], tensor_dims[12]), ::testing::Values(DType::kBFloat16, DType::kFloat32), ::testing::Values(false), - ::testing::Values(true)), + ::testing::Values(true), + ::testing::Values(false)), [](const testing::TestParamInfo& info) { std::string name = to_string(std::get<0>(info.param)); const auto& shape = std::get<1>(info.param); @@ -800,5 +912,32 @@ INSTANTIATE_TEST_SUITE_P( if (std::get<4>(info.param)) { name += "XROW_SCALED"; } + if (std::get<5>(info.param)) { + name += "X4OVER6"; + } + return name; + }); + +INSTANTIATE_TEST_SUITE_P( + OperatorTest4Over6, + FusedCastTransposeNVFP4TestSuite, + ::testing::Combine( + ::testing::Values(ActivationType::Identity), + ::testing::Values(tensor_dims[4], tensor_dims[9], tensor_dims[12]), + ::testing::Values(DType::kFloat32), + ::testing::Values(false), + ::testing::Values(false, true), + ::testing::Values(true)), + [](const testing::TestParamInfo& info) { + std::string name = to_string(std::get<0>(info.param)); + const auto& shape = std::get<1>(info.param); + for (const auto& s: shape) { + name += "X" + std::to_string(s); + } + name += "X" + test::typeName(std::get<2>(info.param)); + if (std::get<4>(info.param)) { + name += "XROW_SCALED"; + } + name += "X4OVER6"; return name; }); diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index d7f9c8994e..60c1543407 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -21,9 +21,13 @@ def maybe_skip_row_scaled_unsupported_quantization( return_transpose: bool, with_2d_quantization: bool = False, use_4over6: bool = False, + x_dtype: torch.dtype | None = None, + M: int | None = None, + N: int | None = None, ) -> None: if use_4over6 and with_2d_quantization: - pytest.skip("NVFP4 4over6 does not support 2D quantization") + if x_dtype != torch.bfloat16 or M is None or N is None or M % 32 != 0 or N % 32 != 0: + pytest.skip("NVFP4 2D 4over6 exact tests require the optimized BF16 kernel path") if not row_scaled_nvfp4: return if return_transpose: @@ -51,7 +55,7 @@ def check_quantization_nvfp4_versus_reference( use_4over6: bool = False, ) -> None: maybe_skip_row_scaled_unsupported_quantization( - row_scaled_nvfp4, return_transpose, with_2d_quantization, use_4over6 + row_scaled_nvfp4, return_transpose, with_2d_quantization, use_4over6, x_dtype, M, N ) te_dtype = tex.DType.kFloat4E2M1 diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index 7500e95196..47505c5be0 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -47,11 +47,10 @@ def nvfp4_4over6(): nvfp4_recipe = recipe.NVFP4BlockScaling( disable_rht=True, disable_stochastic_rounding=True, - disable_2d_quantization=True, nvfp4_4over6="all", ) nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() - nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(fp4_2d_quantization=True) nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() return nvfp4_recipe diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index ade160222b..bb4a4e3857 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -82,11 +82,10 @@ def nvfp4_4over6(): nvfp4_recipe = recipe.NVFP4BlockScaling( disable_rht=True, disable_stochastic_rounding=True, - disable_2d_quantization=True, nvfp4_4over6="all", ) nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() - nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(fp4_2d_quantization=True) nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() return nvfp4_recipe diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 1b37a3803a..5f82bfcba2 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -156,11 +156,10 @@ def nvfp4_4over6(): nvfp4_recipe = recipe.NVFP4BlockScaling( disable_rht=True, disable_stochastic_rounding=True, - disable_2d_quantization=True, nvfp4_4over6="all", ) nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() - nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(fp4_2d_quantization=True) nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() return nvfp4_recipe diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index cff5b4a88f..27eafbecdc 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -112,11 +112,10 @@ def nvfp4_4over6(): nvfp4_recipe = recipe.NVFP4BlockScaling( disable_rht=True, disable_stochastic_rounding=True, - disable_2d_quantization=True, nvfp4_4over6="all", ) nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() - nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(fp4_2d_quantization=True) nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() return nvfp4_recipe diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py index b7869da7c7..137e5f5a77 100644 --- a/tests/pytorch/test_torch_compile.py +++ b/tests/pytorch/test_torch_compile.py @@ -58,11 +58,10 @@ def nvfp4_4over6(): nvfp4_recipe = recipe.NVFP4BlockScaling( disable_rht=True, disable_stochastic_rounding=True, - disable_2d_quantization=True, nvfp4_4over6="all", ) nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() - nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(fp4_2d_quantization=True) nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() return nvfp4_recipe diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index e425e785f5..c58d3e751e 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -146,12 +146,13 @@ def make_recipe(name: Optional[str], **recipe_kwargs: Any) -> Optional[Recipe]: if name == "fp8_block_scaling": return transformer_engine.common.recipe.Float8BlockScaling(**recipe_kwargs) if "nvfp4" in name: + use_4over6 = "4over6" in name kwargs = { "disable_rht": True, "disable_stochastic_rounding": True, - "disable_2d_quantization": True, + "disable_2d_quantization": not use_4over6, "row_scaled_activation": "row_scaled" in name, - "nvfp4_4over6": "all" if "4over6" in name else None, + "nvfp4_4over6": "all" if use_4over6 else None, } kwargs.update(recipe_kwargs) return transformer_engine.common.recipe.NVFP4BlockScaling(**kwargs) diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index ca0fed9f16..597d034844 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -102,8 +102,6 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, auto dtype = input_tensor->dtype(); const bool row_scaled_nvfp4 = output_tensor->row_scaled_nvfp4; const bool use_4over6 = quant_config_cpp.nvfp4_4over6 || output_tensor->nvfp4_4over6; - NVTE_CHECK(!use_4over6 || !quant_config_cpp.nvfp4_2d_quantization, - "NVFP4 4over6 quantization does not support 2D quantization."); NVTE_CHECK(!use_4over6 || !quant_config_cpp.stochastic_rounding, "NVFP4 4over6 quantization does not support stochastic rounding."); quant_config_cpp.nvfp4_4over6 = use_4over6; @@ -259,8 +257,6 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens int32_t cols = grad_tensor->flat_last_dim(); auto dtype = grad_tensor->dtype(); const bool use_4over6 = quant_config_cpp.nvfp4_4over6 || output_tensor->nvfp4_4over6; - NVTE_CHECK(!use_4over6 || !quant_config_cpp.nvfp4_2d_quantization, - "NVFP4 4over6 quantization does not support 2D quantization."); NVTE_CHECK(!use_4over6 || !quant_config_cpp.stochastic_rounding, "NVFP4 4over6 quantization does not support stochastic rounding."); quant_config_cpp.nvfp4_4over6 = use_4over6; diff --git a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh index 6548f3a5bc..42fe61d9b9 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh @@ -234,12 +234,7 @@ template &scaling_factors, const float global_amax, - nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { - float err_map4 = 0.0f; - float err_map6 = 0.0f; - __align__(8) uint32_t rOut_map4[2]; - __align__(8) uint32_t rOut_map6[2]; - + float &err_map4, float &err_map6, uint32_t (&rOut_map4)[2], uint32_t (&rOut_map6)[2]) { if constexpr (REVERSE_PACK_ORDER) { rOut_map4[1] = cvt_fp32_to_fp4_8x_with_mse_rn( second_half, static_cast(scaling_factors.SFcoefficient_map4), @@ -267,6 +262,21 @@ __device__ __forceinline__ void quantize_4over6_16x( second_half, static_cast(scaling_factors.SFcoefficient_map6), scaling_factors.S_dec_b_fp8_map6, global_amax, &err_map6); } +} + +template +__device__ __forceinline__ void quantize_4over6_16x( + const float (&first_half)[8], const float (&second_half)[8], + const QuantizationScales4Over6 &scaling_factors, const float global_amax, + nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { + float err_map4 = 0.0f; + float err_map6 = 0.0f; + __align__(8) uint32_t rOut_map4[2]; + __align__(8) uint32_t rOut_map6[2]; + + quantize_4over6_16x(first_half, second_half, scaling_factors, + global_amax, err_map4, err_map6, rOut_map4, + rOut_map6); if (err_map4 < err_map6) { S_dec_b_fp8 = scaling_factors.S_dec_b_fp8_map4; @@ -279,6 +289,172 @@ __device__ __forceinline__ void quantize_4over6_16x( } } +struct QuantizationCandidates4Over6 { + float err_map4; + float err_map6; + uint32_t rOut_map4[2]; + uint32_t rOut_map6[2]; +}; + +template +__device__ __forceinline__ void quantize_4over6_candidates_16x( + const float (&x)[16], const QuantizationScales4Over6 &scaling_factors, + const float global_amax, QuantizationCandidates4Over6 &candidates) { + float first_half[8]; + float second_half[8]; +#pragma unroll + for (int i = 0; i < 8; ++i) { + first_half[i] = x[i]; + second_half[i] = x[i + 8]; + } + + candidates.err_map4 = 0.0f; + candidates.err_map6 = 0.0f; + quantize_4over6_16x(first_half, second_half, scaling_factors, global_amax, + candidates.err_map4, candidates.err_map6, candidates.rOut_map4, + candidates.rOut_map6); +} + +template +__device__ __forceinline__ void reduce_4over6_2d_block_selection( + const size_t block_in_tile_y, const size_t reduce_thread_idx, const float global_encode_scale, + const float global_decode_scale, + float (&block_amax_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X + 1], + float (&err_map4_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X][BLOCK_DIM], + float (&err_map6_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X][BLOCK_DIM], + uint8_t (&pick_map4_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X], + nvfp4_scale_t (&selected_scale_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X]) { + if (reduce_thread_idx < BLOCKS_PER_TILE_X) { + const size_t reduce_block_x = reduce_thread_idx; + float block_err_map4 = 0.0f; + float block_err_map6 = 0.0f; +#pragma unroll + for (int i = 0; i < BLOCK_DIM; ++i) { + block_err_map4 += err_map4_matrix[block_in_tile_y][reduce_block_x][i]; + block_err_map6 += err_map6_matrix[block_in_tile_y][reduce_block_x][i]; + } + + const auto scaling_factors = compute_4over6_fp4_encode_quantization_scaling_factors( + block_amax_matrix[block_in_tile_y][reduce_block_x], global_encode_scale, + global_decode_scale); + if (block_err_map4 < block_err_map6) { + pick_map4_matrix[block_in_tile_y][reduce_block_x] = 1; + selected_scale_matrix[block_in_tile_y][reduce_block_x] = scaling_factors.S_dec_b_fp8_map4; + } else { + pick_map4_matrix[block_in_tile_y][reduce_block_x] = 0; + selected_scale_matrix[block_in_tile_y][reduce_block_x] = scaling_factors.S_dec_b_fp8_map6; + } + } +} + +template +__device__ __forceinline__ void record_and_reduce_4over6_2d_block_selection( + const float block_amax, const float global_encode_scale, const float global_decode_scale, + const size_t block_in_tile_y, const size_t block_in_tile_x, const size_t participant_idx, + float (&err_map4_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X][BLOCK_DIM], + float (&err_map6_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X][BLOCK_DIM], + uint8_t (&pick_map4_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X], + nvfp4_scale_t (&selected_scale_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X], + const QuantizationCandidates4Over6 &candidates) { + err_map4_matrix[block_in_tile_y][block_in_tile_x][participant_idx] = candidates.err_map4; + err_map6_matrix[block_in_tile_y][block_in_tile_x][participant_idx] = candidates.err_map6; + __syncthreads(); + + if (participant_idx == 0) { + float block_err_map4 = 0.0f; + float block_err_map6 = 0.0f; +#pragma unroll + for (int i = 0; i < BLOCK_DIM; ++i) { + block_err_map4 += err_map4_matrix[block_in_tile_y][block_in_tile_x][i]; + block_err_map6 += err_map6_matrix[block_in_tile_y][block_in_tile_x][i]; + } + + const auto scaling_factors = compute_4over6_fp4_encode_quantization_scaling_factors( + block_amax, global_encode_scale, global_decode_scale); + if (block_err_map4 < block_err_map6) { + pick_map4_matrix[block_in_tile_y][block_in_tile_x] = 1; + selected_scale_matrix[block_in_tile_y][block_in_tile_x] = scaling_factors.S_dec_b_fp8_map4; + } else { + pick_map4_matrix[block_in_tile_y][block_in_tile_x] = 0; + selected_scale_matrix[block_in_tile_y][block_in_tile_x] = scaling_factors.S_dec_b_fp8_map6; + } + } + __syncthreads(); +} + +template +__device__ __forceinline__ void quantize_4over6_2d_block_candidate( + const float (&x)[16], const float block_amax, const float global_encode_scale, + const float global_decode_scale, const float global_amax, const size_t block_in_tile_y, + const size_t block_in_tile_x, const size_t participant_idx, const size_t reduce_thread_idx, + float (&block_amax_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X + 1], + float (&err_map4_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X][BLOCK_DIM], + float (&err_map6_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X][BLOCK_DIM], + uint8_t (&pick_map4_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X], + nvfp4_scale_t (&selected_scale_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X], + QuantizationCandidates4Over6 &candidates) { + const auto scaling_factors = compute_4over6_fp4_encode_quantization_scaling_factors( + block_amax, global_encode_scale, global_decode_scale); + quantize_4over6_candidates_16x(x, scaling_factors, global_amax, candidates); + + err_map4_matrix[block_in_tile_y][block_in_tile_x][participant_idx] = candidates.err_map4; + err_map6_matrix[block_in_tile_y][block_in_tile_x][participant_idx] = candidates.err_map6; + __syncthreads(); + + reduce_4over6_2d_block_selection( + block_in_tile_y, reduce_thread_idx, global_encode_scale, global_decode_scale, + block_amax_matrix, err_map4_matrix, err_map6_matrix, pick_map4_matrix, selected_scale_matrix); + __syncthreads(); +} + +__device__ __forceinline__ uint32_t *selected_4over6_packed( + const bool pick_map4, QuantizationCandidates4Over6 &candidates) { + if (pick_map4) { + return candidates.rOut_map4; + } + return candidates.rOut_map6; +} + +template +__device__ __forceinline__ void store_4over6_colwise_packed_16x( + const bool pick_map4, QuantizationCandidates4Over6 &candidates, const int thread_lane, + output_type *out_t_data_sh, const size_t shmem_offset_base_colwise_out_t) { + uint32_t *regs_4x = selected_4over6_packed(pick_map4, candidates); + const int group = thread_lane / 16; + uint32_t val[2]; + switch (group) { + case 0: + val[0] = regs_4x[0]; + val[1] = regs_4x[1]; + break; + case 1: + val[0] = regs_4x[1]; + val[1] = regs_4x[0]; + break; + } + uint32_t *out_t_data_sh_as_uint32_t = + reinterpret_cast(&out_t_data_sh[shmem_offset_base_colwise_out_t]); + out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2; + out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2; +} + +template +__device__ __forceinline__ void store_4over6_rowwise_packed_16x( + const bool pick_map4, QuantizationCandidates4Over6 &candidates, const int bank_group, + const size_t thread_offset_X_rowwise, const size_t shmem_offset_base_rowwise_out, + output_type *out_data_sh) { + uint32_t *packed = selected_4over6_packed(pick_map4, candidates); +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; + uint32_t *out_data_sh_as_uint32_t = + reinterpret_cast(&out_data_sh[shmem_offset_rowwise]); + out_data_sh_as_uint32_t[0] = packed[swizzled_group_idx / PACK_SIZE]; + } +} + template __device__ __forceinline__ void store_4over6_packed_16x(const uint32_t (&packed)[2], output_vec_type &output_vec) { @@ -286,6 +462,14 @@ __device__ __forceinline__ void store_4over6_packed_16x(const uint32_t (&packed) *reinterpret_cast(&output_vec.data.elt[4]) = packed[1]; } +template +__device__ __forceinline__ void store_selected_4over6_packed_16x( + const bool pick_map4, QuantizationCandidates4Over6 &candidates, output_vec_type &output_vec) { + uint32_t *packed = selected_4over6_packed(pick_map4, candidates); + *reinterpret_cast(&output_vec.data.elt[0]) = packed[0]; + *reinterpret_cast(&output_vec.data.elt[4]) = packed[1]; +} + template __device__ __forceinline__ void quantize_4over6_contiguous_16x( @@ -322,6 +506,27 @@ __device__ __forceinline__ void quantize_4over6_pair_array_16x( global_amax, S_dec_b_fp8, rOut); } +template +__device__ __forceinline__ void quantize_4over6_vec2_array_candidates_16x( + const vec_type (&x)[8], const QuantizationScales4Over6 &scaling_factors, + const float global_amax, QuantizationCandidates4Over6 &candidates) { + float first_half[8]; + float second_half[8]; +#pragma unroll + for (int i = 0; i < 4; ++i) { + first_half[2 * i] = static_cast(x[i].data.elt[0]); + first_half[2 * i + 1] = static_cast(x[i].data.elt[1]); + second_half[2 * i] = static_cast(x[i + 4].data.elt[0]); + second_half[2 * i + 1] = static_cast(x[i + 4].data.elt[1]); + } + + candidates.err_map4 = 0.0f; + candidates.err_map6 = 0.0f; + quantize_4over6_16x(first_half, second_half, scaling_factors, global_amax, + candidates.err_map4, candidates.err_map6, candidates.rOut_map4, + candidates.rOut_map6); +} + template __device__ __forceinline__ void quantize_4over6_vec2_array_16x( @@ -341,6 +546,26 @@ __device__ __forceinline__ void quantize_4over6_vec2_array_16x( global_amax, S_dec_b_fp8, rOut); } +template +__device__ __forceinline__ void quantize_4over6_vec_index_candidates_16x( + const vec_type (&x)[16], const int idx, + const QuantizationScales4Over6 &scaling_factors, const float global_amax, + QuantizationCandidates4Over6 &candidates) { + float first_half[8]; + float second_half[8]; +#pragma unroll + for (int i = 0; i < 8; ++i) { + first_half[i] = static_cast(x[i].data.elt[idx]); + second_half[i] = static_cast(x[i + 8].data.elt[idx]); + } + + candidates.err_map4 = 0.0f; + candidates.err_map6 = 0.0f; + quantize_4over6_16x(first_half, second_half, scaling_factors, global_amax, + candidates.err_map4, candidates.err_map6, candidates.rOut_map4, + candidates.rOut_map6); +} + template __device__ __forceinline__ void quantize_4over6_vec_index_16x( diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index 9e4aef5a1c..d25155f4de 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -23,6 +23,7 @@ #include "../../util/ptx.cuh" #include "../../utils.cuh" #include "core_nvfp4.cuh" +#include "quantize_4over6_nvfp4.cuh" #include "specialized/quantize_transpose_nvfp4_tuned_1D.cuh" namespace transformer_engine { @@ -779,7 +780,8 @@ __global__ void __launch_bounds__(THREADS_NUM) } template + typename IType, bool USE_STOCHASTIC_ROUNDING, bool USE_FAST_MATH, bool RETURN_TRANSPOSE, + bool USE_4OVER6> __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_output, @@ -893,16 +895,21 @@ __global__ void __launch_bounds__(THREADS_NUM) const bool is_master_thread = (threadIdx.x == 0); // Compute a global encoding/decoding scaling factors for all S_dec_b - const float S_enc_rowwise = (amax_rowwise_ptr == nullptr) - ? 1.0f - : compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); + const float S_enc_rowwise = + (amax_rowwise_ptr == nullptr) + ? 1.0f + : compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); // NOTE: This is to match with how emulation code was written. const float S_dec_rowwise = 1.0 / S_enc_rowwise; - const float S_enc_colwise = (amax_colwise_ptr == nullptr) - ? S_enc_rowwise - : compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); + const float S_enc_colwise = + (amax_colwise_ptr == nullptr) + ? S_enc_rowwise + : compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); const float S_dec_colwise = 1.0 / S_enc_colwise; + const float global_amax_rowwise = (amax_rowwise_ptr == nullptr) ? 1.0f : *amax_rowwise_ptr; + const float global_amax_colwise = + (amax_colwise_ptr == nullptr) ? global_amax_rowwise : *amax_colwise_ptr; const size_t warp_id = threadIdx.x / 32; const size_t lane_id = threadIdx.x % 32; @@ -914,6 +921,11 @@ __global__ void __launch_bounds__(THREADS_NUM) __shared__ alignas(8) uint64_t mbar[STAGES]; __shared__ __align__(16) float block_amax_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X + 1]; + __shared__ __align__(16) float err_map4_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X][BLOCK_DIM]; + __shared__ __align__(16) float err_map6_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X][BLOCK_DIM]; + __shared__ __align__(16) uint8_t pick_map4_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X]; + __shared__ __align__(16) + nvfp4_scale_t selected_scale_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X]; // Helper function for warp reduction auto warp_reduce_amax = [](float thread_amax, int block_in_warp) -> float { @@ -1074,56 +1086,86 @@ __global__ void __launch_bounds__(THREADS_NUM) } } - // 2. Compute E4M3 scaling factor - const nvfp4_scale_t S_dec_b_fp8 = - compute_decoding_scaling_factor(block_amax, S_enc_colwise); + if constexpr (USE_4OVER6) { + float x_4over6[SCALE_DIM]; +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + x_4over6[i] = static_cast(in_colwise_IType[i]); + } else { + x_4over6[i] = in_compute_colwise[i]; + } + } - // // Store scaling factors through SHMEM - const size_t scale_idx_sh = - tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; - out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8; + QuantizationCandidates4Over6 candidates; + const size_t block_col = threadIdx.x % BLOCK_DIM; + quantize_4over6_2d_block_candidate( + x_4over6, block_amax, S_enc_colwise, S_dec_colwise, global_amax_colwise, + block_in_tile_y, block_in_tile_x, block_col, threadIdx.x, block_amax_matrix, + err_map4_matrix, err_map6_matrix, pick_map4_matrix, selected_scale_matrix, + candidates); + + const nvfp4_scale_t S_dec_b_fp8 = selected_scale_matrix[block_in_tile_y][block_in_tile_x]; + const size_t scale_idx_sh = + tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; + out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8; + + const bool pick_map4 = pick_map4_matrix[block_in_tile_y][block_in_tile_x] != 0; + store_4over6_colwise_packed_16x(pick_map4, candidates, thread_lane, out_t_data_sh, + shmem_offset_base_colwise_out_t); + } else { + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_colwise); - // Compute "correct" per-block encoding scaling factor - constexpr float float_max = detail::TypeExtrema::max; - const float block_scale_inverse = fminf( - 1.0f / (static_cast(S_dec_b_fp8) * S_dec_colwise), float_max); // S_enc_b_fp8 - const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + // // Store scaling factors through SHMEM + const size_t scale_idx_sh = + tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; + out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8; - fp4e2m1x4 regs[SCALE_DIM / 4]; + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = fminf( + 1.0f / (static_cast(S_dec_b_fp8) * S_dec_colwise), float_max); // S_enc_b_fp8 + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + + fp4e2m1x4 regs[SCALE_DIM / 4]; #pragma unroll - for (int e = 0; e < SCALE_DIM / 4; ++e) { - const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); - if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { - const uint64_t elts = *reinterpret_cast(&in_colwise_IType[4 * e]); - regs[e] = ptx::mul_cvt_bf16_to_fp4_4x( - elts, block_scale_inverse_2x, rbits); - } else { - const float2 in01 = *reinterpret_cast(&in_compute_colwise[4 * e]); - const float2 in23 = *reinterpret_cast(&in_compute_colwise[4 * e + 2]); - regs[e] = ptx::mul_cvt_fp32_to_fp4_4x( - in01, in23, block_scale_inverse_2x, rbits); + for (int e = 0; e < SCALE_DIM / 4; ++e) { + const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + const uint64_t elts = *reinterpret_cast(&in_colwise_IType[4 * e]); + regs[e] = ptx::mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else { + const float2 in01 = *reinterpret_cast(&in_compute_colwise[4 * e]); + const float2 in23 = *reinterpret_cast(&in_compute_colwise[4 * e + 2]); + regs[e] = ptx::mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, rbits); + } } - } - const int group = thread_lane / 16; - uint32_t val[2]; - uint32_t *regs_4x = reinterpret_cast(regs); - - // Helps reducing bank conflicts - switch (group) { - case 0: - val[0] = regs_4x[0]; - val[1] = regs_4x[1]; - break; - case 1: - val[0] = regs_4x[1]; - val[1] = regs_4x[0]; - break; + const int group = thread_lane / 16; + uint32_t val[2]; + uint32_t *regs_4x = reinterpret_cast(regs); + + // Helps reducing bank conflicts + switch (group) { + case 0: + val[0] = regs_4x[0]; + val[1] = regs_4x[1]; + break; + case 1: + val[0] = regs_4x[1]; + val[1] = regs_4x[0]; + break; + } + uint32_t *out_t_data_sh_as_uint32_t = + reinterpret_cast(&out_t_data_sh[shmem_offset_base_colwise_out_t]); + out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2; + out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2; } - uint32_t *out_t_data_sh_as_uint32_t = - reinterpret_cast(&out_t_data_sh[shmem_offset_base_colwise_out_t]); - out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2; - out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2; } } @@ -1143,6 +1185,7 @@ __global__ void __launch_bounds__(THREADS_NUM) block_amax = block_amax_matrix[block_in_tile_y][block_in_tile_x]; float in_compute_rowwise[SCALE_DIM]; + float in_4over6_rowwise[SCALE_DIM]; Vec in_cached[WAVES]; // used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY @@ -1158,6 +1201,15 @@ __global__ void __launch_bounds__(THREADS_NUM) const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; // Load elements in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); + if constexpr (USE_4OVER6) { +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + in_4over6_rowwise[swizzled_group_idx + 2 * e] = + static_cast(in_IType[w].data.elt[e].x); + in_4over6_rowwise[swizzled_group_idx + 2 * e + 1] = + static_cast(in_IType[w].data.elt[e].y); + } + } } } else if constexpr (IS_CACHED_ACT_OP) { // ensures that all writes to cache made in the section above are visible to all threads @@ -1170,6 +1222,13 @@ __global__ void __launch_bounds__(THREADS_NUM) // Load cached elements in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + if constexpr (USE_4OVER6) { +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + in_4over6_rowwise[swizzled_group_idx + e] = + static_cast(in_cached[w].data.elt[e]); + } + } } } else { #pragma unroll @@ -1195,62 +1254,93 @@ __global__ void __launch_bounds__(THREADS_NUM) elt = static_cast(static_cast(elt)); } in_compute_rowwise[j] = elt; + if constexpr (USE_4OVER6) { + in_4over6_rowwise[swizzled_group_idx + e] = elt; + } } } } - // 2. Compute E4M3 scaling factor - const nvfp4_scale_t S_dec_b_fp8 = - compute_decoding_scaling_factor(block_amax, S_enc_rowwise); - - // Check boundaries - const size_t scales_offset_Y = - scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; - const size_t scales_offset_X = scales_offset_X_rowwise; - const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; - - const bool rowwise_scale_is_within_bounds_Y = - (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; - if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { - scales_ptr[scale_idx_global] = S_dec_b_fp8; - } + if constexpr (USE_4OVER6) { + QuantizationCandidates4Over6 candidates; + quantize_4over6_2d_block_candidate( + in_4over6_rowwise, block_amax, S_enc_rowwise, S_dec_rowwise, global_amax_rowwise, + block_in_tile_y, block_in_tile_x, tid_Y_rowwise, threadIdx.x, block_amax_matrix, + err_map4_matrix, err_map6_matrix, pick_map4_matrix, selected_scale_matrix, + candidates); - // Compute "correct" per-block encoding scaling factor - constexpr float float_max = detail::TypeExtrema::max; - const float block_scale_inverse = fminf( - 1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8 - const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + const nvfp4_scale_t S_dec_b_fp8 = selected_scale_matrix[block_in_tile_y][block_in_tile_x]; + const size_t scales_offset_Y = + scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; + const size_t scales_offset_X = scales_offset_X_rowwise; + const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; - // 3. Scale elements + const bool rowwise_scale_is_within_bounds_Y = + (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; + if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { + scales_ptr[scale_idx_global] = S_dec_b_fp8; + } + + const bool pick_map4 = pick_map4_matrix[block_in_tile_y][block_in_tile_x] != 0; + store_4over6_rowwise_packed_16x( + pick_map4, candidates, bank_group, thread_offset_X_rowwise, + shmem_offset_base_rowwise_out, out_data_sh); + } else { + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + + // Check boundaries + const size_t scales_offset_Y = + scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; + const size_t scales_offset_X = scales_offset_X_rowwise; + const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; + + const bool rowwise_scale_is_within_bounds_Y = + (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; + if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { + scales_ptr[scale_idx_global] = S_dec_b_fp8; + } + + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = fminf( + 1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8 + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + + // 3. Scale elements #pragma unroll - for (int w = 0; w < WAVES; ++w) { - Vec out; + for (int w = 0; w < WAVES; ++w) { + Vec out; #pragma unroll - for (int e = 0; e < PACK_SIZE / 4; ++e) { - const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); - IType2 in01; - IType2 in23; - if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { - const uint64_t elts = *reinterpret_cast(&in_IType[w].data.elt[2 * e]); - out.data.elt[e] = ptx::mul_cvt_bf16_to_fp4_4x( - elts, block_scale_inverse_2x, rbits); - } else if constexpr (IS_CACHED_ACT_OP) { - const uint64_t elts = *reinterpret_cast(&in_cached[w].data.elt[4 * e]); - out.data.elt[e] = ptx::mul_cvt_bf16_to_fp4_4x( - elts, block_scale_inverse_2x, rbits); - } else { - const int j = w * PACK_SIZE + 4 * e; - const float2 in01 = make_float2(in_compute_rowwise[j], in_compute_rowwise[j + 1]); - const float2 in23 = make_float2(in_compute_rowwise[j + 2], in_compute_rowwise[j + 3]); - out.data.elt[e] = ptx::mul_cvt_fp32_to_fp4_4x( - in01, in23, block_scale_inverse_2x, rbits); + for (int e = 0; e < PACK_SIZE / 4; ++e) { + const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); + IType2 in01; + IType2 in23; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + const uint64_t elts = *reinterpret_cast(&in_IType[w].data.elt[2 * e]); + out.data.elt[e] = ptx::mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else if constexpr (IS_CACHED_ACT_OP) { + const uint64_t elts = *reinterpret_cast(&in_cached[w].data.elt[4 * e]); + out.data.elt[e] = ptx::mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else { + const int j = w * PACK_SIZE + 4 * e; + const float2 in01 = make_float2(in_compute_rowwise[j], in_compute_rowwise[j + 1]); + const float2 in23 = + make_float2(in_compute_rowwise[j + 2], in_compute_rowwise[j + 3]); + out.data.elt[e] = ptx::mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, rbits); + } } - } - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; - const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; - out.store_to(&out_data_sh[shmem_offset_rowwise]); + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; + out.store_to(&out_data_sh[shmem_offset_rowwise]); + } } } } @@ -1319,9 +1409,13 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, using namespace ptx; bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; + const bool use_fast_math = quant_config ? quant_config->use_fast_math : false; + const bool use_4over6 = quant_config ? quant_config->nvfp4_4over6 : false; const bool row_scaled_nvfp4 = output->row_scaled_nvfp4; NVTE_CHECK(!row_scaled_nvfp4 || !use_2d_quantization, "Row-scaled NVFP4 quantization does not support 2D quantization."); + NVTE_CHECK(!use_4over6 || !use_stochastic_rounding, + "NVFP4 4over6 quantization does not support stochastic rounding."); // If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to // return the transposed data. @@ -1431,25 +1525,31 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, TRANSFORMER_ENGINE_SWITCH_CONDITION( use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_fast_math, USE_FAST_MATH, + TRANSFORMER_ENGINE_SWITCH_CONDITION(row_scaled_nvfp4, ROW_SCALED_NVFP4, { + TRANSFORMER_ENGINE_SWITCH_CONDITION(use_4over6, USE_4OVER6, { + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + auto kernel = quantize_transpose_nvfp4_kernel; + + if constexpr (use_2d_quantization) { + kernel = + quantize_transpose_nvfp4_2D_kernel; + } - TRANSFORMER_ENGINE_SWITCH_CONDITION(row_scaled_nvfp4, ROW_SCALED_NVFP4, { - TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { - auto kernel = quantize_transpose_nvfp4_kernel; - - if constexpr (use_2d_quantization) { - kernel = quantize_transpose_nvfp4_2D_kernel; - } - - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - kernel<<>>( - tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, - scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, - scale_stride, scale_stride_transpose, rng_state); - }); - });); + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, + scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, + scale_stride, scale_stride_transpose, rng_state); + }); + }); + }););); #else NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); #endif // FP4_TYPE_SUPPORTED diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 9feb696d06..3b1fdb4e3f 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -524,11 +524,13 @@ class NVFP4BlockScaling(Recipe): NVFP4 tensors. In this mode, rowwise ``amax`` metadata is stored as a vector with one FP32 value per tensor row. nvfp4_4over6 : {None, 'weights', 'activations', 'all'}, default = None - Select tensors that use NVFP4 4over6. In this mode NVFP4 1D + Select tensors that use NVFP4 4over6. In this mode NVFP4 quantization evaluates per-block map-to-4 and map-to-6 candidates and chooses the one with lower MSE. Ties choose map-to-6. The global E4M3 scale bound is 256 in this mode instead of 448. The ``activations`` scope applies to every non-weight tensor role. + Random Hadamard transforms and stochastic rounding are not yet + supported on tensors that use 4over6. backward_override : {None, 'high_precision', 'dequantized'}, default = None Backward precision mode. None does not modify backward behavior, `high_precision` keeps original high-precision operands for backward, @@ -562,14 +564,11 @@ def __post_init__(self) -> None: assert ( self.nvfp4_4over6 in _NVFP4_4OVER6_SCOPES ), "NVTE_NVFP4_4OVER6 must be unset or one of: 'weights', 'activations', 'all'." - if self.nvfp4_4over6 is not None: + if self.nvfp4_4over6 in ("activations", "all"): assert self.disable_rht, "NVFP4 4over6 currently requires RHT to be disabled" assert ( self.disable_stochastic_rounding ), "NVFP4 4over6 currently requires stochastic rounding to be disabled" - assert ( - self.disable_2d_quantization - ), "NVFP4 4over6 currently requires 2D quantization to be disabled" # Quantization params # Note: RHT is currently only applied to column-wise usage so that diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index b69f790531..4db1c1d63f 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -361,6 +361,12 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo kIs2DBlockScaling ? (kFP4BlockScalingSize / kNumRowsPerWarp) : 1; __shared__ CType amax_smem_red[k2DBlockAmaxDim][k2DBlockAmaxDim][k2DBlockAmaxReduceDim]; __shared__ CType amax_smem[k2DBlockAmaxDim][k2DBlockAmaxDim]; + constexpr int k4Over62DSelectionDim = + (kUse4Over6 && kIs2DBlockScaling) ? kFP4BlockScalingSize : 1; + __shared__ float err_map4_smem[k2DBlockAmaxDim][k2DBlockAmaxDim][k4Over62DSelectionDim]; + __shared__ float err_map6_smem[k2DBlockAmaxDim][k2DBlockAmaxDim][k4Over62DSelectionDim]; + __shared__ uint8_t pick_map4_smem[k2DBlockAmaxDim][k2DBlockAmaxDim]; + __shared__ ScaleType selected_scale_smem[k2DBlockAmaxDim][k2DBlockAmaxDim]; // Step 1: Load input to shared memory { @@ -530,11 +536,37 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo row_global_amax = global_amax[0]; } - uint32_t output_vec_4over6[2]; - transformer_engine::dispatch::nvfp4::core::quantize_4over6_vec2_array_16x( - smem_vec, scaling_factors, row_global_amax, scale_inv, output_vec_4over6); - transformer_engine::dispatch::nvfp4::core::store_4over6_packed_16x(output_vec_4over6, - output_vec); + if constexpr (kIs2DBlockScaling) { + constexpr int kNumRowsPerIter = kThreadsPerBlock / kNumThreadsStore; + const int warp_idx = threadIdx.x / kThreadsPerWarp; + const int tid_in_warp_x = threadIdx.x % kNumThreadsStore; + const int tid_in_warp_y = (threadIdx.x / kNumThreadsStore) % kNumRowsPerWarp; + const int data_row_idx = + iter * kNumRowsPerIter + warp_idx * kNumRowsPerWarp + tid_in_warp_y; + const size_t block_in_tile_y = data_row_idx / kFP4BlockScalingSize; + const size_t block_in_tile_x = tid_in_warp_x; + const size_t participant_idx = data_row_idx % kFP4BlockScalingSize; + + transformer_engine::dispatch::nvfp4::core::QuantizationCandidates4Over6 candidates; + transformer_engine::dispatch::nvfp4::core::quantize_4over6_vec2_array_candidates_16x< + kUseFastMath>(smem_vec, scaling_factors, row_global_amax, candidates); + transformer_engine::dispatch::nvfp4::core::record_and_reduce_4over6_2d_block_selection< + kFP4BlockScalingSize, k2DBlockAmaxDim, k2DBlockAmaxDim>( + amax, row_global_encode_scale, row_global_decode_scale, block_in_tile_y, + block_in_tile_x, participant_idx, err_map4_smem, err_map6_smem, pick_map4_smem, + selected_scale_smem, candidates); + + const bool pick_map4 = pick_map4_smem[block_in_tile_y][block_in_tile_x] == 1; + scale_inv = selected_scale_smem[block_in_tile_y][block_in_tile_x]; + transformer_engine::dispatch::nvfp4::core::store_selected_4over6_packed_16x( + pick_map4, candidates, output_vec); + } else { + uint32_t output_vec_4over6[2]; + transformer_engine::dispatch::nvfp4::core::quantize_4over6_vec2_array_16x( + smem_vec, scaling_factors, row_global_amax, scale_inv, output_vec_4over6); + transformer_engine::dispatch::nvfp4::core::store_4over6_packed_16x(output_vec_4over6, + output_vec); + } } else { scale_inv = ComputeDecodeScaleFP4(amax, row_global_encode_scale_multiplier); encode_scale = ComputeEncodeScaleFP4(scale_inv, row_global_decode_scale); @@ -670,11 +702,39 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo compute_4over6_fp4_encode_quantization_scaling_factors(amax, global_encode_scale, global_decode_scale); - uint32_t output_vec_4over6[2]; - transformer_engine::dispatch::nvfp4::core::quantize_4over6_vec_index_16x( - smem_vec, smem_idx, scaling_factors, global_amax[0], scale_inv, output_vec_4over6); - transformer_engine::dispatch::nvfp4::core::store_4over6_packed_16x(output_vec_4over6, - output_vec); + if constexpr (kIs2DBlockScaling) { + const int warp_idx = threadIdx.x / kThreadsPerWarp; + constexpr int kNumColsPerWarp = kThreadsPerWarp / kNumThreadsStore * kNVecSMem; + constexpr int kNumWarpsPerBlock = kThreadsPerBlock / kThreadsPerWarp; + constexpr int kNumColsPerIter = kNumColsPerWarp * kNumWarpsPerBlock; + const int tid_in_warp_x = (threadIdx.x / kNumThreadsStore) % kNumColsPerWarp; + const int tid_in_warp_y = (threadIdx.x % kThreadsPerWarp) % kNumThreadsStore; + const int data_col_idx = + iter * kNumColsPerIter + warp_idx * kNumColsPerWarp + tid_in_warp_x; + const size_t block_in_tile_y = tid_in_warp_y; + const size_t block_in_tile_x = data_col_idx / kFP4BlockScalingSize; + const size_t participant_idx = data_col_idx % kFP4BlockScalingSize; + + transformer_engine::dispatch::nvfp4::core::QuantizationCandidates4Over6 candidates; + transformer_engine::dispatch::nvfp4::core::quantize_4over6_vec_index_candidates_16x< + kUseFastMath>(smem_vec, smem_idx, scaling_factors, global_amax[0], candidates); + transformer_engine::dispatch::nvfp4::core::record_and_reduce_4over6_2d_block_selection< + kFP4BlockScalingSize, k2DBlockAmaxDim, k2DBlockAmaxDim>( + amax, global_encode_scale, global_decode_scale, block_in_tile_y, block_in_tile_x, + participant_idx, err_map4_smem, err_map6_smem, pick_map4_smem, selected_scale_smem, + candidates); + + const bool pick_map4 = pick_map4_smem[block_in_tile_y][block_in_tile_x] == 1; + scale_inv = selected_scale_smem[block_in_tile_y][block_in_tile_x]; + transformer_engine::dispatch::nvfp4::core::store_selected_4over6_packed_16x( + pick_map4, candidates, output_vec); + } else { + uint32_t output_vec_4over6[2]; + transformer_engine::dispatch::nvfp4::core::quantize_4over6_vec_index_16x( + smem_vec, smem_idx, scaling_factors, global_amax[0], scale_inv, output_vec_4over6); + transformer_engine::dispatch::nvfp4::core::store_4over6_packed_16x(output_vec_4over6, + output_vec); + } } else { scale_inv = ComputeDecodeScaleFP4(amax, global_encode_scale_multiplier); encode_scale = ComputeEncodeScaleFP4(scale_inv, global_decode_scale); @@ -772,8 +832,6 @@ void quantize_transpose_vector_blockwise_fp4( "Row-scaled NVFP4 quantization only supports rowwise quantization."); NVTE_CHECK(!row_scaled_nvfp4 || !use_2d_quantization, "Row-scaled NVFP4 quantization does not support 2D quantization."); - NVTE_CHECK(!use_4over6 || !use_2d_quantization, - "NVFP4 4over6 quantization does not support 2D quantization."); NVTE_CHECK(!use_4over6 || !use_stochastic_rounding, "NVFP4 4over6 quantization does not support stochastic rounding."); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index b7abf0da7c..b18a423864 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -2298,8 +2298,6 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou if (this->use_4over6) { NVTE_CHECK(!this->with_rht, "NVFP4 4over6 quantization does not support RHT."); - NVTE_CHECK(!this->with_2d_quantization, - "NVFP4 4over6 quantization does not support 2D quantization."); NVTE_CHECK(!this->stochastic_rounding, "NVFP4 4over6 quantization does not support stochastic rounding."); } diff --git a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py index ced235a372..3bc84ca785 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py @@ -365,8 +365,8 @@ def __init__( if use_4over6: if pow_2_scales: raise ValueError("4over6 is only supported for NVFP4 (non-pow2) mode.") - if quant_tile_shape != (1, 16): - raise ValueError("4over6 reference quantization only supports 1x16 tiles.") + if quant_tile_shape not in ((1, 16), (16, 16)): + raise ValueError("4over6 reference quantization only supports 1x16 or 16x16 tiles.") if with_rht: raise ValueError("4over6 reference quantization does not support RHT.") super().__init__(rowwise=rowwise, columnwise=columnwise) @@ -464,6 +464,7 @@ def _quantize_blockwise_4over6_reference( global_encode_scale: torch.Tensor, global_decode_scale: torch.Tensor, row_scaled_nvfp4: bool, + tile_len_y: int, ) -> Tuple[torch.Tensor, torch.Tensor]: """Quantize NVFP4 with 4over6 candidate selection. @@ -537,7 +538,12 @@ def _quantize_blockwise_4over6_reference( val_map6 = val_map6 / denom diff_map6 = val_map6 - x_float[:, :, idx] err_map6 = err_map6 + (diff_map6 * diff_map6).unsqueeze(-1) - pick_map4 = err_map4 < err_map6 + if tile_len_y == 1: + pick_map4 = err_map4 < err_map6 + else: + err_map4_blocks = err_map4.view(m // tile_len_y, tile_len_y, num_blocks, 1).sum(dim=1) + err_map6_blocks = err_map6.view(m // tile_len_y, tile_len_y, num_blocks, 1).sum(dim=1) + pick_map4 = (err_map4_blocks < err_map6_blocks).repeat_interleave(tile_len_y, dim=0) qx = torch.where( pick_map4.expand(-1, -1, tile_len_x // 2), qx_map4.view(m, num_blocks, tile_len_x // 2), @@ -601,8 +607,6 @@ def _quantize_blockwise_reference( decode_scale.to(torch.float32), ) else: - if use_4over6 and using_2d_quantization: - raise ValueError("4over6 reference quantization does not support 2D quantization.") if row_scaled_nvfp4: global_amax = global_amax.to(torch.float32).view(m, 1, 1) @@ -636,6 +640,7 @@ def _quantize_blockwise_reference( global_encode_scale, global_decode_scale, row_scaled_nvfp4, + tile_len_y, ) global_encode_scale_multiplier = global_encode_scale * torch.reciprocal(FLOAT4_E2M1_MAX) From 708c1ec052f2cb8b4fc83dc1a42940c04711a3bf Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 11 May 2026 20:07:11 -0700 Subject: [PATCH 27/57] Refactor 2d Signed-off-by: Ziang Li --- .../cast/nvfp4/quantize_4over6_nvfp4.cuh | 108 +++++++----------- .../cast/nvfp4/quantize_transpose_nvfp4.cuh | 36 +++--- ...quantize_transpose_vector_blockwise_fp4.cu | 70 ++++++------ 3 files changed, 93 insertions(+), 121 deletions(-) diff --git a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh index 42fe61d9b9..14aab1c208 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh @@ -264,6 +264,19 @@ __device__ __forceinline__ void quantize_4over6_16x( } } +__device__ __forceinline__ bool pick_4over6_map4(const float err_map4, const float err_map6) { + return err_map4 < err_map6; +} + +template +__device__ __forceinline__ nvfp4_scale_t selected_4over6_scale( + const bool pick_map4, const QuantizationScales4Over6 &scaling_factors) { + if (pick_map4) { + return scaling_factors.S_dec_b_fp8_map4; + } + return scaling_factors.S_dec_b_fp8_map6; +} + template __device__ __forceinline__ void quantize_4over6_16x( const float (&first_half)[8], const float (&second_half)[8], @@ -278,12 +291,12 @@ __device__ __forceinline__ void quantize_4over6_16x( global_amax, err_map4, err_map6, rOut_map4, rOut_map6); - if (err_map4 < err_map6) { - S_dec_b_fp8 = scaling_factors.S_dec_b_fp8_map4; + const bool pick_map4 = pick_4over6_map4(err_map4, err_map6); + S_dec_b_fp8 = selected_4over6_scale(pick_map4, scaling_factors); + if (pick_map4) { rOut[0] = rOut_map4[0]; rOut[1] = rOut_map4[1]; } else { - S_dec_b_fp8 = scaling_factors.S_dec_b_fp8_map6; rOut[0] = rOut_map6[0]; rOut[1] = rOut_map6[1]; } @@ -315,47 +328,16 @@ __device__ __forceinline__ void quantize_4over6_candidates_16x( candidates.rOut_map6); } -template -__device__ __forceinline__ void reduce_4over6_2d_block_selection( - const size_t block_in_tile_y, const size_t reduce_thread_idx, const float global_encode_scale, - const float global_decode_scale, - float (&block_amax_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X + 1], - float (&err_map4_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X][BLOCK_DIM], - float (&err_map6_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X][BLOCK_DIM], - uint8_t (&pick_map4_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X], - nvfp4_scale_t (&selected_scale_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X]) { - if (reduce_thread_idx < BLOCKS_PER_TILE_X) { - const size_t reduce_block_x = reduce_thread_idx; - float block_err_map4 = 0.0f; - float block_err_map6 = 0.0f; -#pragma unroll - for (int i = 0; i < BLOCK_DIM; ++i) { - block_err_map4 += err_map4_matrix[block_in_tile_y][reduce_block_x][i]; - block_err_map6 += err_map6_matrix[block_in_tile_y][reduce_block_x][i]; - } - - const auto scaling_factors = compute_4over6_fp4_encode_quantization_scaling_factors( - block_amax_matrix[block_in_tile_y][reduce_block_x], global_encode_scale, - global_decode_scale); - if (block_err_map4 < block_err_map6) { - pick_map4_matrix[block_in_tile_y][reduce_block_x] = 1; - selected_scale_matrix[block_in_tile_y][reduce_block_x] = scaling_factors.S_dec_b_fp8_map4; - } else { - pick_map4_matrix[block_in_tile_y][reduce_block_x] = 0; - selected_scale_matrix[block_in_tile_y][reduce_block_x] = scaling_factors.S_dec_b_fp8_map6; - } - } -} - -template -__device__ __forceinline__ void record_and_reduce_4over6_2d_block_selection( - const float block_amax, const float global_encode_scale, const float global_decode_scale, +template +__device__ __forceinline__ bool record_and_select_4over6_2d_block( + const QuantizationScales4Over6 &scaling_factors, const size_t block_in_tile_y, const size_t block_in_tile_x, const size_t participant_idx, float (&err_map4_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X][BLOCK_DIM], float (&err_map6_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X][BLOCK_DIM], uint8_t (&pick_map4_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X], nvfp4_scale_t (&selected_scale_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X], - const QuantizationCandidates4Over6 &candidates) { + nvfp4_scale_t &S_dec_b_fp8, const QuantizationCandidates4Over6 &candidates) { err_map4_matrix[block_in_tile_y][block_in_tile_x][participant_idx] = candidates.err_map4; err_map6_matrix[block_in_tile_y][block_in_tile_x][participant_idx] = candidates.err_map6; __syncthreads(); @@ -369,46 +351,43 @@ __device__ __forceinline__ void record_and_reduce_4over6_2d_block_selection( block_err_map6 += err_map6_matrix[block_in_tile_y][block_in_tile_x][i]; } - const auto scaling_factors = compute_4over6_fp4_encode_quantization_scaling_factors( - block_amax, global_encode_scale, global_decode_scale); - if (block_err_map4 < block_err_map6) { + const bool pick_map4 = pick_4over6_map4(block_err_map4, block_err_map6); + if (pick_map4) { pick_map4_matrix[block_in_tile_y][block_in_tile_x] = 1; - selected_scale_matrix[block_in_tile_y][block_in_tile_x] = scaling_factors.S_dec_b_fp8_map4; } else { pick_map4_matrix[block_in_tile_y][block_in_tile_x] = 0; - selected_scale_matrix[block_in_tile_y][block_in_tile_x] = scaling_factors.S_dec_b_fp8_map6; } + selected_scale_matrix[block_in_tile_y][block_in_tile_x] = + selected_4over6_scale(pick_map4, scaling_factors); } __syncthreads(); + S_dec_b_fp8 = selected_scale_matrix[block_in_tile_y][block_in_tile_x]; + return pick_map4_matrix[block_in_tile_y][block_in_tile_x] != 0; } template -__device__ __forceinline__ void quantize_4over6_2d_block_candidate( +__device__ __forceinline__ bool quantize_and_select_4over6_2d_block_16x( const float (&x)[16], const float block_amax, const float global_encode_scale, const float global_decode_scale, const float global_amax, const size_t block_in_tile_y, - const size_t block_in_tile_x, const size_t participant_idx, const size_t reduce_thread_idx, - float (&block_amax_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X + 1], + const size_t block_in_tile_x, const size_t participant_idx, float (&err_map4_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X][BLOCK_DIM], float (&err_map6_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X][BLOCK_DIM], uint8_t (&pick_map4_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X], nvfp4_scale_t (&selected_scale_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X], - QuantizationCandidates4Over6 &candidates) { + nvfp4_scale_t &S_dec_b_fp8, QuantizationCandidates4Over6 &candidates) { const auto scaling_factors = compute_4over6_fp4_encode_quantization_scaling_factors( block_amax, global_encode_scale, global_decode_scale); quantize_4over6_candidates_16x(x, scaling_factors, global_amax, candidates); - err_map4_matrix[block_in_tile_y][block_in_tile_x][participant_idx] = candidates.err_map4; - err_map6_matrix[block_in_tile_y][block_in_tile_x][participant_idx] = candidates.err_map6; - __syncthreads(); - - reduce_4over6_2d_block_selection( - block_in_tile_y, reduce_thread_idx, global_encode_scale, global_decode_scale, - block_amax_matrix, err_map4_matrix, err_map6_matrix, pick_map4_matrix, selected_scale_matrix); - __syncthreads(); + const bool pick_map4 = + record_and_select_4over6_2d_block( + scaling_factors, block_in_tile_y, block_in_tile_x, participant_idx, err_map4_matrix, + err_map6_matrix, pick_map4_matrix, selected_scale_matrix, S_dec_b_fp8, candidates); + return pick_map4; } -__device__ __forceinline__ uint32_t *selected_4over6_packed( - const bool pick_map4, QuantizationCandidates4Over6 &candidates) { +__device__ __forceinline__ const uint32_t *selected_4over6_packed( + const bool pick_map4, const QuantizationCandidates4Over6 &candidates) { if (pick_map4) { return candidates.rOut_map4; } @@ -417,9 +396,9 @@ __device__ __forceinline__ uint32_t *selected_4over6_packed( template __device__ __forceinline__ void store_4over6_colwise_packed_16x( - const bool pick_map4, QuantizationCandidates4Over6 &candidates, const int thread_lane, + const bool pick_map4, const QuantizationCandidates4Over6 &candidates, const int thread_lane, output_type *out_t_data_sh, const size_t shmem_offset_base_colwise_out_t) { - uint32_t *regs_4x = selected_4over6_packed(pick_map4, candidates); + const uint32_t *regs_4x = selected_4over6_packed(pick_map4, candidates); const int group = thread_lane / 16; uint32_t val[2]; switch (group) { @@ -440,10 +419,10 @@ __device__ __forceinline__ void store_4over6_colwise_packed_16x( template __device__ __forceinline__ void store_4over6_rowwise_packed_16x( - const bool pick_map4, QuantizationCandidates4Over6 &candidates, const int bank_group, + const bool pick_map4, const QuantizationCandidates4Over6 &candidates, const int bank_group, const size_t thread_offset_X_rowwise, const size_t shmem_offset_base_rowwise_out, output_type *out_data_sh) { - uint32_t *packed = selected_4over6_packed(pick_map4, candidates); + const uint32_t *packed = selected_4over6_packed(pick_map4, candidates); #pragma unroll for (int w = 0; w < WAVES; ++w) { const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; @@ -464,8 +443,9 @@ __device__ __forceinline__ void store_4over6_packed_16x(const uint32_t (&packed) template __device__ __forceinline__ void store_selected_4over6_packed_16x( - const bool pick_map4, QuantizationCandidates4Over6 &candidates, output_vec_type &output_vec) { - uint32_t *packed = selected_4over6_packed(pick_map4, candidates); + const bool pick_map4, const QuantizationCandidates4Over6 &candidates, + output_vec_type &output_vec) { + const uint32_t *packed = selected_4over6_packed(pick_map4, candidates); *reinterpret_cast(&output_vec.data.elt[0]) = packed[0]; *reinterpret_cast(&output_vec.data.elt[4]) = packed[1]; } diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index d25155f4de..75886fe54b 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -1097,21 +1097,20 @@ __global__ void __launch_bounds__(THREADS_NUM) } } - QuantizationCandidates4Over6 candidates; const size_t block_col = threadIdx.x % BLOCK_DIM; - quantize_4over6_2d_block_candidate( - x_4over6, block_amax, S_enc_colwise, S_dec_colwise, global_amax_colwise, - block_in_tile_y, block_in_tile_x, block_col, threadIdx.x, block_amax_matrix, - err_map4_matrix, err_map6_matrix, pick_map4_matrix, selected_scale_matrix, - candidates); - - const nvfp4_scale_t S_dec_b_fp8 = selected_scale_matrix[block_in_tile_y][block_in_tile_x]; + QuantizationCandidates4Over6 candidates; + nvfp4_scale_t S_dec_b_fp8; + const bool pick_map4 = + quantize_and_select_4over6_2d_block_16x( + x_4over6, block_amax, S_enc_colwise, S_dec_colwise, global_amax_colwise, + block_in_tile_y, block_in_tile_x, block_col, err_map4_matrix, err_map6_matrix, + pick_map4_matrix, selected_scale_matrix, S_dec_b_fp8, candidates); + const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8; - const bool pick_map4 = pick_map4_matrix[block_in_tile_y][block_in_tile_x] != 0; store_4over6_colwise_packed_16x(pick_map4, candidates, thread_lane, out_t_data_sh, shmem_offset_base_colwise_out_t); } else { @@ -1263,14 +1262,14 @@ __global__ void __launch_bounds__(THREADS_NUM) if constexpr (USE_4OVER6) { QuantizationCandidates4Over6 candidates; - quantize_4over6_2d_block_candidate( - in_4over6_rowwise, block_amax, S_enc_rowwise, S_dec_rowwise, global_amax_rowwise, - block_in_tile_y, block_in_tile_x, tid_Y_rowwise, threadIdx.x, block_amax_matrix, - err_map4_matrix, err_map6_matrix, pick_map4_matrix, selected_scale_matrix, - candidates); - - const nvfp4_scale_t S_dec_b_fp8 = selected_scale_matrix[block_in_tile_y][block_in_tile_x]; + nvfp4_scale_t S_dec_b_fp8; + const bool pick_map4 = + quantize_and_select_4over6_2d_block_16x( + in_4over6_rowwise, block_amax, S_enc_rowwise, S_dec_rowwise, global_amax_rowwise, + block_in_tile_y, block_in_tile_x, tid_Y_rowwise, err_map4_matrix, err_map6_matrix, + pick_map4_matrix, selected_scale_matrix, S_dec_b_fp8, candidates); + const size_t scales_offset_Y = scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; const size_t scales_offset_X = scales_offset_X_rowwise; @@ -1282,7 +1281,6 @@ __global__ void __launch_bounds__(THREADS_NUM) scales_ptr[scale_idx_global] = S_dec_b_fp8; } - const bool pick_map4 = pick_map4_matrix[block_in_tile_y][block_in_tile_x] != 0; store_4over6_rowwise_packed_16x( pick_map4, candidates, bank_group, thread_offset_X_rowwise, shmem_offset_base_rowwise_out, out_data_sh); diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index 4db1c1d63f..dd7cb24d45 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -34,6 +34,7 @@ using std::uint32_t; using std::uint8_t; using transformer_engine::detail::TypeExtrema; +namespace nvfp4_core = transformer_engine::dispatch::nvfp4::core; using transformer_engine::dispatch::nvfp4::core::compute_global_encode_scaling_factor_FP4; // clang-format off @@ -522,9 +523,9 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo float encode_scale; OVec output_vec; if constexpr (kUse4Over6) { - const auto scaling_factors = transformer_engine::dispatch::nvfp4::core:: - compute_4over6_fp4_encode_quantization_scaling_factors(amax, row_global_encode_scale, - row_global_decode_scale); + const auto scaling_factors = + nvfp4_core::compute_4over6_fp4_encode_quantization_scaling_factors( + amax, row_global_encode_scale, row_global_decode_scale); float row_global_amax; if constexpr (kRowScaledNVFP4) { if (row_idx < num_rows) { @@ -547,25 +548,21 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo const size_t block_in_tile_x = tid_in_warp_x; const size_t participant_idx = data_row_idx % kFP4BlockScalingSize; - transformer_engine::dispatch::nvfp4::core::QuantizationCandidates4Over6 candidates; - transformer_engine::dispatch::nvfp4::core::quantize_4over6_vec2_array_candidates_16x< - kUseFastMath>(smem_vec, scaling_factors, row_global_amax, candidates); - transformer_engine::dispatch::nvfp4::core::record_and_reduce_4over6_2d_block_selection< - kFP4BlockScalingSize, k2DBlockAmaxDim, k2DBlockAmaxDim>( - amax, row_global_encode_scale, row_global_decode_scale, block_in_tile_y, - block_in_tile_x, participant_idx, err_map4_smem, err_map6_smem, pick_map4_smem, - selected_scale_smem, candidates); - - const bool pick_map4 = pick_map4_smem[block_in_tile_y][block_in_tile_x] == 1; - scale_inv = selected_scale_smem[block_in_tile_y][block_in_tile_x]; - transformer_engine::dispatch::nvfp4::core::store_selected_4over6_packed_16x( - pick_map4, candidates, output_vec); + nvfp4_core::QuantizationCandidates4Over6 candidates; + nvfp4_core::quantize_4over6_vec2_array_candidates_16x( + smem_vec, scaling_factors, row_global_amax, candidates); + const bool pick_map4 = + nvfp4_core::record_and_select_4over6_2d_block( + scaling_factors, block_in_tile_y, block_in_tile_x, participant_idx, err_map4_smem, + err_map6_smem, pick_map4_smem, selected_scale_smem, scale_inv, candidates); + + nvfp4_core::store_selected_4over6_packed_16x(pick_map4, candidates, output_vec); } else { uint32_t output_vec_4over6[2]; - transformer_engine::dispatch::nvfp4::core::quantize_4over6_vec2_array_16x( + nvfp4_core::quantize_4over6_vec2_array_16x( smem_vec, scaling_factors, row_global_amax, scale_inv, output_vec_4over6); - transformer_engine::dispatch::nvfp4::core::store_4over6_packed_16x(output_vec_4over6, - output_vec); + nvfp4_core::store_4over6_packed_16x(output_vec_4over6, output_vec); } } else { scale_inv = ComputeDecodeScaleFP4(amax, row_global_encode_scale_multiplier); @@ -698,9 +695,9 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo float encode_scale; OVec output_vec; if constexpr (kUse4Over6) { - const auto scaling_factors = transformer_engine::dispatch::nvfp4::core:: - compute_4over6_fp4_encode_quantization_scaling_factors(amax, global_encode_scale, - global_decode_scale); + const auto scaling_factors = + nvfp4_core::compute_4over6_fp4_encode_quantization_scaling_factors( + amax, global_encode_scale, global_decode_scale); if constexpr (kIs2DBlockScaling) { const int warp_idx = threadIdx.x / kThreadsPerWarp; @@ -715,25 +712,22 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo const size_t block_in_tile_x = data_col_idx / kFP4BlockScalingSize; const size_t participant_idx = data_col_idx % kFP4BlockScalingSize; - transformer_engine::dispatch::nvfp4::core::QuantizationCandidates4Over6 candidates; - transformer_engine::dispatch::nvfp4::core::quantize_4over6_vec_index_candidates_16x< - kUseFastMath>(smem_vec, smem_idx, scaling_factors, global_amax[0], candidates); - transformer_engine::dispatch::nvfp4::core::record_and_reduce_4over6_2d_block_selection< - kFP4BlockScalingSize, k2DBlockAmaxDim, k2DBlockAmaxDim>( - amax, global_encode_scale, global_decode_scale, block_in_tile_y, block_in_tile_x, - participant_idx, err_map4_smem, err_map6_smem, pick_map4_smem, selected_scale_smem, - candidates); - - const bool pick_map4 = pick_map4_smem[block_in_tile_y][block_in_tile_x] == 1; - scale_inv = selected_scale_smem[block_in_tile_y][block_in_tile_x]; - transformer_engine::dispatch::nvfp4::core::store_selected_4over6_packed_16x( - pick_map4, candidates, output_vec); + nvfp4_core::QuantizationCandidates4Over6 candidates; + nvfp4_core::quantize_4over6_vec_index_candidates_16x( + smem_vec, smem_idx, scaling_factors, global_amax[0], candidates); + const bool pick_map4 = + nvfp4_core::record_and_select_4over6_2d_block( + scaling_factors, block_in_tile_y, block_in_tile_x, participant_idx, + err_map4_smem, err_map6_smem, pick_map4_smem, selected_scale_smem, scale_inv, + candidates); + + nvfp4_core::store_selected_4over6_packed_16x(pick_map4, candidates, output_vec); } else { uint32_t output_vec_4over6[2]; - transformer_engine::dispatch::nvfp4::core::quantize_4over6_vec_index_16x( + nvfp4_core::quantize_4over6_vec_index_16x( smem_vec, smem_idx, scaling_factors, global_amax[0], scale_inv, output_vec_4over6); - transformer_engine::dispatch::nvfp4::core::store_4over6_packed_16x(output_vec_4over6, - output_vec); + nvfp4_core::store_4over6_packed_16x(output_vec_4over6, output_vec); } } else { scale_inv = ComputeDecodeScaleFP4(amax, global_encode_scale_multiplier); From 4d31f1897a53532119a9992f01daf28a57fac8d3 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 11 May 2026 20:15:46 -0700 Subject: [PATCH 28/57] Clean up anti pattern Signed-off-by: Ziang Li --- tests/pytorch/test_backward_override.py | 26 ++++++++++++++++--------- tests/pytorch/test_fusible_ops.py | 25 +++++++++++++----------- tests/pytorch/utils.py | 8 ++++---- 3 files changed, 35 insertions(+), 24 deletions(-) diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index 5f54267a8f..5e6f36e8b4 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -175,7 +175,7 @@ def _maybe_skip_recipe_dtype( ) -> None: if dtype == torch.bfloat16 and not bf16_available: pytest.skip(reason_for_no_bf16) - if "nvfp4" in recipe_name: + if recipe_name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): if module_type in ("linear", "layernorm_linear") and dtype not in ( torch.bfloat16, torch.float32, @@ -188,14 +188,14 @@ def _maybe_skip_recipe_dtype( def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: str) -> None: if module_type == "ops_linear" and recipe_name == "fp8_block_scaling": pytest.skip("Fusible ops (te_ops.Linear) do not support Float8BlockScaling recipe") - if module_type == "ops_linear" and "nvfp4_row_scaled" in recipe_name: + if module_type == "ops_linear" and recipe_name == "nvfp4_row_scaled": pytest.skip("Row-scaled NVFP4 currently does not support fused te_ops paths.") - if module_type == "grouped_linear" and "nvfp4_4over6" in recipe_name: + if module_type == "grouped_linear" and recipe_name == "nvfp4_4over6": pytest.skip("NVFP4 4over6 currently does not support grouped quantization.") def _make_quantized_forward_reference_recipe(recipe_name: str) -> recipe.Recipe: - if "nvfp4_row_scaled" in recipe_name: + if recipe_name == "nvfp4_row_scaled": return make_recipe(recipe_name, backward_override="dequantized") return make_recipe(recipe_name) @@ -215,7 +215,9 @@ def _maybe_skip_unsupported_recipe_shape( " by 32." ) return - if "nvfp4" in recipe_name and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): + if recipe_name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") and ( + flat_first_dim % 16 != 0 or last_dim % 16 != 0 + ): pytest.skip( "Linear/LayerNormLinear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible" " by 16." @@ -240,7 +242,9 @@ def _maybe_skip_unsupported_recipe_shape( pytest.skip( "te_ops.Linear + MXFP8 requires prod(shape[:-1]) and shape[-1] divisible by 32." ) - if "nvfp4" in recipe_name and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): + if recipe_name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") and ( + flat_first_dim % 16 != 0 or last_dim % 16 != 0 + ): pytest.skip( "te_ops.Linear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible by 16." ) @@ -259,9 +263,13 @@ def _maybe_skip_unsupported_grouped_splits(recipe_name: str, m_splits: list[int] ) if recipe_name == "mxfp8" and any(m % 32 != 0 for m in non_empty_splits): pytest.skip("GroupedLinear + MXFP8 requires each non-empty m_split divisible by 32.") - if "nvfp4" in recipe_name and any(m % 16 != 0 for m in non_empty_splits): + if recipe_name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") and any( + m % 16 != 0 for m in non_empty_splits + ): pytest.skip("GroupedLinear + NVFP4 requires each non-empty m_split divisible by 16.") - if "nvfp4" in recipe_name and any(m % 64 != 0 for m in non_empty_splits): + if recipe_name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") and any( + m % 64 != 0 for m in non_empty_splits + ): pytest.skip( "GroupedLinear + NVFP4 grouped split_quantize currently requires each non-empty " "m_split divisible by 64 due to grouped amax kernel constraints." @@ -1744,7 +1752,7 @@ def test_backward_override_memory_peak_report( modes = ( ("high_precision", "dequantized") - if "nvfp4_row_scaled" in recipe_name + if recipe_name == "nvfp4_row_scaled" else (None, "high_precision", "dequantized") ) mode_results: dict[str, dict[str, float] | str] = {} diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 9d060611d2..e57ee9d098 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -107,7 +107,7 @@ def maybe_skip_quantization( pytest.skip(reason_for_no_fp8) if quantization == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if "nvfp4" in quantization and not nvfp4_available: + if quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") and not nvfp4_available: pytest.skip(reason_for_no_nvfp4) # Check dims @@ -120,13 +120,16 @@ def maybe_skip_quantization( elif quantization == "mxfp8": if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0: pytest.skip("MXFP8 GEMMs require dims that are divisible by 32") - elif "nvfp4" in quantization: + elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0: pytest.skip("NVFP4 GEMMs require dims that are divisible by 16") # Check dtype if dtype is not None: - if "nvfp4" in quantization and dtype != torch.bfloat16: + if ( + quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") + and dtype != torch.bfloat16 + ): pytest.skip("NVFP4 quantization is only supported with BF16 data") @@ -181,14 +184,14 @@ def make_reference_and_test_tensors( test = quantizer(test) elif quantization == "mxfp8": test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) - elif "nvfp4" in quantization: + elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): test = NVFP4Quantizer( with_rht=False, with_post_rht_amax=False, with_2d_quantization=False, stochastic_rounding=False, with_random_sign_mask=False, - use_4over6="4over6" in quantization, + use_4over6=quantization == "nvfp4_4over6", )(test) else: raise ValueError(f"Unsupported quantization scheme ({quantization})") @@ -1514,7 +1517,7 @@ def test_add_extra_input( if in_place: if quantization in ("fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"): tols = dtype_tols(x1_test._fp8_dtype) - elif quantization is not None and "nvfp4" in quantization: + elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): tols = dtype_tols(x1_test._fp4_dtype) y_test = y_test.to(dtype=torch.float64, device="cpu") dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu") @@ -1885,7 +1888,7 @@ def test_clamped_swiglu( # Expected numerical error tols = dtype_tols(dtype) - if quantized_compute and "nvfp4" in quantization: + if quantized_compute and quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): tols = dtype_tols(tex.DType.kFloat4E2M1) elif quantized_compute: tols = dtype_tols(tex.DType.kFloat8E4M3) @@ -2078,7 +2081,7 @@ def test_grouped_linear( pytest.skip("Quantization scheme is not used") if quantization is not None and dtype not in (torch.bfloat16, torch.float16): pytest.skip("Quantized group GEMM is only supported with BF16/FP16") - if quantization and "4over6" in quantization: + if quantization == "nvfp4_4over6": pytest.skip("NVFP4 4over6 grouped quantization is not supported") if single_grouped_bias and not bias: @@ -3612,11 +3615,11 @@ def test_grouped_mlp( pytest.skip("single_grouped_bias requires bias=True") if with_quantization and dtype not in (torch.bfloat16, torch.float16): pytest.skip("Quantized group GEMM is only supported with BF16/FP16") - if with_quantization and "4over6" in quantization: + if quantization == "nvfp4_4over6": pytest.skip("NVFP4 4over6 grouped quantization is not supported") if ( with_quantization - and "nvfp4" in quantization + and quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") and activation == "scaled_clamped_qgeglu" and bias ): @@ -3841,7 +3844,7 @@ def test_grouped_mlp( # Loose tols for sanity checking tols = {"rtol": 0.125, "atol": 0.25} - if with_quantization and "nvfp4" in quantization: + if quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): tols = {"rtol": 0.25, "atol": 0.5} # Check values diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index c58d3e751e..f8113332d9 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -118,7 +118,7 @@ def quantization_tols(name: str) -> dict[str, float]: "mxfp8_block_scaling", ): return dtype_tols(tex.DType.kFloat8E4M3) - if "nvfp4" in name: + if name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): return dtype_tols(tex.DType.kFloat4E2M1) raise ValueError(f"Unsupported quantization scheme ({name})") @@ -145,13 +145,13 @@ def make_recipe(name: Optional[str], **recipe_kwargs: Any) -> Optional[Recipe]: ) if name == "fp8_block_scaling": return transformer_engine.common.recipe.Float8BlockScaling(**recipe_kwargs) - if "nvfp4" in name: - use_4over6 = "4over6" in name + if name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): + use_4over6 = name == "nvfp4_4over6" kwargs = { "disable_rht": True, "disable_stochastic_rounding": True, "disable_2d_quantization": not use_4over6, - "row_scaled_activation": "row_scaled" in name, + "row_scaled_activation": name == "nvfp4_row_scaled", "nvfp4_4over6": "all" if use_4over6 else None, } kwargs.update(recipe_kwargs) From dfc15f23b19dc98acf8855cf3c01aa8d24a40fa8 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 11 May 2026 20:35:51 -0700 Subject: [PATCH 29/57] Enforce 4over6 consistency Signed-off-by: Ziang Li --- tests/cpp/operator/test_dequantize_nvfp4.cu | 8 ++++++-- .../common/cast/dispatch/quantize.cuh | 17 +++++++++-------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu index e1289b7d60..714574facd 100644 --- a/tests/cpp/operator/test_dequantize_nvfp4.cu +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -120,7 +120,9 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, // Quantize if (rows > 0 && cols > 0) { - nvte_quantize(input.data(), quantized.data(), 0); + QuantizationConfigWrapper quant_config; + quant_config.set_nvfp4_4over6(use_4over6); + nvte_quantize_v2(input.data(), quantized.data(), quant_config, 0); cudaDeviceSynchronize(); auto err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); @@ -181,7 +183,9 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, } if (rows > 0 && cols > 0) { - nvte_quantize(input.data(), quantized_compact.data(), 0); + QuantizationConfigWrapper quant_config; + quant_config.set_nvfp4_4over6(use_4over6); + nvte_quantize_v2(input.data(), quantized_compact.data(), quant_config, 0); cudaDeviceSynchronize(); } diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 597d034844..a28aac98ca 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -101,11 +101,11 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, int32_t cols = input_tensor->flat_last_dim(); auto dtype = input_tensor->dtype(); const bool row_scaled_nvfp4 = output_tensor->row_scaled_nvfp4; - const bool use_4over6 = quant_config_cpp.nvfp4_4over6 || output_tensor->nvfp4_4over6; + NVTE_CHECK(quant_config_cpp.nvfp4_4over6 == output_tensor->nvfp4_4over6, + "Tensor and quantization config have inconsistent options for NVFP4 4over6."); + const bool use_4over6 = quant_config_cpp.nvfp4_4over6; NVTE_CHECK(!use_4over6 || !quant_config_cpp.stochastic_rounding, "NVFP4 4over6 quantization does not support stochastic rounding."); - quant_config_cpp.nvfp4_4over6 = use_4over6; - output_tensor->nvfp4_4over6 = use_4over6; if (row_scaled_nvfp4) { NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, "Row-scaled NVFP4 quantization does not support 2D quantization."); @@ -256,11 +256,11 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens int32_t rows = grad_tensor->flat_first_dim(); int32_t cols = grad_tensor->flat_last_dim(); auto dtype = grad_tensor->dtype(); - const bool use_4over6 = quant_config_cpp.nvfp4_4over6 || output_tensor->nvfp4_4over6; + NVTE_CHECK(quant_config_cpp.nvfp4_4over6 == output_tensor->nvfp4_4over6, + "Tensor and quantization config have inconsistent options for NVFP4 4over6."); + const bool use_4over6 = quant_config_cpp.nvfp4_4over6; NVTE_CHECK(!use_4over6 || !quant_config_cpp.stochastic_rounding, "NVFP4 4over6 quantization does not support stochastic rounding."); - quant_config_cpp.nvfp4_4over6 = use_4over6; - output_tensor->nvfp4_4over6 = use_4over6; NVTE_CHECK(!output_tensor->row_scaled_nvfp4, "Backward NVFP4 quantization does not support row-scaled outputs."); bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && @@ -386,10 +386,11 @@ void group_quantize_fwd_host_aware_helper(const NVTETensor input, NVTETensor *ou int32_t cols = input_tensor->flat_last_dim(); auto dtype = input_tensor->dtype(); - bool use_4over6 = quant_config_cpp.nvfp4_4over6; for (const auto *output_tensor : output_tensors) { - use_4over6 = use_4over6 || output_tensor->nvfp4_4over6; + NVTE_CHECK(quant_config_cpp.nvfp4_4over6 == output_tensor->nvfp4_4over6, + "Tensor and quantization config have inconsistent options for NVFP4 4over6."); } + const bool use_4over6 = quant_config_cpp.nvfp4_4over6; NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, "2D quantization is not supported for group quantize."); NVTE_CHECK(!use_4over6, "NVFP4 4over6 quantization is not supported for group quantize."); From 945367098b4b6e8ba3ce5ad9c26d6c6bf1e32680 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 11 May 2026 20:41:25 -0700 Subject: [PATCH 30/57] Update comments Signed-off-by: Ziang Li --- .../include/transformer_engine/transformer_engine.h | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 1b78dcf826..67756d477b 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -83,7 +83,13 @@ enum NVTETensorParam { * its values are populated during quantization. */ kNVTERowScaledNVFP4 = 8, - kNVTENVFP44Over6 = 9, /*!< Whether an NVFP4 tensor uses 4over6 scaling */ + /*! Whether an NVFP4 tensor is encoded with 4over6 semantics. + * + * This is part of the tensor data contract: 4over6 tensors use 256 as + * their global E4M3 scale bound, so downstream dequantization and GEMM + * scale consumers must decode them differently from default NVFP4 tensors. + */ + kNVTENVFP44Over6 = 9, kNVTENumTensorParams }; @@ -385,7 +391,9 @@ enum NVTEQuantizationConfigAttribute { /*! Whether to use NVFP4 4over6 block scale selection. * * 4over6 evaluates map-to-4 and map-to-6 candidates for each 1x16 block, - * stores the lower-MSE candidate, and uses a 256 global E4M3 scale bound. + * stores the lower-MSE candidate, and emits tensor data that uses a 256 + * global E4M3 scale bound. The output tensor's kNVTENVFP44Over6 metadata + * must match this option. */ kNVTEQuantizationConfigNVFP44Over6 = 8, kNVTEQuantizationConfigNumAttributes From 6d871da2d1011831f42cb93f499ca28a4f85ff0b Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 11 May 2026 20:55:44 -0700 Subject: [PATCH 31/57] Update docs Signed-off-by: Ziang Li --- docs/envvars.rst | 4 ++-- transformer_engine/common/recipe/__init__.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/envvars.rst b/docs/envvars.rst index d4302ea91d..82ada36c19 100644 --- a/docs/envvars.rst +++ b/docs/envvars.rst @@ -261,7 +261,7 @@ Kernel Configuration :Type: ``int`` (0 or 1) :Default: ``0`` - :Description: Enable fast math optimizations in runtime-compiled (NVRTC) kernels. This trades numerical accuracy for performance. These optimizations are experimental and inconsistently implemented. + :Description: Enable fast math optimizations in supported kernels, including runtime-compiled (NVRTC) kernels and NVFP4 4over6 quantization. This trades numerical accuracy for performance. These optimizations are experimental and inconsistently implemented. .. envvar:: NVTE_DISABLE_NVRTC @@ -291,7 +291,7 @@ Kernel Configuration :Type: ``str`` (``weights``, ``activations``, or ``all``) :Default: unset - :Description: Enable per-block map-to-4 versus map-to-6 candidate selection for selected NVFP4 quantizers in the ``NVFP4BlockScaling`` recipe. ``weights`` selects weight tensor roles, ``activations`` selects non-weight tensor roles, and ``all`` selects both. The selected block scale is the candidate with lower input-domain MSE, and ties select map-to-6. This mode uses 256 as the global E4M3 scale bound instead of the 448 bound. Tensors using 4over6 currently require RHT and stochastic rounding to be disabled. + :Description: Enable per-block map-to-4 versus map-to-6 candidate selection for selected NVFP4 quantizers in the ``NVFP4BlockScaling`` recipe. ``weights`` selects weight tensor roles, ``activations`` selects non-weight tensor roles, and ``all`` selects both. The selected block scale is the candidate with lower input-domain MSE, and ties select map-to-6. This mode uses 256 as the global E4M3 scale bound instead of the 448 bound. Tensors using 4over6 currently require RHT and stochastic rounding to be disabled; activation and backward scopes therefore require ``NVTE_NVFP4_DISABLE_RHT=1`` and ``NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING=1``. Torch Compilation and Fusion ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 3b1fdb4e3f..82cc42fedd 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -530,7 +530,9 @@ class NVFP4BlockScaling(Recipe): global E4M3 scale bound is 256 in this mode instead of 448. The ``activations`` scope applies to every non-weight tensor role. Random Hadamard transforms and stochastic rounding are not yet - supported on tensors that use 4over6. + supported on tensors that use 4over6; activation and backward + scopes therefore require ``disable_rht=True`` and + ``disable_stochastic_rounding=True``. backward_override : {None, 'high_precision', 'dequantized'}, default = None Backward precision mode. None does not modify backward behavior, `high_precision` keeps original high-precision operands for backward, From f8338e835bae747746a7e706ca16e8fb26736cfb Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 11 May 2026 21:36:11 -0700 Subject: [PATCH 32/57] Fix test Signed-off-by: Ziang Li --- tests/pytorch/test_fusible_ops.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index e57ee9d098..e201c673aa 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -145,6 +145,7 @@ def make_reference_and_test_tensors( test_dtype: torch.dtype = torch.float32, test_device: torch.device = "cuda", test_is_quantized: bool = False, + nvfp4_weight: bool = False, requires_grad: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Construct tensors with the same values @@ -155,6 +156,8 @@ def make_reference_and_test_tensors( If a quantization scheme is provided, the tensor values are quantized so that they are representable. + NVFP4 4over6 activations use 1D quantization, while linear weights + use the recipe's 2D weight quantization path. """ @@ -188,7 +191,7 @@ def make_reference_and_test_tensors( test = NVFP4Quantizer( with_rht=False, with_post_rht_amax=False, - with_2d_quantization=False, + with_2d_quantization=quantization == "nvfp4_4over6" and nvfp4_weight, stochastic_rounding=False, with_random_sign_mask=False, use_4over6=quantization == "nvfp4_4over6", @@ -508,6 +511,7 @@ def test_dtype_cast( quantization=quantization, test_dtype=dtype, test_device=device, + nvfp4_weight=True, ) # Construct operation @@ -915,6 +919,7 @@ def _test_basic_linear( quantization=quantization, test_dtype=dtype, test_device=device, + nvfp4_weight=True, ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, @@ -1087,6 +1092,7 @@ def test_linear( quantization=quantization, test_dtype=dtype, test_device=device, + nvfp4_weight=True, ) b_ref, b_test = None, None if bias: @@ -2119,6 +2125,7 @@ def test_grouped_linear( quantization=quantization, test_dtype=dtype, test_device=device, + nvfp4_weight=True, requires_grad=weight_requires_grad, ) b_ref, b_test = None, None @@ -2629,6 +2636,7 @@ def test_forward_linear_bias_activation( quantization=quantization, test_dtype=dtype, test_device=device, + nvfp4_weight=True, ) b_ref, b_test = None, None if bias: @@ -2734,6 +2742,7 @@ def test_forward_linear_bias_add( quantization=quantization, test_dtype=dtype, test_device=device, + nvfp4_weight=True, ) b_ref, b_test = None, None if bias: @@ -2847,6 +2856,7 @@ def test_forward_linear_scale_add( quantization=quantization, test_dtype=dtype, test_device=device, + nvfp4_weight=True, ) x2_ref, x2_test = make_reference_and_test_tensors( out_shape, @@ -3129,6 +3139,7 @@ def test_backward_linear_add( quantization=quantization, test_dtype=dtype, test_device=device, + nvfp4_weight=True, ) dy1_ref, dy1_test = make_reference_and_test_tensors( out_shape, @@ -3232,6 +3243,7 @@ def test_backward_linear_scale( quantization=quantization, test_dtype=dtype, test_device=device, + nvfp4_weight=True, ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, @@ -3461,12 +3473,14 @@ def test_layernorm_mlp( quantization=quantization, test_dtype=dtype, test_device=device, + nvfp4_weight=True, ) w2_ref, w2_test = make_reference_and_test_tensors( (hidden_size, ffn_hidden_size // 2), quantization=quantization, test_dtype=dtype, test_device=device, + nvfp4_weight=True, ) b1_ref, b1_test, b2_ref, b2_test = None, None, None, None if bias: @@ -3661,6 +3675,7 @@ def test_grouped_mlp( quantization=quantization, test_dtype=dtype, test_device=device, + nvfp4_weight=True, ) fc2_w_ref, fc2_w_test = make_reference_and_test_tensors( (hidden_size, hidden_size), @@ -3669,6 +3684,7 @@ def test_grouped_mlp( quantization=quantization, test_dtype=dtype, test_device=device, + nvfp4_weight=True, ) fc1_b_ref, fc1_b_test = None, None fc2_b_ref, fc2_b_test = None, None From c9bc9211e2a87624c0b6d4001d820dd88b7a1b87 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 11 May 2026 22:06:55 -0700 Subject: [PATCH 33/57] Drop test_fusible_ops Signed-off-by: Ziang Li --- tests/pytorch/test_fusible_ops.py | 48 ++++++------------------------- 1 file changed, 9 insertions(+), 39 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index e201c673aa..7691582f97 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -77,7 +77,6 @@ _quantization_list.append("mxfp8") if nvfp4_available: _quantization_list.append("nvfp4") - _quantization_list.append("nvfp4_4over6") @pytest.fixture(autouse=True, scope="class") @@ -107,7 +106,7 @@ def maybe_skip_quantization( pytest.skip(reason_for_no_fp8) if quantization == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") and not nvfp4_available: + if quantization == "nvfp4" and not nvfp4_available: pytest.skip(reason_for_no_nvfp4) # Check dims @@ -120,16 +119,13 @@ def maybe_skip_quantization( elif quantization == "mxfp8": if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0: pytest.skip("MXFP8 GEMMs require dims that are divisible by 32") - elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): + elif quantization == "nvfp4": if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0: pytest.skip("NVFP4 GEMMs require dims that are divisible by 16") # Check dtype if dtype is not None: - if ( - quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") - and dtype != torch.bfloat16 - ): + if quantization == "nvfp4" and dtype != torch.bfloat16: pytest.skip("NVFP4 quantization is only supported with BF16 data") @@ -145,7 +141,6 @@ def make_reference_and_test_tensors( test_dtype: torch.dtype = torch.float32, test_device: torch.device = "cuda", test_is_quantized: bool = False, - nvfp4_weight: bool = False, requires_grad: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Construct tensors with the same values @@ -156,8 +151,6 @@ def make_reference_and_test_tensors( If a quantization scheme is provided, the tensor values are quantized so that they are representable. - NVFP4 4over6 activations use 1D quantization, while linear weights - use the recipe's 2D weight quantization path. """ @@ -187,14 +180,13 @@ def make_reference_and_test_tensors( test = quantizer(test) elif quantization == "mxfp8": test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) - elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): + elif quantization == "nvfp4": test = NVFP4Quantizer( with_rht=False, with_post_rht_amax=False, - with_2d_quantization=quantization == "nvfp4_4over6" and nvfp4_weight, + with_2d_quantization=False, stochastic_rounding=False, with_random_sign_mask=False, - use_4over6=quantization == "nvfp4_4over6", )(test) else: raise ValueError(f"Unsupported quantization scheme ({quantization})") @@ -511,7 +503,6 @@ def test_dtype_cast( quantization=quantization, test_dtype=dtype, test_device=device, - nvfp4_weight=True, ) # Construct operation @@ -919,7 +910,6 @@ def _test_basic_linear( quantization=quantization, test_dtype=dtype, test_device=device, - nvfp4_weight=True, ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, @@ -1092,7 +1082,6 @@ def test_linear( quantization=quantization, test_dtype=dtype, test_device=device, - nvfp4_weight=True, ) b_ref, b_test = None, None if bias: @@ -1523,7 +1512,7 @@ def test_add_extra_input( if in_place: if quantization in ("fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"): tols = dtype_tols(x1_test._fp8_dtype) - elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): + elif quantization == "nvfp4": tols = dtype_tols(x1_test._fp4_dtype) y_test = y_test.to(dtype=torch.float64, device="cpu") dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu") @@ -1894,7 +1883,7 @@ def test_clamped_swiglu( # Expected numerical error tols = dtype_tols(dtype) - if quantized_compute and quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): + if quantized_compute and quantization == "nvfp4": tols = dtype_tols(tex.DType.kFloat4E2M1) elif quantized_compute: tols = dtype_tols(tex.DType.kFloat8E4M3) @@ -2087,8 +2076,6 @@ def test_grouped_linear( pytest.skip("Quantization scheme is not used") if quantization is not None and dtype not in (torch.bfloat16, torch.float16): pytest.skip("Quantized group GEMM is only supported with BF16/FP16") - if quantization == "nvfp4_4over6": - pytest.skip("NVFP4 4over6 grouped quantization is not supported") if single_grouped_bias and not bias: pytest.skip("single_grouped_bias requires bias=True") @@ -2125,7 +2112,6 @@ def test_grouped_linear( quantization=quantization, test_dtype=dtype, test_device=device, - nvfp4_weight=True, requires_grad=weight_requires_grad, ) b_ref, b_test = None, None @@ -2636,7 +2622,6 @@ def test_forward_linear_bias_activation( quantization=quantization, test_dtype=dtype, test_device=device, - nvfp4_weight=True, ) b_ref, b_test = None, None if bias: @@ -2742,7 +2727,6 @@ def test_forward_linear_bias_add( quantization=quantization, test_dtype=dtype, test_device=device, - nvfp4_weight=True, ) b_ref, b_test = None, None if bias: @@ -2856,7 +2840,6 @@ def test_forward_linear_scale_add( quantization=quantization, test_dtype=dtype, test_device=device, - nvfp4_weight=True, ) x2_ref, x2_test = make_reference_and_test_tensors( out_shape, @@ -3139,7 +3122,6 @@ def test_backward_linear_add( quantization=quantization, test_dtype=dtype, test_device=device, - nvfp4_weight=True, ) dy1_ref, dy1_test = make_reference_and_test_tensors( out_shape, @@ -3243,7 +3225,6 @@ def test_backward_linear_scale( quantization=quantization, test_dtype=dtype, test_device=device, - nvfp4_weight=True, ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, @@ -3473,14 +3454,12 @@ def test_layernorm_mlp( quantization=quantization, test_dtype=dtype, test_device=device, - nvfp4_weight=True, ) w2_ref, w2_test = make_reference_and_test_tensors( (hidden_size, ffn_hidden_size // 2), quantization=quantization, test_dtype=dtype, test_device=device, - nvfp4_weight=True, ) b1_ref, b1_test, b2_ref, b2_test = None, None, None, None if bias: @@ -3629,14 +3608,7 @@ def test_grouped_mlp( pytest.skip("single_grouped_bias requires bias=True") if with_quantization and dtype not in (torch.bfloat16, torch.float16): pytest.skip("Quantized group GEMM is only supported with BF16/FP16") - if quantization == "nvfp4_4over6": - pytest.skip("NVFP4 4over6 grouped quantization is not supported") - if ( - with_quantization - and quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") - and activation == "scaled_clamped_qgeglu" - and bias - ): + if quantization == "nvfp4" and activation == "scaled_clamped_qgeglu" and bias: # TODO: ksivaman: Need to debug numerics for this case. pytest.skip("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU") @@ -3675,7 +3647,6 @@ def test_grouped_mlp( quantization=quantization, test_dtype=dtype, test_device=device, - nvfp4_weight=True, ) fc2_w_ref, fc2_w_test = make_reference_and_test_tensors( (hidden_size, hidden_size), @@ -3684,7 +3655,6 @@ def test_grouped_mlp( quantization=quantization, test_dtype=dtype, test_device=device, - nvfp4_weight=True, ) fc1_b_ref, fc1_b_test = None, None fc2_b_ref, fc2_b_test = None, None @@ -3860,7 +3830,7 @@ def test_grouped_mlp( # Loose tols for sanity checking tols = {"rtol": 0.125, "atol": 0.25} - if quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): + if quantization == "nvfp4": tols = {"rtol": 0.25, "atol": 0.5} # Check values From 00ba6949f4a26240a6df20e1d76b51244a146f8a Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 11 May 2026 22:32:50 -0700 Subject: [PATCH 34/57] Revert "Drop test_fusible_ops" This reverts commit 69f9ccc36a9c459f50c2f00b6cd6a62c5e1bdf13. Signed-off-by: Ziang Li --- tests/pytorch/test_fusible_ops.py | 48 +++++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 9 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 7691582f97..e201c673aa 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -77,6 +77,7 @@ _quantization_list.append("mxfp8") if nvfp4_available: _quantization_list.append("nvfp4") + _quantization_list.append("nvfp4_4over6") @pytest.fixture(autouse=True, scope="class") @@ -106,7 +107,7 @@ def maybe_skip_quantization( pytest.skip(reason_for_no_fp8) if quantization == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if quantization == "nvfp4" and not nvfp4_available: + if quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") and not nvfp4_available: pytest.skip(reason_for_no_nvfp4) # Check dims @@ -119,13 +120,16 @@ def maybe_skip_quantization( elif quantization == "mxfp8": if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0: pytest.skip("MXFP8 GEMMs require dims that are divisible by 32") - elif quantization == "nvfp4": + elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0: pytest.skip("NVFP4 GEMMs require dims that are divisible by 16") # Check dtype if dtype is not None: - if quantization == "nvfp4" and dtype != torch.bfloat16: + if ( + quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") + and dtype != torch.bfloat16 + ): pytest.skip("NVFP4 quantization is only supported with BF16 data") @@ -141,6 +145,7 @@ def make_reference_and_test_tensors( test_dtype: torch.dtype = torch.float32, test_device: torch.device = "cuda", test_is_quantized: bool = False, + nvfp4_weight: bool = False, requires_grad: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Construct tensors with the same values @@ -151,6 +156,8 @@ def make_reference_and_test_tensors( If a quantization scheme is provided, the tensor values are quantized so that they are representable. + NVFP4 4over6 activations use 1D quantization, while linear weights + use the recipe's 2D weight quantization path. """ @@ -180,13 +187,14 @@ def make_reference_and_test_tensors( test = quantizer(test) elif quantization == "mxfp8": test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) - elif quantization == "nvfp4": + elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): test = NVFP4Quantizer( with_rht=False, with_post_rht_amax=False, - with_2d_quantization=False, + with_2d_quantization=quantization == "nvfp4_4over6" and nvfp4_weight, stochastic_rounding=False, with_random_sign_mask=False, + use_4over6=quantization == "nvfp4_4over6", )(test) else: raise ValueError(f"Unsupported quantization scheme ({quantization})") @@ -503,6 +511,7 @@ def test_dtype_cast( quantization=quantization, test_dtype=dtype, test_device=device, + nvfp4_weight=True, ) # Construct operation @@ -910,6 +919,7 @@ def _test_basic_linear( quantization=quantization, test_dtype=dtype, test_device=device, + nvfp4_weight=True, ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, @@ -1082,6 +1092,7 @@ def test_linear( quantization=quantization, test_dtype=dtype, test_device=device, + nvfp4_weight=True, ) b_ref, b_test = None, None if bias: @@ -1512,7 +1523,7 @@ def test_add_extra_input( if in_place: if quantization in ("fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"): tols = dtype_tols(x1_test._fp8_dtype) - elif quantization == "nvfp4": + elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): tols = dtype_tols(x1_test._fp4_dtype) y_test = y_test.to(dtype=torch.float64, device="cpu") dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu") @@ -1883,7 +1894,7 @@ def test_clamped_swiglu( # Expected numerical error tols = dtype_tols(dtype) - if quantized_compute and quantization == "nvfp4": + if quantized_compute and quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): tols = dtype_tols(tex.DType.kFloat4E2M1) elif quantized_compute: tols = dtype_tols(tex.DType.kFloat8E4M3) @@ -2076,6 +2087,8 @@ def test_grouped_linear( pytest.skip("Quantization scheme is not used") if quantization is not None and dtype not in (torch.bfloat16, torch.float16): pytest.skip("Quantized group GEMM is only supported with BF16/FP16") + if quantization == "nvfp4_4over6": + pytest.skip("NVFP4 4over6 grouped quantization is not supported") if single_grouped_bias and not bias: pytest.skip("single_grouped_bias requires bias=True") @@ -2112,6 +2125,7 @@ def test_grouped_linear( quantization=quantization, test_dtype=dtype, test_device=device, + nvfp4_weight=True, requires_grad=weight_requires_grad, ) b_ref, b_test = None, None @@ -2622,6 +2636,7 @@ def test_forward_linear_bias_activation( quantization=quantization, test_dtype=dtype, test_device=device, + nvfp4_weight=True, ) b_ref, b_test = None, None if bias: @@ -2727,6 +2742,7 @@ def test_forward_linear_bias_add( quantization=quantization, test_dtype=dtype, test_device=device, + nvfp4_weight=True, ) b_ref, b_test = None, None if bias: @@ -2840,6 +2856,7 @@ def test_forward_linear_scale_add( quantization=quantization, test_dtype=dtype, test_device=device, + nvfp4_weight=True, ) x2_ref, x2_test = make_reference_and_test_tensors( out_shape, @@ -3122,6 +3139,7 @@ def test_backward_linear_add( quantization=quantization, test_dtype=dtype, test_device=device, + nvfp4_weight=True, ) dy1_ref, dy1_test = make_reference_and_test_tensors( out_shape, @@ -3225,6 +3243,7 @@ def test_backward_linear_scale( quantization=quantization, test_dtype=dtype, test_device=device, + nvfp4_weight=True, ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, @@ -3454,12 +3473,14 @@ def test_layernorm_mlp( quantization=quantization, test_dtype=dtype, test_device=device, + nvfp4_weight=True, ) w2_ref, w2_test = make_reference_and_test_tensors( (hidden_size, ffn_hidden_size // 2), quantization=quantization, test_dtype=dtype, test_device=device, + nvfp4_weight=True, ) b1_ref, b1_test, b2_ref, b2_test = None, None, None, None if bias: @@ -3608,7 +3629,14 @@ def test_grouped_mlp( pytest.skip("single_grouped_bias requires bias=True") if with_quantization and dtype not in (torch.bfloat16, torch.float16): pytest.skip("Quantized group GEMM is only supported with BF16/FP16") - if quantization == "nvfp4" and activation == "scaled_clamped_qgeglu" and bias: + if quantization == "nvfp4_4over6": + pytest.skip("NVFP4 4over6 grouped quantization is not supported") + if ( + with_quantization + and quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") + and activation == "scaled_clamped_qgeglu" + and bias + ): # TODO: ksivaman: Need to debug numerics for this case. pytest.skip("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU") @@ -3647,6 +3675,7 @@ def test_grouped_mlp( quantization=quantization, test_dtype=dtype, test_device=device, + nvfp4_weight=True, ) fc2_w_ref, fc2_w_test = make_reference_and_test_tensors( (hidden_size, hidden_size), @@ -3655,6 +3684,7 @@ def test_grouped_mlp( quantization=quantization, test_dtype=dtype, test_device=device, + nvfp4_weight=True, ) fc1_b_ref, fc1_b_test = None, None fc2_b_ref, fc2_b_test = None, None @@ -3830,7 +3860,7 @@ def test_grouped_mlp( # Loose tols for sanity checking tols = {"rtol": 0.125, "atol": 0.25} - if quantization == "nvfp4": + if quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): tols = {"rtol": 0.25, "atol": 0.5} # Check values From 3252d4e2b8f41f8ca8ac4867b30bdb707ce7db60 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 11 May 2026 22:45:40 -0700 Subject: [PATCH 35/57] Refactor test_fusible_ops Signed-off-by: Ziang Li --- tests/pytorch/test_fusible_ops.py | 38 +++++++++++++++++-------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index e201c673aa..179cc417d5 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -39,6 +39,7 @@ Float8Quantizer, MXFP8Quantizer, NVFP4Quantizer, + QuantizerRole, is_bf16_available, ) from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor @@ -145,7 +146,7 @@ def make_reference_and_test_tensors( test_dtype: torch.dtype = torch.float32, test_device: torch.device = "cuda", test_is_quantized: bool = False, - nvfp4_weight: bool = False, + quantizer_role: Optional[QuantizerRole] = None, requires_grad: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Construct tensors with the same values @@ -156,8 +157,8 @@ def make_reference_and_test_tensors( If a quantization scheme is provided, the tensor values are quantized so that they are representable. - NVFP4 4over6 activations use 1D quantization, while linear weights - use the recipe's 2D weight quantization path. + NVFP4 4over6 follows recipe role dispatch: activation-like tensors + use 1D quantization and weight tensors use the 2D weight path. """ @@ -188,10 +189,13 @@ def make_reference_and_test_tensors( elif quantization == "mxfp8": test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): + with_2d_quantization = False + if quantization == "nvfp4_4over6" and quantizer_role is not None: + with_2d_quantization = quantizer_role.tensor_type == "weight" test = NVFP4Quantizer( with_rht=False, with_post_rht_amax=False, - with_2d_quantization=quantization == "nvfp4_4over6" and nvfp4_weight, + with_2d_quantization=with_2d_quantization, stochastic_rounding=False, with_random_sign_mask=False, use_4over6=quantization == "nvfp4_4over6", @@ -511,7 +515,7 @@ def test_dtype_cast( quantization=quantization, test_dtype=dtype, test_device=device, - nvfp4_weight=True, + quantizer_role=QuantizerRole(tensor_type="weight"), ) # Construct operation @@ -919,7 +923,7 @@ def _test_basic_linear( quantization=quantization, test_dtype=dtype, test_device=device, - nvfp4_weight=True, + quantizer_role=QuantizerRole(tensor_type="weight"), ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, @@ -1092,7 +1096,7 @@ def test_linear( quantization=quantization, test_dtype=dtype, test_device=device, - nvfp4_weight=True, + quantizer_role=QuantizerRole(tensor_type="weight"), ) b_ref, b_test = None, None if bias: @@ -2125,7 +2129,7 @@ def test_grouped_linear( quantization=quantization, test_dtype=dtype, test_device=device, - nvfp4_weight=True, + quantizer_role=QuantizerRole(tensor_type="weight"), requires_grad=weight_requires_grad, ) b_ref, b_test = None, None @@ -2636,7 +2640,7 @@ def test_forward_linear_bias_activation( quantization=quantization, test_dtype=dtype, test_device=device, - nvfp4_weight=True, + quantizer_role=QuantizerRole(tensor_type="weight"), ) b_ref, b_test = None, None if bias: @@ -2742,7 +2746,7 @@ def test_forward_linear_bias_add( quantization=quantization, test_dtype=dtype, test_device=device, - nvfp4_weight=True, + quantizer_role=QuantizerRole(tensor_type="weight"), ) b_ref, b_test = None, None if bias: @@ -2856,7 +2860,7 @@ def test_forward_linear_scale_add( quantization=quantization, test_dtype=dtype, test_device=device, - nvfp4_weight=True, + quantizer_role=QuantizerRole(tensor_type="weight"), ) x2_ref, x2_test = make_reference_and_test_tensors( out_shape, @@ -3139,7 +3143,7 @@ def test_backward_linear_add( quantization=quantization, test_dtype=dtype, test_device=device, - nvfp4_weight=True, + quantizer_role=QuantizerRole(tensor_type="weight"), ) dy1_ref, dy1_test = make_reference_and_test_tensors( out_shape, @@ -3243,7 +3247,7 @@ def test_backward_linear_scale( quantization=quantization, test_dtype=dtype, test_device=device, - nvfp4_weight=True, + quantizer_role=QuantizerRole(tensor_type="weight"), ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, @@ -3473,14 +3477,14 @@ def test_layernorm_mlp( quantization=quantization, test_dtype=dtype, test_device=device, - nvfp4_weight=True, + quantizer_role=QuantizerRole(tensor_type="weight"), ) w2_ref, w2_test = make_reference_and_test_tensors( (hidden_size, ffn_hidden_size // 2), quantization=quantization, test_dtype=dtype, test_device=device, - nvfp4_weight=True, + quantizer_role=QuantizerRole(tensor_type="weight"), ) b1_ref, b1_test, b2_ref, b2_test = None, None, None, None if bias: @@ -3675,7 +3679,7 @@ def test_grouped_mlp( quantization=quantization, test_dtype=dtype, test_device=device, - nvfp4_weight=True, + quantizer_role=QuantizerRole(tensor_type="weight"), ) fc2_w_ref, fc2_w_test = make_reference_and_test_tensors( (hidden_size, hidden_size), @@ -3684,7 +3688,7 @@ def test_grouped_mlp( quantization=quantization, test_dtype=dtype, test_device=device, - nvfp4_weight=True, + quantizer_role=QuantizerRole(tensor_type="weight"), ) fc1_b_ref, fc1_b_test = None, None fc2_b_ref, fc2_b_test = None, None From 3f33c1d50fa6671c32164fd68048ff529632ab64 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 11 May 2026 23:25:49 -0700 Subject: [PATCH 36/57] Refactor ref and extend cpp test Signed-off-by: Ziang Li --- .../cpp/operator/test_cast_nvfp4_transpose.cu | 211 ++++++++++++++++-- .../cast/nvfp4/quantize_4over6_nvfp4.cuh | 9 +- .../custom_recipes/quantization_ref_nvfp4.py | 5 +- 3 files changed, 197 insertions(+), 28 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 63b1e00462..36f26ca070 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -83,6 +83,24 @@ float compute_global_encode_scaling_factor_FP4(const float global_amax, const bo return global_encode_scale; } +struct NVFP4FourOverSixDecodeScales { + fp8e4m3 map4; + fp8e4m3 map6; +}; + +NVFP4FourOverSixDecodeScales compute_4over6_decoding_scaling_factors( + const float block_amax, const float global_encode_scale) { + constexpr float fp4_max = 6.0f; + constexpr float scale_expansion_factor = 1.5f; + const float base_sf_high_precision = block_amax / fp4_max * global_encode_scale; + const float sf_high_precision_map4 = base_sf_high_precision * scale_expansion_factor; + const float sf_high_precision_map6 = base_sf_high_precision; + return { + static_cast(sf_high_precision_map4), + static_cast(sf_high_precision_map6), + }; +} + // 1D Scaling: Original implementation with 1x16 blocks template void quantize_nvfp4_1d(float (*OP)(const float), @@ -130,9 +148,10 @@ void quantize_nvfp4_1d(float (*OP)(const float), const size_t scale_idx = i * scales_stride + block_X; if (use_4over6) { - const float S_dec_b_map6 = block_amax * (S_enc * (1.0f / 6.0f)); - const fp8e4m3 S_dec_b_fp8_map6 = static_cast(S_dec_b_map6); - const fp8e4m3 S_dec_b_fp8_map4 = static_cast(S_dec_b_map6 * 1.5f); + const NVFP4FourOverSixDecodeScales S_dec_b_fp8 = + compute_4over6_decoding_scaling_factors(block_amax, S_enc); + const fp8e4m3 S_dec_b_fp8_map4 = S_dec_b_fp8.map4; + const fp8e4m3 S_dec_b_fp8_map6 = S_dec_b_fp8.map6; const float S_dec_b_fp32_map6 = static_cast(S_dec_b_fp8_map6); const float S_dec_b_fp32_map4 = static_cast(S_dec_b_fp8_map4); @@ -172,10 +191,12 @@ void quantize_nvfp4_1d(float (*OP)(const float), const fp4e2m1x2 casted_to_e2m1_pair_map6(scaled_elt_pair_map6); output_map6[cache_idx_x / 2] = casted_to_e2m1_pair_map6; const double2 truncated_pair_map6 = cvt_fp4x2_to_double2(casted_to_e2m1_pair_map6); - const double dequant_x_map6 = - truncated_pair_map6.x * S_dec_b_fp32_map6 * global_amax * mse_scale; - const double dequant_y_map6 = - truncated_pair_map6.y * S_dec_b_fp32_map6 * global_amax * mse_scale; + const float dequant_x_map6 = + static_cast(truncated_pair_map6.x) * S_dec_b_fp32_map6 * + global_amax * mse_scale; + const float dequant_y_map6 = + static_cast(truncated_pair_map6.y) * S_dec_b_fp32_map6 * + global_amax * mse_scale; err_map6 += (dequant_x_map6 - cached_x) * (dequant_x_map6 - cached_x); err_map6 += (dequant_y_map6 - cached_y) * (dequant_y_map6 - cached_y); @@ -186,10 +207,12 @@ void quantize_nvfp4_1d(float (*OP)(const float), const fp4e2m1x2 casted_to_e2m1_pair_map4(scaled_elt_pair_map4); output_map4[cache_idx_x / 2] = casted_to_e2m1_pair_map4; const double2 truncated_pair_map4 = cvt_fp4x2_to_double2(casted_to_e2m1_pair_map4); - const double dequant_x_map4 = - truncated_pair_map4.x * S_dec_b_fp32_map4 * global_amax * mse_scale; - const double dequant_y_map4 = - truncated_pair_map4.y * S_dec_b_fp32_map4 * global_amax * mse_scale; + const float dequant_x_map4 = + static_cast(truncated_pair_map4.x) * S_dec_b_fp32_map4 * + global_amax * mse_scale; + const float dequant_y_map4 = + static_cast(truncated_pair_map4.y) * S_dec_b_fp32_map4 * + global_amax * mse_scale; err_map4 += (dequant_x_map4 - cached_x) * (dequant_x_map4 - cached_x); err_map4 += (dequant_y_map4 - cached_y) * (dequant_y_map4 - cached_y); } @@ -286,9 +309,90 @@ void compute_2d_mathematical_scales(float (*OP)(const float), } // Compute E4M3 scaling factor for this 16x16 block - const float S_dec_b = block_amax / 6.0f; - const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); - math_scales[block_Y][block_X] = S_dec_b_fp8; + const float S_dec_b = block_amax / 6.0f * S_enc; + const fp8e4m3 S_dec_b_fp8_map6 = static_cast(S_dec_b); + if (use_4over6) { + const NVFP4FourOverSixDecodeScales S_dec_b_fp8 = + compute_4over6_decoding_scaling_factors(block_amax, S_enc); + const fp8e4m3 S_dec_b_fp8_map4 = S_dec_b_fp8.map4; + const fp8e4m3 S_dec_b_fp8_map6 = S_dec_b_fp8.map6; + const float S_dec_b_fp32_map6 = static_cast(S_dec_b_fp8_map6); + const float S_dec_b_fp32_map4 = static_cast(S_dec_b_fp8_map4); + + float scale_reciprocal_map6 = 0.0f; + if (S_dec_b_fp32_map6 != 0.0f) { + scale_reciprocal_map6 = + fminf(S_enc / S_dec_b_fp32_map6, Numeric_Traits::maxNorm); + } + + float scale_reciprocal_map4 = 0.0f; + if (S_dec_b_fp32_map4 != 0.0f) { + scale_reciprocal_map4 = + fminf(S_enc / S_dec_b_fp32_map4, Numeric_Traits::maxNorm); + } + + if (use_fast_math) { + scale_reciprocal_map6 = static_cast(static_cast(scale_reciprocal_map6)); + scale_reciprocal_map4 = static_cast(static_cast(scale_reciprocal_map4)); + } + + float err_map6 = 0.0f; + float err_map4 = 0.0f; + constexpr float mse_scale = 1.0f / (6.0f * 256.0f); + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; j += 2) { + const float input_x = static_cast(input[i * cols + j]); + const float act_x = OP(input_x); + const float cached_x = static_cast(static_cast(act_x)); + float cached_y = 0.0f; + if (j + 1 < j_max) { + const float input_y = static_cast(input[i * cols + j + 1]); + const float act_y = OP(input_y); + cached_y = static_cast(static_cast(act_y)); + } + + const float2 scaled_elt_pair_map6 = { + cached_x * scale_reciprocal_map6, + cached_y * scale_reciprocal_map6, + }; + const fp4e2m1x2 casted_to_e2m1_pair_map6(scaled_elt_pair_map6); + const double2 truncated_pair_map6 = + cvt_fp4x2_to_double2(casted_to_e2m1_pair_map6); + const float dequant_x_map6 = + static_cast(truncated_pair_map6.x) * S_dec_b_fp32_map6 * + global_amax * mse_scale; + const float dequant_y_map6 = + static_cast(truncated_pair_map6.y) * S_dec_b_fp32_map6 * + global_amax * mse_scale; + err_map6 += (dequant_x_map6 - cached_x) * (dequant_x_map6 - cached_x); + err_map6 += (dequant_y_map6 - cached_y) * (dequant_y_map6 - cached_y); + + const float2 scaled_elt_pair_map4 = { + cached_x * scale_reciprocal_map4, + cached_y * scale_reciprocal_map4, + }; + const fp4e2m1x2 casted_to_e2m1_pair_map4(scaled_elt_pair_map4); + const double2 truncated_pair_map4 = + cvt_fp4x2_to_double2(casted_to_e2m1_pair_map4); + const float dequant_x_map4 = + static_cast(truncated_pair_map4.x) * S_dec_b_fp32_map4 * + global_amax * mse_scale; + const float dequant_y_map4 = + static_cast(truncated_pair_map4.y) * S_dec_b_fp32_map4 * + global_amax * mse_scale; + err_map4 += (dequant_x_map4 - cached_x) * (dequant_x_map4 - cached_x); + err_map4 += (dequant_y_map4 - cached_y) * (dequant_y_map4 - cached_y); + } + } + + if (err_map4 < err_map6) { + math_scales[block_Y][block_X] = S_dec_b_fp8_map4; + } else { + math_scales[block_Y][block_X] = S_dec_b_fp8_map6; + } + } else { + math_scales[block_Y][block_X] = S_dec_b_fp8_map6; + } } } } @@ -626,10 +730,14 @@ template void performTest(float (*OP)(const float), const std::vector& shape, const bool use_fast_math, + const bool use_2d_quantization = false, const bool row_scaled_nvfp4 = false, const bool use_4over6 = false) { using namespace test; + NVTE_CHECK(!(use_2d_quantization && row_scaled_nvfp4), + "2D quantization and row-scaling are not supported together."); + DType itype = TypeInfo::dtype; DType otype = DType::kFloat4E2M1; @@ -702,8 +810,6 @@ void performTest(float (*OP)(const float), output.from_cpu(); } - // Reference implementation - bool use_2d_quantization = false; compute_ref(OP, input.rowwise_cpu_dptr(), ref_output.get(), @@ -813,6 +919,7 @@ class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam transformer_engine::DType, bool, bool, + bool, bool>> {}; TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { @@ -828,8 +935,9 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { const auto tensor_dims = std::get<1>(GetParam()); const DType input_type = std::get<2>(GetParam()); const bool use_fast_math = std::get<3>(GetParam()); - const bool row_scaled_nvfp4 = std::get<4>(GetParam()); - const bool use_4over6 = std::get<5>(GetParam()); + const bool use_2d_quantization = std::get<4>(GetParam()); + const bool row_scaled_nvfp4 = std::get<5>(GetParam()); + const bool use_4over6 = std::get<6>(GetParam()); // Skip tests if the input tensor is 1D if (tensor_dims.size() < 2) { @@ -847,7 +955,8 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { } TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, - performTest(OP, tensor_dims, use_fast_math, row_scaled_nvfp4, use_4over6); + performTest(OP, tensor_dims, use_fast_math, use_2d_quantization, + row_scaled_nvfp4, use_4over6); ); } @@ -872,6 +981,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(DType::kBFloat16), ::testing::Values(false), ::testing::Values(false), + ::testing::Values(false), ::testing::Values(false)), [](const testing::TestParamInfo& info) { std::string name = to_string(std::get<0>(info.param)); @@ -883,7 +993,10 @@ INSTANTIATE_TEST_SUITE_P( if (std::get<3>(info.param)) { name += "X_FAST_SCALING"; } - if (std::get<5>(info.param)) { + if (std::get<4>(info.param)) { + name += "X2D"; + } + if (std::get<6>(info.param)) { name += "X4OVER6"; } return name; @@ -897,6 +1010,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(tensor_dims[4], tensor_dims[9], tensor_dims[12]), ::testing::Values(DType::kBFloat16, DType::kFloat32), ::testing::Values(false), + ::testing::Values(false), ::testing::Values(true), ::testing::Values(false)), [](const testing::TestParamInfo& info) { @@ -910,9 +1024,12 @@ INSTANTIATE_TEST_SUITE_P( name += "X_FAST_SCALING"; } if (std::get<4>(info.param)) { - name += "XROW_SCALED"; + name += "X2D"; } if (std::get<5>(info.param)) { + name += "XROW_SCALED"; + } + if (std::get<6>(info.param)) { name += "X4OVER6"; } return name; @@ -924,9 +1041,10 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Combine( ::testing::Values(ActivationType::Identity), ::testing::Values(tensor_dims[4], tensor_dims[9], tensor_dims[12]), - ::testing::Values(DType::kFloat32), + ::testing::Values(DType::kBFloat16, DType::kFloat32), + ::testing::Values(false), + ::testing::Values(false), ::testing::Values(false), - ::testing::Values(false, true), ::testing::Values(true)), [](const testing::TestParamInfo& info) { std::string name = to_string(std::get<0>(info.param)); @@ -936,8 +1054,55 @@ INSTANTIATE_TEST_SUITE_P( } name += "X" + test::typeName(std::get<2>(info.param)); if (std::get<4>(info.param)) { + name += "X2D"; + } + if (std::get<5>(info.param)) { name += "XROW_SCALED"; } name += "X4OVER6"; return name; }); + +INSTANTIATE_TEST_SUITE_P( + OperatorTest4Over6RowScaled, + FusedCastTransposeNVFP4TestSuite, + ::testing::Combine( + ::testing::Values(ActivationType::Identity), + ::testing::Values(tensor_dims[4], tensor_dims[9], tensor_dims[12]), + ::testing::Values(DType::kFloat32), + ::testing::Values(false), + ::testing::Values(false), + ::testing::Values(true), + ::testing::Values(true)), + [](const testing::TestParamInfo& info) { + std::string name = to_string(std::get<0>(info.param)); + const auto& shape = std::get<1>(info.param); + for (const auto& s: shape) { + name += "X" + std::to_string(s); + } + name += "X" + test::typeName(std::get<2>(info.param)); + name += "XROW_SCALEDX4OVER6"; + return name; + }); + +INSTANTIATE_TEST_SUITE_P( + OperatorTest4Over62D, + FusedCastTransposeNVFP4TestSuite, + ::testing::Combine( + ::testing::Values(ActivationType::Identity), + ::testing::Values(tensor_dims[4], tensor_dims[9], tensor_dims[12]), + ::testing::Values(DType::kBFloat16), + ::testing::Values(false), + ::testing::Values(true), + ::testing::Values(false), + ::testing::Values(true)), + [](const testing::TestParamInfo& info) { + std::string name = to_string(std::get<0>(info.param)); + const auto& shape = std::get<1>(info.param); + for (const auto& s: shape) { + name += "X" + std::to_string(s); + } + name += "X" + test::typeName(std::get<2>(info.param)); + name += "X2DX4OVER6"; + return name; + }); diff --git a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh index 14aab1c208..d9ec2de947 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh @@ -37,9 +37,12 @@ __device__ __forceinline__ void compute_4over6_decoding_scaling_factors( const float block_amax, const float S_enc, nvfp4_scale_t &S_dec_b_fp8_map4, nvfp4_scale_t &S_dec_b_fp8_map6) { constexpr float fp4_max = detail::TypeExtrema::max; // 6.0f - const float sf_high_precision = block_amax / fp4_max * S_enc; - S_dec_b_fp8_map4 = static_cast(sf_high_precision * 1.5f); - S_dec_b_fp8_map6 = static_cast(sf_high_precision); + constexpr float scale_expansion_factor = 1.5f; + const float base_sf_high_precision = block_amax / fp4_max * S_enc; + const float sf_high_precision_map4 = base_sf_high_precision * scale_expansion_factor; + const float sf_high_precision_map6 = base_sf_high_precision; + S_dec_b_fp8_map4 = static_cast(sf_high_precision_map4); + S_dec_b_fp8_map6 = static_cast(sf_high_precision_map6); } template diff --git a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py index 3bc84ca785..d1c30f84f5 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py @@ -478,8 +478,9 @@ def _quantize_blockwise_4over6_reference( FLOAT8_E4M3_MAX = torch.tensor(448.0, device=x.device, dtype=torch.float32) GLOBAL_SCALE_E4M3_MAX = torch.tensor(256.0, device=x.device, dtype=torch.float32) - decode_scale_map6 = torch.div(vec_max, FLOAT4_E2M1_MAX) * global_encode_scale - decode_scale_map4 = decode_scale_map6 * 1.5 + decode_scale_base = torch.div(vec_max, FLOAT4_E2M1_MAX) * global_encode_scale + decode_scale_map4 = decode_scale_base * 1.5 + decode_scale_map6 = decode_scale_base decode_scale_map4 = torch.clamp( decode_scale_map4, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX ).to(torch.float8_e4m3fn) From 8607e0327d19879d13c32a7303816251fee7c6b4 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 11 May 2026 23:39:10 -0700 Subject: [PATCH 37/57] Clean up cpp test Signed-off-by: Ziang Li --- .../cpp/operator/test_cast_nvfp4_transpose.cu | 325 ++++++++---------- 1 file changed, 135 insertions(+), 190 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 36f26ca070..2e88f2dc1a 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -83,21 +83,102 @@ float compute_global_encode_scaling_factor_FP4(const float global_amax, const bo return global_encode_scale; } -struct NVFP4FourOverSixDecodeScales { - fp8e4m3 map4; - fp8e4m3 map6; +struct NVFP4FourOverSixQuantization { + fp8e4m3 scale_map4; + fp8e4m3 scale_map6; + float reciprocal_map4; + float reciprocal_map6; + fp4e2m1x2 quantized_map4; + fp4e2m1x2 quantized_map6; + float error_map4; + float error_map6; }; -NVFP4FourOverSixDecodeScales compute_4over6_decoding_scaling_factors( - const float block_amax, const float global_encode_scale) { +NVFP4FourOverSixQuantization compute_4over6_quantization_scales( + const float block_amax, const float global_encode_scale, const bool use_fast_math) { constexpr float fp4_max = 6.0f; constexpr float scale_expansion_factor = 1.5f; const float base_sf_high_precision = block_amax / fp4_max * global_encode_scale; const float sf_high_precision_map4 = base_sf_high_precision * scale_expansion_factor; const float sf_high_precision_map6 = base_sf_high_precision; + const fp8e4m3 scale_map4 = static_cast(sf_high_precision_map4); + const fp8e4m3 scale_map6 = static_cast(sf_high_precision_map6); + + float reciprocal_map4 = 0.0f; + const float scale_map4_fp32 = static_cast(scale_map4); + if (scale_map4_fp32 != 0.0f) { + reciprocal_map4 = fminf(global_encode_scale / scale_map4_fp32, + Numeric_Traits::maxNorm); + } + + float reciprocal_map6 = 0.0f; + const float scale_map6_fp32 = static_cast(scale_map6); + if (scale_map6_fp32 != 0.0f) { + reciprocal_map6 = fminf(global_encode_scale / scale_map6_fp32, + Numeric_Traits::maxNorm); + } + + if (use_fast_math) { + reciprocal_map4 = static_cast(static_cast(reciprocal_map4)); + reciprocal_map6 = static_cast(static_cast(reciprocal_map6)); + } + + const float2 zero = {0.0f, 0.0f}; return { - static_cast(sf_high_precision_map4), - static_cast(sf_high_precision_map6), + scale_map4, + scale_map6, + reciprocal_map4, + reciprocal_map6, + fp4e2m1x2(zero), + fp4e2m1x2(zero), + 0.0f, + 0.0f, + }; +} + +float compute_4over6_dequantized_value(const double quantized_value, + const fp8e4m3 scale, + const float global_amax) { + constexpr float mse_scale = 1.0f / (6.0f * 256.0f); + return static_cast(quantized_value) * static_cast(scale) * global_amax * + mse_scale; +} + +float compute_squared_error(const float value, const float reference) { + const float diff = value - reference; + return diff * diff; +} + +NVFP4FourOverSixQuantization quantize_4over6_pair( + const float x, const float y, const NVFP4FourOverSixQuantization& quantization, + const float global_amax) { + const float2 scaled_map4 = {x * quantization.reciprocal_map4, + y * quantization.reciprocal_map4}; + const fp4e2m1x2 quantized_map4(scaled_map4); + const double2 truncated_map4 = cvt_fp4x2_to_double2(quantized_map4); + const float dequant_x_map4 = + compute_4over6_dequantized_value(truncated_map4.x, quantization.scale_map4, global_amax); + const float dequant_y_map4 = + compute_4over6_dequantized_value(truncated_map4.y, quantization.scale_map4, global_amax); + + const float2 scaled_map6 = {x * quantization.reciprocal_map6, + y * quantization.reciprocal_map6}; + const fp4e2m1x2 quantized_map6(scaled_map6); + const double2 truncated_map6 = cvt_fp4x2_to_double2(quantized_map6); + const float dequant_x_map6 = + compute_4over6_dequantized_value(truncated_map6.x, quantization.scale_map6, global_amax); + const float dequant_y_map6 = + compute_4over6_dequantized_value(truncated_map6.y, quantization.scale_map6, global_amax); + + return { + quantization.scale_map4, + quantization.scale_map6, + quantization.reciprocal_map4, + quantization.reciprocal_map6, + quantized_map4, + quantized_map6, + compute_squared_error(dequant_x_map4, x) + compute_squared_error(dequant_y_map4, y), + compute_squared_error(dequant_x_map6, x) + compute_squared_error(dequant_y_map6, y), }; } @@ -148,84 +229,36 @@ void quantize_nvfp4_1d(float (*OP)(const float), const size_t scale_idx = i * scales_stride + block_X; if (use_4over6) { - const NVFP4FourOverSixDecodeScales S_dec_b_fp8 = - compute_4over6_decoding_scaling_factors(block_amax, S_enc); - const fp8e4m3 S_dec_b_fp8_map4 = S_dec_b_fp8.map4; - const fp8e4m3 S_dec_b_fp8_map6 = S_dec_b_fp8.map6; - const float S_dec_b_fp32_map6 = static_cast(S_dec_b_fp8_map6); - const float S_dec_b_fp32_map4 = static_cast(S_dec_b_fp8_map4); - - float scale_reciprocal_map6 = 0.0f; - if (S_dec_b_fp32_map6 != 0.0f) { - scale_reciprocal_map6 = - fminf(S_enc / S_dec_b_fp32_map6, Numeric_Traits::maxNorm); - } - - float scale_reciprocal_map4 = 0.0f; - if (S_dec_b_fp32_map4 != 0.0f) { - scale_reciprocal_map4 = - fminf(S_enc / S_dec_b_fp32_map4, Numeric_Traits::maxNorm); - } - - if (use_fast_math) { - scale_reciprocal_map6 = static_cast(static_cast(scale_reciprocal_map6)); - scale_reciprocal_map4 = static_cast(static_cast(scale_reciprocal_map4)); - } + const NVFP4FourOverSixQuantization quantization = + compute_4over6_quantization_scales(block_amax, S_enc, use_fast_math); std::array output_map6; std::array output_map4; float err_map6 = 0.0f; float err_map4 = 0.0f; - constexpr float mse_scale = 1.0f / (6.0f * 256.0f); for (size_t j = j_min; j < j_max; j += 2) { const int cache_idx_x = j - j_min; const int cache_idx_y = cache_idx_x + 1; const float cached_x = cache_buffer[cache_idx_x]; const float cached_y = cache_buffer[cache_idx_y]; - - const float2 scaled_elt_pair_map6 = { - cached_x * scale_reciprocal_map6, - cached_y * scale_reciprocal_map6, - }; - const fp4e2m1x2 casted_to_e2m1_pair_map6(scaled_elt_pair_map6); - output_map6[cache_idx_x / 2] = casted_to_e2m1_pair_map6; - const double2 truncated_pair_map6 = cvt_fp4x2_to_double2(casted_to_e2m1_pair_map6); - const float dequant_x_map6 = - static_cast(truncated_pair_map6.x) * S_dec_b_fp32_map6 * - global_amax * mse_scale; - const float dequant_y_map6 = - static_cast(truncated_pair_map6.y) * S_dec_b_fp32_map6 * - global_amax * mse_scale; - err_map6 += (dequant_x_map6 - cached_x) * (dequant_x_map6 - cached_x); - err_map6 += (dequant_y_map6 - cached_y) * (dequant_y_map6 - cached_y); - - const float2 scaled_elt_pair_map4 = { - cached_x * scale_reciprocal_map4, - cached_y * scale_reciprocal_map4, - }; - const fp4e2m1x2 casted_to_e2m1_pair_map4(scaled_elt_pair_map4); - output_map4[cache_idx_x / 2] = casted_to_e2m1_pair_map4; - const double2 truncated_pair_map4 = cvt_fp4x2_to_double2(casted_to_e2m1_pair_map4); - const float dequant_x_map4 = - static_cast(truncated_pair_map4.x) * S_dec_b_fp32_map4 * - global_amax * mse_scale; - const float dequant_y_map4 = - static_cast(truncated_pair_map4.y) * S_dec_b_fp32_map4 * - global_amax * mse_scale; - err_map4 += (dequant_x_map4 - cached_x) * (dequant_x_map4 - cached_x); - err_map4 += (dequant_y_map4 - cached_y) * (dequant_y_map4 - cached_y); + const NVFP4FourOverSixQuantization pair_quantization = + quantize_4over6_pair(cached_x, cached_y, quantization, global_amax); + output_map4[cache_idx_x / 2] = pair_quantization.quantized_map4; + output_map6[cache_idx_x / 2] = pair_quantization.quantized_map6; + err_map4 += pair_quantization.error_map4; + err_map6 += pair_quantization.error_map6; } const bool pick_map4 = err_map4 < err_map6; if (pick_map4) { - scales[scale_idx] = S_dec_b_fp8_map4; + scales[scale_idx] = quantization.scale_map4; for (size_t j = j_min; j < j_max; j += 2) { const int idx_pair = (i * cols + j) / 2; output[idx_pair] = output_map4[(j - j_min) / 2]; } } else { - scales[scale_idx] = S_dec_b_fp8_map6; + scales[scale_idx] = quantization.scale_map6; for (size_t j = j_min; j < j_max; j += 2) { const int idx_pair = (i * cols + j) / 2; output[idx_pair] = output_map6[(j - j_min) / 2]; @@ -309,36 +342,12 @@ void compute_2d_mathematical_scales(float (*OP)(const float), } // Compute E4M3 scaling factor for this 16x16 block - const float S_dec_b = block_amax / 6.0f * S_enc; - const fp8e4m3 S_dec_b_fp8_map6 = static_cast(S_dec_b); if (use_4over6) { - const NVFP4FourOverSixDecodeScales S_dec_b_fp8 = - compute_4over6_decoding_scaling_factors(block_amax, S_enc); - const fp8e4m3 S_dec_b_fp8_map4 = S_dec_b_fp8.map4; - const fp8e4m3 S_dec_b_fp8_map6 = S_dec_b_fp8.map6; - const float S_dec_b_fp32_map6 = static_cast(S_dec_b_fp8_map6); - const float S_dec_b_fp32_map4 = static_cast(S_dec_b_fp8_map4); - - float scale_reciprocal_map6 = 0.0f; - if (S_dec_b_fp32_map6 != 0.0f) { - scale_reciprocal_map6 = - fminf(S_enc / S_dec_b_fp32_map6, Numeric_Traits::maxNorm); - } - - float scale_reciprocal_map4 = 0.0f; - if (S_dec_b_fp32_map4 != 0.0f) { - scale_reciprocal_map4 = - fminf(S_enc / S_dec_b_fp32_map4, Numeric_Traits::maxNorm); - } - - if (use_fast_math) { - scale_reciprocal_map6 = static_cast(static_cast(scale_reciprocal_map6)); - scale_reciprocal_map4 = static_cast(static_cast(scale_reciprocal_map4)); - } + const NVFP4FourOverSixQuantization quantization = + compute_4over6_quantization_scales(block_amax, S_enc, use_fast_math); float err_map6 = 0.0f; float err_map4 = 0.0f; - constexpr float mse_scale = 1.0f / (6.0f * 256.0f); for (size_t i = i_min; i < i_max; ++i) { for (size_t j = j_min; j < j_max; j += 2) { const float input_x = static_cast(input[i * cols + j]); @@ -350,47 +359,21 @@ void compute_2d_mathematical_scales(float (*OP)(const float), const float act_y = OP(input_y); cached_y = static_cast(static_cast(act_y)); } - - const float2 scaled_elt_pair_map6 = { - cached_x * scale_reciprocal_map6, - cached_y * scale_reciprocal_map6, - }; - const fp4e2m1x2 casted_to_e2m1_pair_map6(scaled_elt_pair_map6); - const double2 truncated_pair_map6 = - cvt_fp4x2_to_double2(casted_to_e2m1_pair_map6); - const float dequant_x_map6 = - static_cast(truncated_pair_map6.x) * S_dec_b_fp32_map6 * - global_amax * mse_scale; - const float dequant_y_map6 = - static_cast(truncated_pair_map6.y) * S_dec_b_fp32_map6 * - global_amax * mse_scale; - err_map6 += (dequant_x_map6 - cached_x) * (dequant_x_map6 - cached_x); - err_map6 += (dequant_y_map6 - cached_y) * (dequant_y_map6 - cached_y); - - const float2 scaled_elt_pair_map4 = { - cached_x * scale_reciprocal_map4, - cached_y * scale_reciprocal_map4, - }; - const fp4e2m1x2 casted_to_e2m1_pair_map4(scaled_elt_pair_map4); - const double2 truncated_pair_map4 = - cvt_fp4x2_to_double2(casted_to_e2m1_pair_map4); - const float dequant_x_map4 = - static_cast(truncated_pair_map4.x) * S_dec_b_fp32_map4 * - global_amax * mse_scale; - const float dequant_y_map4 = - static_cast(truncated_pair_map4.y) * S_dec_b_fp32_map4 * - global_amax * mse_scale; - err_map4 += (dequant_x_map4 - cached_x) * (dequant_x_map4 - cached_x); - err_map4 += (dequant_y_map4 - cached_y) * (dequant_y_map4 - cached_y); + const NVFP4FourOverSixQuantization pair_quantization = + quantize_4over6_pair(cached_x, cached_y, quantization, global_amax); + err_map4 += pair_quantization.error_map4; + err_map6 += pair_quantization.error_map6; } } if (err_map4 < err_map6) { - math_scales[block_Y][block_X] = S_dec_b_fp8_map4; + math_scales[block_Y][block_X] = quantization.scale_map4; } else { - math_scales[block_Y][block_X] = S_dec_b_fp8_map6; + math_scales[block_Y][block_X] = quantization.scale_map6; } } else { + const float S_dec_b = block_amax / 6.0f * S_enc; + const fp8e4m3 S_dec_b_fp8_map6 = static_cast(S_dec_b); math_scales[block_Y][block_X] = S_dec_b_fp8_map6; } } @@ -972,6 +955,28 @@ std::string to_string(const ActivationType Act_type) { } } +std::string test_name(const FusedCastTransposeNVFP4TestSuite::ParamType& param) { + std::string name = to_string(std::get<0>(param)); + const auto& shape = std::get<1>(param); + for (const auto& s: shape) { + name += "X" + std::to_string(s); + } + name += "X" + test::typeName(std::get<2>(param)); + if (std::get<3>(param)) { + name += "X_FAST_SCALING"; + } + if (std::get<4>(param)) { + name += "X2D"; + } + if (std::get<5>(param)) { + name += "XROW_SCALED"; + } + if (std::get<6>(param)) { + name += "X4OVER6"; + } + return name; +} + INSTANTIATE_TEST_SUITE_P( OperatorTest, FusedCastTransposeNVFP4TestSuite, @@ -984,22 +989,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(false), ::testing::Values(false)), [](const testing::TestParamInfo& info) { - std::string name = to_string(std::get<0>(info.param)); - const auto& shape = std::get<1>(info.param); - for ( const auto& s: shape) { - name += "X" + std::to_string(s); - } - name += "X" + test::typeName(std::get<2>(info.param)); - if (std::get<3>(info.param)) { - name += "X_FAST_SCALING"; - } - if (std::get<4>(info.param)) { - name += "X2D"; - } - if (std::get<6>(info.param)) { - name += "X4OVER6"; - } - return name; + return test_name(info.param); }); INSTANTIATE_TEST_SUITE_P( @@ -1014,25 +1004,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(true), ::testing::Values(false)), [](const testing::TestParamInfo& info) { - std::string name = to_string(std::get<0>(info.param)); - const auto& shape = std::get<1>(info.param); - for (const auto& s: shape) { - name += "X" + std::to_string(s); - } - name += "X" + test::typeName(std::get<2>(info.param)); - if (std::get<3>(info.param)) { - name += "X_FAST_SCALING"; - } - if (std::get<4>(info.param)) { - name += "X2D"; - } - if (std::get<5>(info.param)) { - name += "XROW_SCALED"; - } - if (std::get<6>(info.param)) { - name += "X4OVER6"; - } - return name; + return test_name(info.param); }); INSTANTIATE_TEST_SUITE_P( @@ -1047,20 +1019,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(false), ::testing::Values(true)), [](const testing::TestParamInfo& info) { - std::string name = to_string(std::get<0>(info.param)); - const auto& shape = std::get<1>(info.param); - for (const auto& s: shape) { - name += "X" + std::to_string(s); - } - name += "X" + test::typeName(std::get<2>(info.param)); - if (std::get<4>(info.param)) { - name += "X2D"; - } - if (std::get<5>(info.param)) { - name += "XROW_SCALED"; - } - name += "X4OVER6"; - return name; + return test_name(info.param); }); INSTANTIATE_TEST_SUITE_P( @@ -1075,14 +1034,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(true), ::testing::Values(true)), [](const testing::TestParamInfo& info) { - std::string name = to_string(std::get<0>(info.param)); - const auto& shape = std::get<1>(info.param); - for (const auto& s: shape) { - name += "X" + std::to_string(s); - } - name += "X" + test::typeName(std::get<2>(info.param)); - name += "XROW_SCALEDX4OVER6"; - return name; + return test_name(info.param); }); INSTANTIATE_TEST_SUITE_P( @@ -1097,12 +1049,5 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(false), ::testing::Values(true)), [](const testing::TestParamInfo& info) { - std::string name = to_string(std::get<0>(info.param)); - const auto& shape = std::get<1>(info.param); - for (const auto& s: shape) { - name += "X" + std::to_string(s); - } - name += "X" + test::typeName(std::get<2>(info.param)); - name += "X2DX4OVER6"; - return name; + return test_name(info.param); }); From d3dbf34d238ffa96d26364dda59ca9040ce6db88 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 11 May 2026 23:44:08 -0700 Subject: [PATCH 38/57] Minor comment Signed-off-by: Ziang Li --- .../common/cast/nvfp4/quantize_transpose_nvfp4.cuh | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index 75886fe54b..d768108eaa 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -921,6 +921,7 @@ __global__ void __launch_bounds__(THREADS_NUM) __shared__ alignas(8) uint64_t mbar[STAGES]; __shared__ __align__(16) float block_amax_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X + 1]; + // Only used for 4over6 quantization __shared__ __align__(16) float err_map4_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X][BLOCK_DIM]; __shared__ __align__(16) float err_map6_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X][BLOCK_DIM]; __shared__ __align__(16) uint8_t pick_map4_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X]; From 565f33f69427bae8a65166a8cbfa62ef2c24a4be Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 11 May 2026 23:57:16 -0700 Subject: [PATCH 39/57] Drop doc Signed-off-by: Ziang Li --- docs/envvars.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/envvars.rst b/docs/envvars.rst index 82ada36c19..5ad271837c 100644 --- a/docs/envvars.rst +++ b/docs/envvars.rst @@ -261,7 +261,7 @@ Kernel Configuration :Type: ``int`` (0 or 1) :Default: ``0`` - :Description: Enable fast math optimizations in supported kernels, including runtime-compiled (NVRTC) kernels and NVFP4 4over6 quantization. This trades numerical accuracy for performance. These optimizations are experimental and inconsistently implemented. + :Description: Enable fast math optimizations in runtime-compiled (NVRTC) kernels. This trades numerical accuracy for performance. These optimizations are experimental and inconsistently implemented. .. envvar:: NVTE_DISABLE_NVRTC From 54b4da85961503dbb739ef7450c25a4a11725d38 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 12 May 2026 02:04:19 -0700 Subject: [PATCH 40/57] Explicit handle conditional smem buffer Signed-off-by: Ziang Li --- .../cast/nvfp4/quantize_4over6_nvfp4.cuh | 30 +++++-- .../cast/nvfp4/quantize_transpose_nvfp4.cuh | 82 +++++++++++++------ 2 files changed, 79 insertions(+), 33 deletions(-) diff --git a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh index d9ec2de947..cf93e80577 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh @@ -312,6 +312,14 @@ struct QuantizationCandidates4Over6 { uint32_t rOut_map6[2]; }; +template +struct alignas(16) QuantizationScratch4Over6 { + alignas(16) float err_map4_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X][BLOCK_DIM]; + alignas(16) float err_map6_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X][BLOCK_DIM]; + alignas(16) uint8_t pick_map4_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X]; + alignas(16) nvfp4_scale_t selected_scale_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X]; +}; + template __device__ __forceinline__ void quantize_4over6_candidates_16x( const float (&x)[16], const QuantizationScales4Over6 &scaling_factors, @@ -368,15 +376,25 @@ __device__ __forceinline__ bool record_and_select_4over6_2d_block( return pick_map4_matrix[block_in_tile_y][block_in_tile_x] != 0; } +template +__device__ __forceinline__ bool record_and_select_4over6_2d_block( + const QuantizationScales4Over6 &scaling_factors, + const size_t block_in_tile_y, const size_t block_in_tile_x, const size_t participant_idx, + QuantizationScratch4Over6 &scratch, + nvfp4_scale_t &S_dec_b_fp8, const QuantizationCandidates4Over6 &candidates) { + return record_and_select_4over6_2d_block( + scaling_factors, block_in_tile_y, block_in_tile_x, participant_idx, + scratch.err_map4_matrix, scratch.err_map6_matrix, scratch.pick_map4_matrix, + scratch.selected_scale_matrix, S_dec_b_fp8, candidates); +} + template __device__ __forceinline__ bool quantize_and_select_4over6_2d_block_16x( const float (&x)[16], const float block_amax, const float global_encode_scale, const float global_decode_scale, const float global_amax, const size_t block_in_tile_y, const size_t block_in_tile_x, const size_t participant_idx, - float (&err_map4_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X][BLOCK_DIM], - float (&err_map6_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X][BLOCK_DIM], - uint8_t (&pick_map4_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X], - nvfp4_scale_t (&selected_scale_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X], + QuantizationScratch4Over6 &scratch, nvfp4_scale_t &S_dec_b_fp8, QuantizationCandidates4Over6 &candidates) { const auto scaling_factors = compute_4over6_fp4_encode_quantization_scaling_factors( block_amax, global_encode_scale, global_decode_scale); @@ -384,8 +402,8 @@ __device__ __forceinline__ bool quantize_and_select_4over6_2d_block_16x( const bool pick_map4 = record_and_select_4over6_2d_block( - scaling_factors, block_in_tile_y, block_in_tile_x, participant_idx, err_map4_matrix, - err_map6_matrix, pick_map4_matrix, selected_scale_matrix, S_dec_b_fp8, candidates); + scaling_factors, block_in_tile_y, block_in_tile_x, participant_idx, scratch, + S_dec_b_fp8, candidates); return pick_map4; } diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index d768108eaa..0833d02eff 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -181,6 +181,10 @@ constexpr size_t RNG_GENS_PER_THREAD = SCALES_PER_THREAD / 4; constexpr size_t TILE_DIM_Y = 32; constexpr size_t TILE_DIM_X = 128; +constexpr size_t NVFP4_2D_BLOCK_DIM = 16; +constexpr size_t NVFP4_2D_BLOCKS_PER_TILE_Y = TILE_DIM_Y / NVFP4_2D_BLOCK_DIM; +constexpr size_t NVFP4_2D_BLOCKS_PER_TILE_X = TILE_DIM_X / NVFP4_2D_BLOCK_DIM; + // SHould this be SCALE_DIM or BLOCK_DIM? Both are 16, should work for both 1D and 2D constexpr size_t SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; constexpr size_t SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 128 / 16 = 8 @@ -814,9 +818,9 @@ __global__ void __launch_bounds__(THREADS_NUM) 0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x // NEW: 2D Block-based scaling constants - constexpr size_t BLOCK_DIM = 16; - constexpr size_t BLOCKS_PER_TILE_Y = TILE_DIM_Y / BLOCK_DIM; // 32/16 = 2 - constexpr size_t BLOCKS_PER_TILE_X = TILE_DIM_X / BLOCK_DIM; // 128/16 = 8 + constexpr size_t BLOCK_DIM = NVFP4_2D_BLOCK_DIM; + constexpr size_t BLOCKS_PER_TILE_Y = NVFP4_2D_BLOCKS_PER_TILE_Y; + constexpr size_t BLOCKS_PER_TILE_X = NVFP4_2D_BLOCKS_PER_TILE_X; constexpr size_t ITERATIONS_BLOCK = 2; // iterations to calculate 2d block amaxes of 1 tile constexpr size_t BLOCKS_PER_WARP = BLOCKS_PER_TILE_X / (THREADS_NUM / 32); // 8 / (128/32) = 2 @@ -871,6 +875,8 @@ __global__ void __launch_bounds__(THREADS_NUM) constexpr size_t out_mem_rowwise_data = buff_size_aligned_out; constexpr size_t out_mem_colwise_data = buff_size_aligned_out; constexpr size_t out_mem_rowwise_scales = 0; + constexpr size_t out_mem_colwise_scales = + (CHUNK_DIM_Y * CHUNK_DIM_X) / SCALE_DIM * sizeof(nvfp4_scale_t); extern __shared__ char dynamic_shmem[]; uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); @@ -890,6 +896,17 @@ __global__ void __launch_bounds__(THREADS_NUM) dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + using FourOverSixScratch = + QuantizationScratch4Over6; + FourOverSixScratch *four_over_six_scratch = nullptr; + if constexpr (USE_4OVER6) { + constexpr size_t four_over_six_scratch_offset = in_mem + out_mem_rowwise_data + + out_mem_colwise_data + out_mem_rowwise_scales + + out_mem_colwise_scales; + four_over_six_scratch = + reinterpret_cast(dshmem + four_over_six_scratch_offset); + } + constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; const bool is_master_thread = (threadIdx.x == 0); @@ -921,12 +938,6 @@ __global__ void __launch_bounds__(THREADS_NUM) __shared__ alignas(8) uint64_t mbar[STAGES]; __shared__ __align__(16) float block_amax_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X + 1]; - // Only used for 4over6 quantization - __shared__ __align__(16) float err_map4_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X][BLOCK_DIM]; - __shared__ __align__(16) float err_map6_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X][BLOCK_DIM]; - __shared__ __align__(16) uint8_t pick_map4_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X]; - __shared__ __align__(16) - nvfp4_scale_t selected_scale_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X]; // Helper function for warp reduction auto warp_reduce_amax = [](float thread_amax, int block_in_warp) -> float { @@ -1105,8 +1116,8 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_and_select_4over6_2d_block_16x( x_4over6, block_amax, S_enc_colwise, S_dec_colwise, global_amax_colwise, - block_in_tile_y, block_in_tile_x, block_col, err_map4_matrix, err_map6_matrix, - pick_map4_matrix, selected_scale_matrix, S_dec_b_fp8, candidates); + block_in_tile_y, block_in_tile_x, block_col, *four_over_six_scratch, S_dec_b_fp8, + candidates); const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; @@ -1268,8 +1279,8 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_and_select_4over6_2d_block_16x( in_4over6_rowwise, block_amax, S_enc_rowwise, S_dec_rowwise, global_amax_rowwise, - block_in_tile_y, block_in_tile_x, tid_Y_rowwise, err_map4_matrix, err_map6_matrix, - pick_map4_matrix, selected_scale_matrix, S_dec_b_fp8, candidates); + block_in_tile_y, block_in_tile_x, tid_Y_rowwise, *four_over_six_scratch, + S_dec_b_fp8, candidates); const size_t scales_offset_Y = scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; @@ -1520,7 +1531,8 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, constexpr size_t out_mem = out_data_mem + out_data_transpose_mem; - constexpr size_t dshmem_size = in_mem + out_mem + out_scales_transpose_mem + TMA_SHMEM_ALIGNMENT; + constexpr size_t base_dshmem_size = + in_mem + out_mem + out_scales_transpose_mem + TMA_SHMEM_ALIGNMENT; TRANSFORMER_ENGINE_SWITCH_CONDITION( use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, @@ -1529,23 +1541,39 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, TRANSFORMER_ENGINE_SWITCH_CONDITION(row_scaled_nvfp4, ROW_SCALED_NVFP4, { TRANSFORMER_ENGINE_SWITCH_CONDITION(use_4over6, USE_4OVER6, { TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { - auto kernel = quantize_transpose_nvfp4_kernel; - - if constexpr (use_2d_quantization) { - kernel = + auto launch_kernel = [&](auto kernel, const size_t dshmem_size) { + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, + scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, + cols, scale_stride, scale_stride_transpose, rng_state); + }; + + if constexpr (use_2d_quantization && USE_4OVER6) { + using FourOverSixScratch = core::QuantizationScratch4Over6< + NVFP4_2D_BLOCK_DIM, NVFP4_2D_BLOCKS_PER_TILE_Y, NVFP4_2D_BLOCKS_PER_TILE_X>; + constexpr size_t dshmem_size = + base_dshmem_size + + DIVUP_TO_MULTIPLE(sizeof(FourOverSixScratch), TMA_SHMEM_ALIGNMENT); + auto kernel = quantize_transpose_nvfp4_2D_kernel; + launch_kernel(kernel, dshmem_size); + } else { + constexpr size_t dshmem_size = base_dshmem_size; + auto kernel = quantize_transpose_nvfp4_kernel; + if constexpr (use_2d_quantization) { + kernel = + quantize_transpose_nvfp4_2D_kernel; + } + launch_kernel(kernel, dshmem_size); } - - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - dshmem_size); - kernel<<>>( - tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, - scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, - scale_stride, scale_stride_transpose, rng_state); }); }); }););); From fa09200066892dd20fc6a93c612786d9786a7449 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 12 May 2026 02:17:08 -0700 Subject: [PATCH 41/57] Further clean up Signed-off-by: Ziang Li --- .../cast/nvfp4/quantize_4over6_nvfp4.cuh | 10 ++-- .../cast/nvfp4/quantize_transpose_nvfp4.cuh | 49 ++++++++----------- 2 files changed, 25 insertions(+), 34 deletions(-) diff --git a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh index cf93e80577..a5c88a9160 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh @@ -384,9 +384,9 @@ __device__ __forceinline__ bool record_and_select_4over6_2d_block( QuantizationScratch4Over6 &scratch, nvfp4_scale_t &S_dec_b_fp8, const QuantizationCandidates4Over6 &candidates) { return record_and_select_4over6_2d_block( - scaling_factors, block_in_tile_y, block_in_tile_x, participant_idx, - scratch.err_map4_matrix, scratch.err_map6_matrix, scratch.pick_map4_matrix, - scratch.selected_scale_matrix, S_dec_b_fp8, candidates); + scaling_factors, block_in_tile_y, block_in_tile_x, participant_idx, scratch.err_map4_matrix, + scratch.err_map6_matrix, scratch.pick_map4_matrix, scratch.selected_scale_matrix, S_dec_b_fp8, + candidates); } template @@ -402,8 +402,8 @@ __device__ __forceinline__ bool quantize_and_select_4over6_2d_block_16x( const bool pick_map4 = record_and_select_4over6_2d_block( - scaling_factors, block_in_tile_y, block_in_tile_x, participant_idx, scratch, - S_dec_b_fp8, candidates); + scaling_factors, block_in_tile_y, block_in_tile_x, participant_idx, scratch, S_dec_b_fp8, + candidates); return pick_map4; } diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index 0833d02eff..036ff74cd5 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -1541,39 +1541,30 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, TRANSFORMER_ENGINE_SWITCH_CONDITION(row_scaled_nvfp4, ROW_SCALED_NVFP4, { TRANSFORMER_ENGINE_SWITCH_CONDITION(use_4over6, USE_4OVER6, { TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { - auto launch_kernel = [&](auto kernel, const size_t dshmem_size) { - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - dshmem_size); - kernel<<>>( - tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, - scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, - cols, scale_stride, scale_stride_transpose, rng_state); - }; - - if constexpr (use_2d_quantization && USE_4OVER6) { - using FourOverSixScratch = core::QuantizationScratch4Over6< - NVFP4_2D_BLOCK_DIM, NVFP4_2D_BLOCKS_PER_TILE_Y, NVFP4_2D_BLOCKS_PER_TILE_X>; - constexpr size_t dshmem_size = - base_dshmem_size + - DIVUP_TO_MULTIPLE(sizeof(FourOverSixScratch), TMA_SHMEM_ALIGNMENT); - auto kernel = + auto kernel = quantize_transpose_nvfp4_kernel; + + if constexpr (use_2d_quantization) { + kernel = quantize_transpose_nvfp4_2D_kernel; - launch_kernel(kernel, dshmem_size); - } else { - constexpr size_t dshmem_size = base_dshmem_size; - auto kernel = quantize_transpose_nvfp4_kernel; - if constexpr (use_2d_quantization) { - kernel = - quantize_transpose_nvfp4_2D_kernel; - } - launch_kernel(kernel, dshmem_size); } + using FourOverSixScratch = + core::QuantizationScratch4Over6; + constexpr size_t dshmem_size = + base_dshmem_size + + ((use_2d_quantization && USE_4OVER6) + ? DIVUP_TO_MULTIPLE(sizeof(FourOverSixScratch), TMA_SHMEM_ALIGNMENT) + : 0); + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, + scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, + scale_stride, scale_stride_transpose, rng_state); }); }); }););); From e57e8beb8c501434400df349331b4c18f0c69e5d Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 12 May 2026 02:38:18 -0700 Subject: [PATCH 42/57] More templates Signed-off-by: Ziang Li --- .../common/cast/nvfp4/quantize_4over6_nvfp4.cuh | 9 +++++++++ .../common/cast/nvfp4/quantize_transpose_nvfp4.cuh | 5 ++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh index a5c88a9160..47c0eede25 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh @@ -318,6 +318,15 @@ struct alignas(16) QuantizationScratch4Over6 { alignas(16) float err_map6_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X][BLOCK_DIM]; alignas(16) uint8_t pick_map4_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X]; alignas(16) nvfp4_scale_t selected_scale_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X]; + + template + static constexpr size_t dynamic_shared_memory_size() { + if constexpr (USE_2D_QUANTIZATION && USE_4OVER6) { + return ((sizeof(QuantizationScratch4Over6) + TMA_SHMEM_ALIGNMENT - 1) / TMA_SHMEM_ALIGNMENT) * + TMA_SHMEM_ALIGNMENT; + } + return 0; + } }; template diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index 036ff74cd5..046a60882c 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -1556,9 +1556,8 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, NVFP4_2D_BLOCKS_PER_TILE_X>; constexpr size_t dshmem_size = base_dshmem_size + - ((use_2d_quantization && USE_4OVER6) - ? DIVUP_TO_MULTIPLE(sizeof(FourOverSixScratch), TMA_SHMEM_ALIGNMENT) - : 0); + FourOverSixScratch::template dynamic_shared_memory_size(); cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); kernel<<>>( From a1df31945632a901d27c0e93316a5a4aae70d7b1 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 12 May 2026 02:57:20 -0700 Subject: [PATCH 43/57] Simplify cpp Signed-off-by: Ziang Li --- .../cast/nvfp4/quantize_4over6_nvfp4.cuh | 172 +++++++++--------- .../cast/nvfp4/quantize_transpose_nvfp4.cuh | 58 ++---- .../quantize_transpose_nvfp4_tuned_1D.cuh | 24 +-- ...quantize_transpose_vector_blockwise_fp4.cu | 15 +- 4 files changed, 120 insertions(+), 149 deletions(-) diff --git a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh index 47c0eede25..c084459e11 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh @@ -310,6 +310,18 @@ struct QuantizationCandidates4Over6 { float err_map6; uint32_t rOut_map4[2]; uint32_t rOut_map6[2]; + + __device__ __forceinline__ void reset_errors() { + err_map4 = 0.0f; + err_map6 = 0.0f; + } + + __device__ __forceinline__ const uint32_t *selected_packed(const bool pick_map4) const { + if (pick_map4) { + return rOut_map4; + } + return rOut_map6; + } }; template @@ -329,20 +341,64 @@ struct alignas(16) QuantizationScratch4Over6 { } }; +template +__device__ __forceinline__ void load_4over6_contiguous_halves_16x(const input_type *x, + float (&first_half)[8], + float (&second_half)[8]) { +#pragma unroll + for (int i = 0; i < 8; ++i) { + first_half[i] = static_cast(x[i]); + second_half[i] = static_cast(x[i + 8]); + } +} + +template +__device__ __forceinline__ void load_4over6_pair_array_halves_16x(const pair_type (&x)[2][4], + float (&first_half)[8], + float (&second_half)[8]) { +#pragma unroll + for (int i = 0; i < 4; ++i) { + first_half[2 * i] = static_cast(x[0][i].x); + first_half[2 * i + 1] = static_cast(x[0][i].y); + second_half[2 * i] = static_cast(x[1][i].x); + second_half[2 * i + 1] = static_cast(x[1][i].y); + } +} + +template +__device__ __forceinline__ void load_4over6_vec2_array_halves_16x(const vec_type (&x)[8], + float (&first_half)[8], + float (&second_half)[8]) { +#pragma unroll + for (int i = 0; i < 4; ++i) { + first_half[2 * i] = static_cast(x[i].data.elt[0]); + first_half[2 * i + 1] = static_cast(x[i].data.elt[1]); + second_half[2 * i] = static_cast(x[i + 4].data.elt[0]); + second_half[2 * i + 1] = static_cast(x[i + 4].data.elt[1]); + } +} + +template +__device__ __forceinline__ void load_4over6_vec_index_halves_16x(const vec_type (&x)[16], + const int idx, + float (&first_half)[8], + float (&second_half)[8]) { +#pragma unroll + for (int i = 0; i < 8; ++i) { + first_half[i] = static_cast(x[i].data.elt[idx]); + second_half[i] = static_cast(x[i + 8].data.elt[idx]); + } +} + template __device__ __forceinline__ void quantize_4over6_candidates_16x( const float (&x)[16], const QuantizationScales4Over6 &scaling_factors, const float global_amax, QuantizationCandidates4Over6 &candidates) { float first_half[8]; float second_half[8]; -#pragma unroll - for (int i = 0; i < 8; ++i) { - first_half[i] = x[i]; - second_half[i] = x[i + 8]; - } + load_4over6_contiguous_halves_16x(x, first_half, second_half); - candidates.err_map4 = 0.0f; - candidates.err_map6 = 0.0f; + candidates.reset_errors(); quantize_4over6_16x(first_half, second_half, scaling_factors, global_amax, candidates.err_map4, candidates.err_map6, candidates.rOut_map4, candidates.rOut_map6); @@ -353,13 +409,10 @@ template &scaling_factors, const size_t block_in_tile_y, const size_t block_in_tile_x, const size_t participant_idx, - float (&err_map4_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X][BLOCK_DIM], - float (&err_map6_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X][BLOCK_DIM], - uint8_t (&pick_map4_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X], - nvfp4_scale_t (&selected_scale_matrix)[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X], + QuantizationScratch4Over6 &scratch, nvfp4_scale_t &S_dec_b_fp8, const QuantizationCandidates4Over6 &candidates) { - err_map4_matrix[block_in_tile_y][block_in_tile_x][participant_idx] = candidates.err_map4; - err_map6_matrix[block_in_tile_y][block_in_tile_x][participant_idx] = candidates.err_map6; + scratch.err_map4_matrix[block_in_tile_y][block_in_tile_x][participant_idx] = candidates.err_map4; + scratch.err_map6_matrix[block_in_tile_y][block_in_tile_x][participant_idx] = candidates.err_map6; __syncthreads(); if (participant_idx == 0) { @@ -367,35 +420,22 @@ __device__ __forceinline__ bool record_and_select_4over6_2d_block( float block_err_map6 = 0.0f; #pragma unroll for (int i = 0; i < BLOCK_DIM; ++i) { - block_err_map4 += err_map4_matrix[block_in_tile_y][block_in_tile_x][i]; - block_err_map6 += err_map6_matrix[block_in_tile_y][block_in_tile_x][i]; + block_err_map4 += scratch.err_map4_matrix[block_in_tile_y][block_in_tile_x][i]; + block_err_map6 += scratch.err_map6_matrix[block_in_tile_y][block_in_tile_x][i]; } const bool pick_map4 = pick_4over6_map4(block_err_map4, block_err_map6); if (pick_map4) { - pick_map4_matrix[block_in_tile_y][block_in_tile_x] = 1; + scratch.pick_map4_matrix[block_in_tile_y][block_in_tile_x] = 1; } else { - pick_map4_matrix[block_in_tile_y][block_in_tile_x] = 0; + scratch.pick_map4_matrix[block_in_tile_y][block_in_tile_x] = 0; } - selected_scale_matrix[block_in_tile_y][block_in_tile_x] = + scratch.selected_scale_matrix[block_in_tile_y][block_in_tile_x] = selected_4over6_scale(pick_map4, scaling_factors); } __syncthreads(); - S_dec_b_fp8 = selected_scale_matrix[block_in_tile_y][block_in_tile_x]; - return pick_map4_matrix[block_in_tile_y][block_in_tile_x] != 0; -} - -template -__device__ __forceinline__ bool record_and_select_4over6_2d_block( - const QuantizationScales4Over6 &scaling_factors, - const size_t block_in_tile_y, const size_t block_in_tile_x, const size_t participant_idx, - QuantizationScratch4Over6 &scratch, - nvfp4_scale_t &S_dec_b_fp8, const QuantizationCandidates4Over6 &candidates) { - return record_and_select_4over6_2d_block( - scaling_factors, block_in_tile_y, block_in_tile_x, participant_idx, scratch.err_map4_matrix, - scratch.err_map6_matrix, scratch.pick_map4_matrix, scratch.selected_scale_matrix, S_dec_b_fp8, - candidates); + S_dec_b_fp8 = scratch.selected_scale_matrix[block_in_tile_y][block_in_tile_x]; + return scratch.pick_map4_matrix[block_in_tile_y][block_in_tile_x] != 0; } template @@ -416,19 +456,11 @@ __device__ __forceinline__ bool quantize_and_select_4over6_2d_block_16x( return pick_map4; } -__device__ __forceinline__ const uint32_t *selected_4over6_packed( - const bool pick_map4, const QuantizationCandidates4Over6 &candidates) { - if (pick_map4) { - return candidates.rOut_map4; - } - return candidates.rOut_map6; -} - template __device__ __forceinline__ void store_4over6_colwise_packed_16x( const bool pick_map4, const QuantizationCandidates4Over6 &candidates, const int thread_lane, output_type *out_t_data_sh, const size_t shmem_offset_base_colwise_out_t) { - const uint32_t *regs_4x = selected_4over6_packed(pick_map4, candidates); + const uint32_t *regs_4x = candidates.selected_packed(pick_map4); const int group = thread_lane / 16; uint32_t val[2]; switch (group) { @@ -452,7 +484,7 @@ __device__ __forceinline__ void store_4over6_rowwise_packed_16x( const bool pick_map4, const QuantizationCandidates4Over6 &candidates, const int bank_group, const size_t thread_offset_X_rowwise, const size_t shmem_offset_base_rowwise_out, output_type *out_data_sh) { - const uint32_t *packed = selected_4over6_packed(pick_map4, candidates); + const uint32_t *packed = candidates.selected_packed(pick_map4); #pragma unroll for (int w = 0; w < WAVES; ++w) { const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; @@ -465,7 +497,7 @@ __device__ __forceinline__ void store_4over6_rowwise_packed_16x( } template -__device__ __forceinline__ void store_4over6_packed_16x(const uint32_t (&packed)[2], +__device__ __forceinline__ void store_4over6_packed_16x(const uint32_t *packed, output_vec_type &output_vec) { *reinterpret_cast(&output_vec.data.elt[0]) = packed[0]; *reinterpret_cast(&output_vec.data.elt[4]) = packed[1]; @@ -475,9 +507,7 @@ template __device__ __forceinline__ void store_selected_4over6_packed_16x( const bool pick_map4, const QuantizationCandidates4Over6 &candidates, output_vec_type &output_vec) { - const uint32_t *packed = selected_4over6_packed(pick_map4, candidates); - *reinterpret_cast(&output_vec.data.elt[0]) = packed[0]; - *reinterpret_cast(&output_vec.data.elt[4]) = packed[1]; + store_4over6_packed_16x(candidates.selected_packed(pick_map4), output_vec); } template (x[i]); - second_half[i] = static_cast(x[i + 8]); - } + load_4over6_contiguous_halves_16x(x, first_half, second_half); quantize_4over6_16x(first_half, second_half, scaling_factors, global_amax, S_dec_b_fp8, rOut); @@ -504,13 +530,7 @@ __device__ __forceinline__ void quantize_4over6_pair_array_16x( const float global_amax, nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { float first_half[8]; float second_half[8]; -#pragma unroll - for (int i = 0; i < 4; ++i) { - first_half[2 * i] = static_cast(x[0][i].x); - first_half[2 * i + 1] = static_cast(x[0][i].y); - second_half[2 * i] = static_cast(x[1][i].x); - second_half[2 * i + 1] = static_cast(x[1][i].y); - } + load_4over6_pair_array_halves_16x(x, first_half, second_half); quantize_4over6_16x(first_half, second_half, scaling_factors, global_amax, S_dec_b_fp8, rOut); @@ -522,16 +542,9 @@ __device__ __forceinline__ void quantize_4over6_vec2_array_candidates_16x( const float global_amax, QuantizationCandidates4Over6 &candidates) { float first_half[8]; float second_half[8]; -#pragma unroll - for (int i = 0; i < 4; ++i) { - first_half[2 * i] = static_cast(x[i].data.elt[0]); - first_half[2 * i + 1] = static_cast(x[i].data.elt[1]); - second_half[2 * i] = static_cast(x[i + 4].data.elt[0]); - second_half[2 * i + 1] = static_cast(x[i + 4].data.elt[1]); - } + load_4over6_vec2_array_halves_16x(x, first_half, second_half); - candidates.err_map4 = 0.0f; - candidates.err_map6 = 0.0f; + candidates.reset_errors(); quantize_4over6_16x(first_half, second_half, scaling_factors, global_amax, candidates.err_map4, candidates.err_map6, candidates.rOut_map4, candidates.rOut_map6); @@ -544,13 +557,7 @@ __device__ __forceinline__ void quantize_4over6_vec2_array_16x( const float global_amax, nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { float first_half[8]; float second_half[8]; -#pragma unroll - for (int i = 0; i < 4; ++i) { - first_half[2 * i] = static_cast(x[i].data.elt[0]); - first_half[2 * i + 1] = static_cast(x[i].data.elt[1]); - second_half[2 * i] = static_cast(x[i + 4].data.elt[0]); - second_half[2 * i + 1] = static_cast(x[i + 4].data.elt[1]); - } + load_4over6_vec2_array_halves_16x(x, first_half, second_half); quantize_4over6_16x(first_half, second_half, scaling_factors, global_amax, S_dec_b_fp8, rOut); @@ -563,14 +570,9 @@ __device__ __forceinline__ void quantize_4over6_vec_index_candidates_16x( QuantizationCandidates4Over6 &candidates) { float first_half[8]; float second_half[8]; -#pragma unroll - for (int i = 0; i < 8; ++i) { - first_half[i] = static_cast(x[i].data.elt[idx]); - second_half[i] = static_cast(x[i + 8].data.elt[idx]); - } + load_4over6_vec_index_halves_16x(x, idx, first_half, second_half); - candidates.err_map4 = 0.0f; - candidates.err_map6 = 0.0f; + candidates.reset_errors(); quantize_4over6_16x(first_half, second_half, scaling_factors, global_amax, candidates.err_map4, candidates.err_map6, candidates.rOut_map4, candidates.rOut_map6); @@ -584,11 +586,7 @@ __device__ __forceinline__ void quantize_4over6_vec_index_16x( nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { float first_half[8]; float second_half[8]; -#pragma unroll - for (int i = 0; i < 8; ++i) { - first_half[i] = static_cast(x[i].data.elt[idx]); - second_half[i] = static_cast(x[i + 8].data.elt[idx]); - } + load_4over6_vec_index_halves_16x(x, idx, first_half, second_half); quantize_4over6_16x(first_half, second_half, scaling_factors, global_amax, S_dec_b_fp8, rOut); diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index 046a60882c..cd4def60a8 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -1098,6 +1098,7 @@ __global__ void __launch_bounds__(THREADS_NUM) } } + nvfp4_scale_t S_dec_b_fp8; if constexpr (USE_4OVER6) { float x_4over6[SCALE_DIM]; #pragma unroll @@ -1111,7 +1112,6 @@ __global__ void __launch_bounds__(THREADS_NUM) const size_t block_col = threadIdx.x % BLOCK_DIM; QuantizationCandidates4Over6 candidates; - nvfp4_scale_t S_dec_b_fp8; const bool pick_map4 = quantize_and_select_4over6_2d_block_16x( @@ -1119,21 +1119,11 @@ __global__ void __launch_bounds__(THREADS_NUM) block_in_tile_y, block_in_tile_x, block_col, *four_over_six_scratch, S_dec_b_fp8, candidates); - const size_t scale_idx_sh = - tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; - out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8; - store_4over6_colwise_packed_16x(pick_map4, candidates, thread_lane, out_t_data_sh, shmem_offset_base_colwise_out_t); } else { // 2. Compute E4M3 scaling factor - const nvfp4_scale_t S_dec_b_fp8 = - compute_decoding_scaling_factor(block_amax, S_enc_colwise); - - // // Store scaling factors through SHMEM - const size_t scale_idx_sh = - tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; - out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8; + S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc_colwise); // Compute "correct" per-block encoding scaling factor constexpr float float_max = detail::TypeExtrema::max; @@ -1177,6 +1167,10 @@ __global__ void __launch_bounds__(THREADS_NUM) out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2; out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2; } + + const size_t scale_idx_sh = + tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; + out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8; } } @@ -1272,9 +1266,17 @@ __global__ void __launch_bounds__(THREADS_NUM) } } + const size_t scales_offset_Y = + scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; + const size_t scales_offset_X = scales_offset_X_rowwise; + const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; + + const bool rowwise_scale_is_within_bounds_Y = + (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; + + nvfp4_scale_t S_dec_b_fp8; if constexpr (USE_4OVER6) { QuantizationCandidates4Over6 candidates; - nvfp4_scale_t S_dec_b_fp8; const bool pick_map4 = quantize_and_select_4over6_2d_block_16x( @@ -1282,36 +1284,12 @@ __global__ void __launch_bounds__(THREADS_NUM) block_in_tile_y, block_in_tile_x, tid_Y_rowwise, *four_over_six_scratch, S_dec_b_fp8, candidates); - const size_t scales_offset_Y = - scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; - const size_t scales_offset_X = scales_offset_X_rowwise; - const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; - - const bool rowwise_scale_is_within_bounds_Y = - (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; - if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { - scales_ptr[scale_idx_global] = S_dec_b_fp8; - } - store_4over6_rowwise_packed_16x( pick_map4, candidates, bank_group, thread_offset_X_rowwise, shmem_offset_base_rowwise_out, out_data_sh); } else { // 2. Compute E4M3 scaling factor - const nvfp4_scale_t S_dec_b_fp8 = - compute_decoding_scaling_factor(block_amax, S_enc_rowwise); - - // Check boundaries - const size_t scales_offset_Y = - scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; - const size_t scales_offset_X = scales_offset_X_rowwise; - const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; - - const bool rowwise_scale_is_within_bounds_Y = - (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; - if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { - scales_ptr[scale_idx_global] = S_dec_b_fp8; - } + S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc_rowwise); // Compute "correct" per-block encoding scaling factor constexpr float float_max = detail::TypeExtrema::max; @@ -1352,6 +1330,10 @@ __global__ void __launch_bounds__(THREADS_NUM) out.store_to(&out_data_sh[shmem_offset_rowwise]); } } + + if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { + scales_ptr[scale_idx_global] = S_dec_b_fp8; + } } } diff --git a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh index e210a1cb8d..27d8f5fd09 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh @@ -331,8 +331,8 @@ __device__ __forceinline__ void rowwise_scaling( } const float block_amax = get_amax_of_pair(thread_amax_2x); + nvfp4_scale_t S_dec_b_fp8; if constexpr (USE_4OVER6) { - nvfp4_scale_t S_dec_b_fp8; float block_S_enc_rowwise; float block_global_amax; if constexpr (ROW_SCALED_NVFP4) { @@ -362,13 +362,6 @@ __device__ __forceinline__ void rowwise_scaling( rIn, scaling_factors, block_global_amax, S_dec_b_fp8, rOut); } - // Store scaling factors to SMEM buffer (R2S) - if (SF_storing_thread) { - const int scales_offset_Y = stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE; - const int scales_offset_X = stage_rowwise_scales_offset_X; - sSFrowwise[scales_offset_Y][scales_offset_X] = S_dec_b_fp8; - } - #pragma unroll for (int w = 0; w < WAVES; ++w) { const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % ELTS_PER_THREAD; @@ -376,7 +369,6 @@ __device__ __forceinline__ void rowwise_scaling( ptx::st_shared_b32(&sOut[buff_out][it_offset_Y_rowwise][swizzled_idx], rOut[w]); } } else { - nvfp4_scale_t S_dec_b_fp8; scaling_coeff_type SFcoefficient; if constexpr (ROW_SCALED_NVFP4) { const size_t row_idx = row_offset + stage_Y * TILE_DIM_Y + it_offset_Y_rowwise; @@ -393,13 +385,6 @@ __device__ __forceinline__ void rowwise_scaling( compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_rowwise); } - // Store scaling factors to SMEM buffer (R2S) - if (SF_storing_thread) { - const int scales_offset_Y = stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE; - const int scales_offset_X = stage_rowwise_scales_offset_X; - sSFrowwise[scales_offset_Y][scales_offset_X] = S_dec_b_fp8; - } - // Scale elements #pragma unroll for (int w = 0; w < WAVES; ++w) { @@ -422,6 +407,13 @@ __device__ __forceinline__ void rowwise_scaling( ptx::st_shared_b32(&sOut[buff_out][it_offset_Y_rowwise][swizzled_idx], out_x8); } } + + // Store scaling factors to SMEM buffer (R2S) + if (SF_storing_thread) { + const int scales_offset_Y = stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE; + const int scales_offset_X = stage_rowwise_scales_offset_X; + sSFrowwise[scales_offset_Y][scales_offset_X] = S_dec_b_fp8; + } } } diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index dd7cb24d45..9ce7bb0dc6 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -364,10 +364,10 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo __shared__ CType amax_smem[k2DBlockAmaxDim][k2DBlockAmaxDim]; constexpr int k4Over62DSelectionDim = (kUse4Over6 && kIs2DBlockScaling) ? kFP4BlockScalingSize : 1; - __shared__ float err_map4_smem[k2DBlockAmaxDim][k2DBlockAmaxDim][k4Over62DSelectionDim]; - __shared__ float err_map6_smem[k2DBlockAmaxDim][k2DBlockAmaxDim][k4Over62DSelectionDim]; - __shared__ uint8_t pick_map4_smem[k2DBlockAmaxDim][k2DBlockAmaxDim]; - __shared__ ScaleType selected_scale_smem[k2DBlockAmaxDim][k2DBlockAmaxDim]; + using FourOverSixScratch = + nvfp4_core::QuantizationScratch4Over6; + __shared__ FourOverSixScratch four_over_six_scratch; // Step 1: Load input to shared memory { @@ -554,8 +554,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo const bool pick_map4 = nvfp4_core::record_and_select_4over6_2d_block( - scaling_factors, block_in_tile_y, block_in_tile_x, participant_idx, err_map4_smem, - err_map6_smem, pick_map4_smem, selected_scale_smem, scale_inv, candidates); + scaling_factors, block_in_tile_y, block_in_tile_x, participant_idx, + four_over_six_scratch, scale_inv, candidates); nvfp4_core::store_selected_4over6_packed_16x(pick_map4, candidates, output_vec); } else { @@ -719,8 +719,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo nvfp4_core::record_and_select_4over6_2d_block( scaling_factors, block_in_tile_y, block_in_tile_x, participant_idx, - err_map4_smem, err_map6_smem, pick_map4_smem, selected_scale_smem, scale_inv, - candidates); + four_over_six_scratch, scale_inv, candidates); nvfp4_core::store_selected_4over6_packed_16x(pick_map4, candidates, output_vec); } else { From 21720daa369590ab2a958324774dde36bbea0f40 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 12 May 2026 03:01:11 -0700 Subject: [PATCH 44/57] Drop write back lifting Signed-off-by: Ziang Li --- .../cast/nvfp4/quantize_transpose_nvfp4.cuh | 58 ++++++++++++------- .../quantize_transpose_nvfp4_tuned_1D.cuh | 24 +++++--- 2 files changed, 54 insertions(+), 28 deletions(-) diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index cd4def60a8..046a60882c 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -1098,7 +1098,6 @@ __global__ void __launch_bounds__(THREADS_NUM) } } - nvfp4_scale_t S_dec_b_fp8; if constexpr (USE_4OVER6) { float x_4over6[SCALE_DIM]; #pragma unroll @@ -1112,6 +1111,7 @@ __global__ void __launch_bounds__(THREADS_NUM) const size_t block_col = threadIdx.x % BLOCK_DIM; QuantizationCandidates4Over6 candidates; + nvfp4_scale_t S_dec_b_fp8; const bool pick_map4 = quantize_and_select_4over6_2d_block_16x( @@ -1119,11 +1119,21 @@ __global__ void __launch_bounds__(THREADS_NUM) block_in_tile_y, block_in_tile_x, block_col, *four_over_six_scratch, S_dec_b_fp8, candidates); + const size_t scale_idx_sh = + tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; + out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8; + store_4over6_colwise_packed_16x(pick_map4, candidates, thread_lane, out_t_data_sh, shmem_offset_base_colwise_out_t); } else { // 2. Compute E4M3 scaling factor - S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc_colwise); + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_colwise); + + // // Store scaling factors through SHMEM + const size_t scale_idx_sh = + tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; + out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8; // Compute "correct" per-block encoding scaling factor constexpr float float_max = detail::TypeExtrema::max; @@ -1167,10 +1177,6 @@ __global__ void __launch_bounds__(THREADS_NUM) out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2; out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2; } - - const size_t scale_idx_sh = - tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; - out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8; } } @@ -1266,17 +1272,9 @@ __global__ void __launch_bounds__(THREADS_NUM) } } - const size_t scales_offset_Y = - scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; - const size_t scales_offset_X = scales_offset_X_rowwise; - const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; - - const bool rowwise_scale_is_within_bounds_Y = - (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; - - nvfp4_scale_t S_dec_b_fp8; if constexpr (USE_4OVER6) { QuantizationCandidates4Over6 candidates; + nvfp4_scale_t S_dec_b_fp8; const bool pick_map4 = quantize_and_select_4over6_2d_block_16x( @@ -1284,12 +1282,36 @@ __global__ void __launch_bounds__(THREADS_NUM) block_in_tile_y, block_in_tile_x, tid_Y_rowwise, *four_over_six_scratch, S_dec_b_fp8, candidates); + const size_t scales_offset_Y = + scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; + const size_t scales_offset_X = scales_offset_X_rowwise; + const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; + + const bool rowwise_scale_is_within_bounds_Y = + (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; + if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { + scales_ptr[scale_idx_global] = S_dec_b_fp8; + } + store_4over6_rowwise_packed_16x( pick_map4, candidates, bank_group, thread_offset_X_rowwise, shmem_offset_base_rowwise_out, out_data_sh); } else { // 2. Compute E4M3 scaling factor - S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + + // Check boundaries + const size_t scales_offset_Y = + scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; + const size_t scales_offset_X = scales_offset_X_rowwise; + const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; + + const bool rowwise_scale_is_within_bounds_Y = + (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; + if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { + scales_ptr[scale_idx_global] = S_dec_b_fp8; + } // Compute "correct" per-block encoding scaling factor constexpr float float_max = detail::TypeExtrema::max; @@ -1330,10 +1352,6 @@ __global__ void __launch_bounds__(THREADS_NUM) out.store_to(&out_data_sh[shmem_offset_rowwise]); } } - - if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { - scales_ptr[scale_idx_global] = S_dec_b_fp8; - } } } diff --git a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh index 27d8f5fd09..e210a1cb8d 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh @@ -331,8 +331,8 @@ __device__ __forceinline__ void rowwise_scaling( } const float block_amax = get_amax_of_pair(thread_amax_2x); - nvfp4_scale_t S_dec_b_fp8; if constexpr (USE_4OVER6) { + nvfp4_scale_t S_dec_b_fp8; float block_S_enc_rowwise; float block_global_amax; if constexpr (ROW_SCALED_NVFP4) { @@ -362,6 +362,13 @@ __device__ __forceinline__ void rowwise_scaling( rIn, scaling_factors, block_global_amax, S_dec_b_fp8, rOut); } + // Store scaling factors to SMEM buffer (R2S) + if (SF_storing_thread) { + const int scales_offset_Y = stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE; + const int scales_offset_X = stage_rowwise_scales_offset_X; + sSFrowwise[scales_offset_Y][scales_offset_X] = S_dec_b_fp8; + } + #pragma unroll for (int w = 0; w < WAVES; ++w) { const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % ELTS_PER_THREAD; @@ -369,6 +376,7 @@ __device__ __forceinline__ void rowwise_scaling( ptx::st_shared_b32(&sOut[buff_out][it_offset_Y_rowwise][swizzled_idx], rOut[w]); } } else { + nvfp4_scale_t S_dec_b_fp8; scaling_coeff_type SFcoefficient; if constexpr (ROW_SCALED_NVFP4) { const size_t row_idx = row_offset + stage_Y * TILE_DIM_Y + it_offset_Y_rowwise; @@ -385,6 +393,13 @@ __device__ __forceinline__ void rowwise_scaling( compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_rowwise); } + // Store scaling factors to SMEM buffer (R2S) + if (SF_storing_thread) { + const int scales_offset_Y = stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE; + const int scales_offset_X = stage_rowwise_scales_offset_X; + sSFrowwise[scales_offset_Y][scales_offset_X] = S_dec_b_fp8; + } + // Scale elements #pragma unroll for (int w = 0; w < WAVES; ++w) { @@ -407,13 +422,6 @@ __device__ __forceinline__ void rowwise_scaling( ptx::st_shared_b32(&sOut[buff_out][it_offset_Y_rowwise][swizzled_idx], out_x8); } } - - // Store scaling factors to SMEM buffer (R2S) - if (SF_storing_thread) { - const int scales_offset_Y = stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE; - const int scales_offset_X = stage_rowwise_scales_offset_X; - sSFrowwise[scales_offset_Y][scales_offset_X] = S_dec_b_fp8; - } } } From b1d073a7df252e4813f810ba9739b2dfc42ee7b3 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 12 May 2026 12:22:47 -0700 Subject: [PATCH 45/57] Add MAE and dedicated fast math env var Signed-off-by: Ziang Li --- docs/envvars.rst | 14 +- .../cpp/operator/test_cast_nvfp4_transpose.cu | 110 ++++--- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 21 ++ .../nvfp4/test_nvfp4_quantize_exact.py | 18 ++ tests/pytorch/test_recipe.py | 7 +- .../common/cast/dispatch/quantize.cuh | 9 +- .../cast/nvfp4/quantize_4over6_nvfp4.cuh | 288 +++++++++--------- .../cast/nvfp4/quantize_transpose_nvfp4.cuh | 79 ++--- .../quantize_transpose_nvfp4_tuned_1D.cuh | 103 ++++--- transformer_engine/common/common.h | 8 +- .../transformer_engine/transformer_engine.h | 42 ++- transformer_engine/common/recipe/__init__.py | 10 +- .../common/transformer_engine.cpp | 19 ++ .../common/transpose/cast_transpose.h | 5 +- ...quantize_transpose_vector_blockwise_fp4.cu | 112 ++++--- transformer_engine/pytorch/csrc/common.h | 1 + .../pytorch/csrc/extensions/cast.cpp | 36 ++- transformer_engine/pytorch/csrc/quantizer.cpp | 22 +- .../custom_recipes/quantization_ref_nvfp4.py | 34 ++- transformer_engine/pytorch/quantization.py | 1 + .../pytorch/tensor/nvfp4_tensor.py | 6 + 21 files changed, 620 insertions(+), 325 deletions(-) diff --git a/docs/envvars.rst b/docs/envvars.rst index 5ad271837c..d2456f3364 100644 --- a/docs/envvars.rst +++ b/docs/envvars.rst @@ -291,7 +291,19 @@ Kernel Configuration :Type: ``str`` (``weights``, ``activations``, or ``all``) :Default: unset - :Description: Enable per-block map-to-4 versus map-to-6 candidate selection for selected NVFP4 quantizers in the ``NVFP4BlockScaling`` recipe. ``weights`` selects weight tensor roles, ``activations`` selects non-weight tensor roles, and ``all`` selects both. The selected block scale is the candidate with lower input-domain MSE, and ties select map-to-6. This mode uses 256 as the global E4M3 scale bound instead of the 448 bound. Tensors using 4over6 currently require RHT and stochastic rounding to be disabled; activation and backward scopes therefore require ``NVTE_NVFP4_DISABLE_RHT=1`` and ``NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING=1``. + :Description: Enable per-block map-to-4 versus map-to-6 candidate selection for selected NVFP4 quantizers in the ``NVFP4BlockScaling`` recipe. ``weights`` selects weight tensor roles, ``activations`` selects non-weight tensor roles, and ``all`` selects both. The selected block scale is the candidate with lower configured input-domain error, and ties select map-to-6. This mode uses 256 as the global E4M3 scale bound instead of the 448 bound. Tensors using 4over6 currently require RHT and stochastic rounding to be disabled; activation and backward scopes therefore require ``NVTE_NVFP4_DISABLE_RHT=1`` and ``NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING=1``. + +.. envvar:: NVTE_NVFP4_4OVER6_ERR_MODE + + :Type: ``str`` (``MAE`` or ``MSE``) + :Default: ``MAE`` + :Description: Select the input-domain error metric used by NVFP4 4over6 map-to-4 versus map-to-6 candidate selection in the ``NVFP4BlockScaling`` recipe. + +.. envvar:: NVTE_NVFP4_4OVER6_ERR_FAST_MATH + + :Type: ``int`` (0 or 1) + :Default: ``0`` + :Description: Allow the NVFP4 4over6 candidate error computation to use faster non-strict floating-point expressions. By default, 4over6 error comparison uses strict expressions; ``NVTE_USE_FAST_MATH`` does not control this error-comparison path. Torch Compilation and Fusion ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 2e88f2dc1a..7a10c50275 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -71,7 +71,7 @@ float compute_global_encode_scaling_factor_FP4(const float global_amax, const bo constexpr float fp4_max = 6.0f; // 6.0f; float global_encode_scale = fp8_max * fp4_max / global_amax; // If scale is infinity, return the max normalized value - const float max_norm_clamp = use_fast_math + const float max_norm_clamp = (use_fast_math && !use_4over6) ? Numeric_Traits::maxNorm : Numeric_Traits::maxNorm; @@ -95,7 +95,7 @@ struct NVFP4FourOverSixQuantization { }; NVFP4FourOverSixQuantization compute_4over6_quantization_scales( - const float block_amax, const float global_encode_scale, const bool use_fast_math) { + const float block_amax, const float global_encode_scale) { constexpr float fp4_max = 6.0f; constexpr float scale_expansion_factor = 1.5f; const float base_sf_high_precision = block_amax / fp4_max * global_encode_scale; @@ -118,11 +118,6 @@ NVFP4FourOverSixQuantization compute_4over6_quantization_scales( Numeric_Traits::maxNorm); } - if (use_fast_math) { - reciprocal_map4 = static_cast(static_cast(reciprocal_map4)); - reciprocal_map6 = static_cast(static_cast(reciprocal_map6)); - } - const float2 zero = {0.0f, 0.0f}; return { scale_map4, @@ -139,19 +134,24 @@ NVFP4FourOverSixQuantization compute_4over6_quantization_scales( float compute_4over6_dequantized_value(const double quantized_value, const fp8e4m3 scale, const float global_amax) { - constexpr float mse_scale = 1.0f / (6.0f * 256.0f); + constexpr float error_scale = 1.0f / (6.0f * 256.0f); return static_cast(quantized_value) * static_cast(scale) * global_amax * - mse_scale; + error_scale; } -float compute_squared_error(const float value, const float reference) { +float compute_4over6_error(const float value, const float reference, + const NVTENVFP44Over6ErrMode err_mode) { const float diff = value - reference; - return diff * diff; + if (err_mode == kNVTENVFP44Over6ErrMSE) { + return diff * diff; + } + NVTE_CHECK(err_mode == kNVTENVFP44Over6ErrMAE, "Unsupported NVFP4 4over6 error mode."); + return fabsf(diff); } NVFP4FourOverSixQuantization quantize_4over6_pair( const float x, const float y, const NVFP4FourOverSixQuantization& quantization, - const float global_amax) { + const float global_amax, const NVTENVFP44Over6ErrMode err_mode) { const float2 scaled_map4 = {x * quantization.reciprocal_map4, y * quantization.reciprocal_map4}; const fp4e2m1x2 quantized_map4(scaled_map4); @@ -177,8 +177,10 @@ NVFP4FourOverSixQuantization quantize_4over6_pair( quantization.reciprocal_map6, quantized_map4, quantized_map6, - compute_squared_error(dequant_x_map4, x) + compute_squared_error(dequant_y_map4, y), - compute_squared_error(dequant_x_map6, x) + compute_squared_error(dequant_y_map6, y), + compute_4over6_error(dequant_x_map4, x, err_mode) + + compute_4over6_error(dequant_y_map4, y, err_mode), + compute_4over6_error(dequant_x_map6, x, err_mode) + + compute_4over6_error(dequant_y_map6, y, err_mode), }; } @@ -193,7 +195,8 @@ void quantize_nvfp4_1d(float (*OP)(const float), const size_t scales_stride, const float global_amax, const bool use_fast_math, - const bool use_4over6 = false) { + const bool use_4over6 = false, + const NVTENVFP44Over6ErrMode err_mode = kNVTENVFP44Over6ErrMAE) { // Compute a global encoding/decoding scaling factor for all S_dec_b const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math, use_4over6); @@ -230,7 +233,7 @@ void quantize_nvfp4_1d(float (*OP)(const float), if (use_4over6) { const NVFP4FourOverSixQuantization quantization = - compute_4over6_quantization_scales(block_amax, S_enc, use_fast_math); + compute_4over6_quantization_scales(block_amax, S_enc); std::array output_map6; std::array output_map4; @@ -243,7 +246,8 @@ void quantize_nvfp4_1d(float (*OP)(const float), const float cached_x = cache_buffer[cache_idx_x]; const float cached_y = cache_buffer[cache_idx_y]; const NVFP4FourOverSixQuantization pair_quantization = - quantize_4over6_pair(cached_x, cached_y, quantization, global_amax); + quantize_4over6_pair(cached_x, cached_y, quantization, global_amax, + err_mode); output_map4[cache_idx_x / 2] = pair_quantization.quantized_map4; output_map6[cache_idx_x / 2] = pair_quantization.quantized_map6; err_map4 += pair_quantization.error_map4; @@ -312,7 +316,9 @@ void compute_2d_mathematical_scales(float (*OP)(const float), const float global_amax, std::vector>& math_scales, const bool use_fast_math, - const bool use_4over6 = false) { + const bool use_4over6 = false, + const NVTENVFP44Over6ErrMode err_mode = + kNVTENVFP44Over6ErrMAE) { const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math, use_4over6); constexpr size_t block_size_Y = 16; @@ -344,7 +350,7 @@ void compute_2d_mathematical_scales(float (*OP)(const float), // Compute E4M3 scaling factor for this 16x16 block if (use_4over6) { const NVFP4FourOverSixQuantization quantization = - compute_4over6_quantization_scales(block_amax, S_enc, use_fast_math); + compute_4over6_quantization_scales(block_amax, S_enc); float err_map6 = 0.0f; float err_map4 = 0.0f; @@ -360,7 +366,8 @@ void compute_2d_mathematical_scales(float (*OP)(const float), cached_y = static_cast(static_cast(act_y)); } const NVFP4FourOverSixQuantization pair_quantization = - quantize_4over6_pair(cached_x, cached_y, quantization, global_amax); + quantize_4over6_pair(cached_x, cached_y, quantization, global_amax, + err_mode); err_map4 += pair_quantization.error_map4; err_map6 += pair_quantization.error_map6; } @@ -391,12 +398,13 @@ void quantize_nvfp4_2d(float (*OP)(const float), const size_t scales_stride, const float global_amax, const bool use_fast_math, - const bool use_4over6 = false) { + const bool use_4over6 = false, + const NVTENVFP44Over6ErrMode err_mode = kNVTENVFP44Over6ErrMAE) { // Step 1: Compute mathematical 8x8 scaling factors std::vector> math_scales; compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math, - use_4over6); + use_4over6, err_mode); const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math, use_4over6); constexpr size_t block_size_Y = 16; @@ -481,13 +489,14 @@ void quantize_nvfp4(float (*OP)(const float), const float global_amax, const bool use_fast_math, const bool use_2d_quantization = false, - const bool use_4over6 = false) { + const bool use_4over6 = false, + const NVTENVFP44Over6ErrMode err_mode = kNVTENVFP44Over6ErrMAE) { if (use_2d_quantization) { quantize_nvfp4_2d(OP, input, output, scales, rows, cols, scales_stride, global_amax, - use_fast_math, use_4over6); + use_fast_math, use_4over6, err_mode); } else { quantize_nvfp4_1d(OP, input, output, scales, rows, cols, scales_stride, global_amax, - use_fast_math, use_4over6); + use_fast_math, use_4over6, err_mode); } } @@ -506,7 +515,8 @@ void compute_ref(float (*OP)(const float), const bool use_fast_math, const bool use_2d_quantization = false, const bool row_scaled_nvfp4 = false, - const bool use_4over6 = false) + const bool use_4over6 = false, + const NVTENVFP44Over6ErrMode err_mode = kNVTENVFP44Over6ErrMAE) { std::vector input_t = create_transpose(input, rows, cols); NVTE_CHECK(!(use_2d_quantization && row_scaled_nvfp4), @@ -517,7 +527,7 @@ void compute_ref(float (*OP)(const float), // Step 1: Compute mathematical 8×8 scaling factors std::vector> math_scales; compute_2d_mathematical_scales(OP, input, rows, cols, *amax, math_scales, use_fast_math, - use_4over6); + use_4over6, err_mode); constexpr size_t block_size_Y = 16; constexpr size_t block_size_X = 16; @@ -545,9 +555,9 @@ void compute_ref(float (*OP)(const float), // Step 4: Process quantized outputs using the same algorithm as quantize_nvfp4_2d // (This part processes the actual FP4 data using the mathematical scaling factors) quantize_nvfp4_2d(OP, input, output, nullptr, rows, cols, scales_stride, *amax, - use_fast_math, use_4over6); // scales already filled + use_fast_math, use_4over6, err_mode); // scales already filled quantize_nvfp4_2d(OP, input_t.data(), output_t, nullptr, cols, rows, scales_stride_t, *amax, - use_fast_math, use_4over6); // scales_t already filled + use_fast_math, use_4over6, err_mode); // scales_t already filled return; } @@ -565,16 +575,17 @@ void compute_ref(float (*OP)(const float), amax[row], use_fast_math, use_2d_quantization, - use_4over6); + use_4over6, + err_mode); } return; } // Ref impl for basic NVFP4 quantize_nvfp4(OP, input, output, scales, rows, cols, scales_stride, *amax, - use_fast_math, use_2d_quantization, use_4over6); + use_fast_math, use_2d_quantization, use_4over6, err_mode); quantize_nvfp4(OP, input_t.data(), output_t, scales_t, cols, rows, scales_stride_t, *amax, - use_fast_math, use_2d_quantization, use_4over6); + use_fast_math, use_2d_quantization, use_4over6, err_mode); } void compare_nvfp4_tensors(const std::string& name, @@ -715,7 +726,8 @@ void performTest(float (*OP)(const float), const bool use_fast_math, const bool use_2d_quantization = false, const bool row_scaled_nvfp4 = false, - const bool use_4over6 = false) { + const bool use_4over6 = false, + const NVTENVFP44Over6ErrMode err_mode = kNVTENVFP44Over6ErrMAE) { using namespace test; NVTE_CHECK(!(use_2d_quantization && row_scaled_nvfp4), @@ -807,7 +819,8 @@ void performTest(float (*OP)(const float), use_fast_math, use_2d_quantization, row_scaled_nvfp4, - use_4over6); + use_4over6, + err_mode); // Initialize stochastic rounding Tensor rng_state("rng_state", std::vector{2}, DType::kInt64); @@ -817,11 +830,12 @@ void performTest(float (*OP)(const float), // Quantization options QuantizationConfigWrapper quant_config; - quant_config.set_use_fast_math(use_fast_math); + quant_config.set_use_fast_math(use_fast_math && !use_4over6); quant_config.set_stochastic_rounding(false); quant_config.set_rng_state(rng_state.data()); quant_config.set_nvfp4_2d_quantization(use_2d_quantization); quant_config.set_nvfp4_4over6(use_4over6); + quant_config.set_nvfp4_4over6_err_mode(err_mode); // Call appropriate function based on operation type // Activation functions take 3 parameters (input, output, stream) @@ -903,7 +917,8 @@ class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam bool, bool, bool, - bool>> {}; + bool, + NVTENVFP44Over6ErrMode>> {}; TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { // Skip tests for pre-Blackwell architectures @@ -921,6 +936,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { const bool use_2d_quantization = std::get<4>(GetParam()); const bool row_scaled_nvfp4 = std::get<5>(GetParam()); const bool use_4over6 = std::get<6>(GetParam()); + const NVTENVFP44Over6ErrMode err_mode = std::get<7>(GetParam()); // Skip tests if the input tensor is 1D if (tensor_dims.size() < 2) { @@ -939,7 +955,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, performTest(OP, tensor_dims, use_fast_math, use_2d_quantization, - row_scaled_nvfp4, use_4over6); + row_scaled_nvfp4, use_4over6, err_mode); ); } @@ -973,6 +989,11 @@ std::string test_name(const FusedCastTransposeNVFP4TestSuite::ParamType& param) } if (std::get<6>(param)) { name += "X4OVER6"; + if (std::get<7>(param) == kNVTENVFP44Over6ErrMSE) { + name += "XMSE"; + } else { + name += "XMAE"; + } } return name; } @@ -987,7 +1008,8 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(false), ::testing::Values(false), ::testing::Values(false), - ::testing::Values(false)), + ::testing::Values(false), + ::testing::Values(kNVTENVFP44Over6ErrMAE)), [](const testing::TestParamInfo& info) { return test_name(info.param); }); @@ -1002,7 +1024,8 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(false), ::testing::Values(false), ::testing::Values(true), - ::testing::Values(false)), + ::testing::Values(false), + ::testing::Values(kNVTENVFP44Over6ErrMAE)), [](const testing::TestParamInfo& info) { return test_name(info.param); }); @@ -1017,7 +1040,8 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(false), ::testing::Values(false), ::testing::Values(false), - ::testing::Values(true)), + ::testing::Values(true), + ::testing::Values(kNVTENVFP44Over6ErrMAE, kNVTENVFP44Over6ErrMSE)), [](const testing::TestParamInfo& info) { return test_name(info.param); }); @@ -1032,7 +1056,8 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(false), ::testing::Values(false), ::testing::Values(true), - ::testing::Values(true)), + ::testing::Values(true), + ::testing::Values(kNVTENVFP44Over6ErrMAE, kNVTENVFP44Over6ErrMSE)), [](const testing::TestParamInfo& info) { return test_name(info.param); }); @@ -1047,7 +1072,8 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(false), ::testing::Values(true), ::testing::Values(false), - ::testing::Values(true)), + ::testing::Values(true), + ::testing::Values(kNVTENVFP44Over6ErrMAE, kNVTENVFP44Over6ErrMSE)), [](const testing::TestParamInfo& info) { return test_name(info.param); }); diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 3450c4bd0b..10d578aa95 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -29,6 +29,7 @@ def check_nvfp4_gemm_versus_reference( w_columnwise: bool = False, row_scaled_nvfp4: bool = False, use_4over6: bool = False, + four_over_six_err_mode: str = "MAE", ): te_dtype = tex.DType.kFloat4E2M1 @@ -61,6 +62,7 @@ def check_nvfp4_gemm_versus_reference( with_post_rht_amax=False, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, + four_over_six_err_mode=four_over_six_err_mode, ) w_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -71,6 +73,7 @@ def check_nvfp4_gemm_versus_reference( with_rht=False, with_post_rht_amax=False, use_4over6=use_4over6, + four_over_six_err_mode=four_over_six_err_mode, ) # Quantize x and w @@ -127,6 +130,7 @@ def check_nvfp4_gemm_versus_reference( quant_tile_shape=(1, 16), row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, + four_over_six_err_mode=four_over_six_err_mode, ) w_ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, @@ -136,6 +140,7 @@ def check_nvfp4_gemm_versus_reference( eps=0.0, quant_tile_shape=(1, 16), use_4over6=use_4over6, + four_over_six_err_mode=four_over_six_err_mode, ) # Create reference quantized tensors needed by reference GEMM @@ -238,6 +243,7 @@ def check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( use_bias: bool, single_output: bool, use_4over6: bool = False, + four_over_six_err_mode: str = "MAE", ): te_dtype = tex.DType.kFloat4E2M1 device = "cuda" @@ -256,6 +262,7 @@ def check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( with_post_rht_amax=False, row_scaled_nvfp4=True, use_4over6=use_4over6, + four_over_six_err_mode=four_over_six_err_mode, ) w_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -266,6 +273,7 @@ def check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( with_rht=False, with_post_rht_amax=False, use_4over6=use_4over6, + four_over_six_err_mode=four_over_six_err_mode, ) x_nvfp4 = [] @@ -330,6 +338,7 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( K: int, N: int, use_4over6: bool = False, + four_over_six_err_mode: str = "MAE", ): te_dtype = tex.DType.kFloat4E2M1 device = "cuda" @@ -349,6 +358,7 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( with_post_rht_amax=False, row_scaled_nvfp4=True, use_4over6=use_4over6, + four_over_six_err_mode=four_over_six_err_mode, ) x_tensorwise_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -359,6 +369,7 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( with_rht=False, with_post_rht_amax=False, use_4over6=use_4over6, + four_over_six_err_mode=four_over_six_err_mode, ) w_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -369,6 +380,7 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( with_rht=False, with_post_rht_amax=False, use_4over6=use_4over6, + four_over_six_err_mode=four_over_six_err_mode, ) x_row_scaled = x_row_scaled_quantizer.update_quantized( @@ -430,6 +442,7 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) @pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) +@pytest.mark.parametrize("four_over_six_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) def test_nvfp4_gemm_versus_reference( M: int, K: int, @@ -442,6 +455,7 @@ def test_nvfp4_gemm_versus_reference( is_w_columnwise: bool, row_scaled_nvfp4: bool, use_4over6: bool, + four_over_six_err_mode: str, ): if row_scaled_nvfp4: if accumulate: @@ -461,6 +475,7 @@ def test_nvfp4_gemm_versus_reference( w_columnwise=is_w_columnwise, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, + four_over_six_err_mode=four_over_six_err_mode, ) @@ -487,6 +502,7 @@ def test_nvfp4_gemm_versus_reference( @pytest.mark.parametrize("use_bias", [False, True], ids=["no_bias", "bias"]) @pytest.mark.parametrize("single_output", [False, True], ids=["list_output", "single_output"]) @pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) +@pytest.mark.parametrize("four_over_six_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) def test_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( m_splits: list[int], k: int, @@ -497,6 +513,7 @@ def test_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( use_bias: bool, single_output: bool, use_4over6: bool, + four_over_six_err_mode: str, ): check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( x_dtype=x_dtype, @@ -508,6 +525,7 @@ def test_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( use_bias=use_bias, single_output=single_output, use_4over6=use_4over6, + four_over_six_err_mode=four_over_six_err_mode, ) @@ -532,6 +550,7 @@ def test_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( @pytest.mark.parametrize("w_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) @pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) +@pytest.mark.parametrize("four_over_six_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) def test_nvfp4_row_scaled_gemm_matches_emulated( M: int, K: int, @@ -540,6 +559,7 @@ def test_nvfp4_row_scaled_gemm_matches_emulated( w_dtype: torch.dtype, out_dtype: torch.dtype, use_4over6: bool, + four_over_six_err_mode: str, ): check_nvfp4_row_scaled_gemm_matches_emulated( x_dtype=x_dtype, @@ -549,4 +569,5 @@ def test_nvfp4_row_scaled_gemm_matches_emulated( K=K, N=N, use_4over6=use_4over6, + four_over_six_err_mode=four_over_six_err_mode, ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 60c1543407..b18521f1b1 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -53,6 +53,7 @@ def check_quantization_nvfp4_versus_reference( with_2d_quantization: bool, row_scaled_nvfp4: bool = False, use_4over6: bool = False, + four_over_six_err_mode: str = "MAE", ) -> None: maybe_skip_row_scaled_unsupported_quantization( row_scaled_nvfp4, return_transpose, with_2d_quantization, use_4over6, x_dtype, M, N @@ -80,6 +81,7 @@ def check_quantization_nvfp4_versus_reference( with_2d_quantization=with_2d_quantization, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, + four_over_six_err_mode=four_over_six_err_mode, ) if use_cpp_allocator: x_nvfp4_sut = nvfp4_quantizer(x) @@ -114,6 +116,7 @@ def check_quantization_nvfp4_versus_reference( quant_tile_shape=quant_tile_shape, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, + four_over_six_err_mode=four_over_six_err_mode, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -190,6 +193,7 @@ def check_quantization_nvfp4_versus_reference( ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) @pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) +@pytest.mark.parametrize("four_over_six_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) def test_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, M: int, @@ -200,6 +204,7 @@ def test_quantization_block_tiling_versus_reference( with_2d_quantization: bool, row_scaled_nvfp4: bool, use_4over6: bool, + four_over_six_err_mode: str, ) -> None: check_quantization_nvfp4_versus_reference( x_dtype=x_dtype, @@ -211,6 +216,7 @@ def test_quantization_block_tiling_versus_reference( with_2d_quantization=with_2d_quantization, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, + four_over_six_err_mode=four_over_six_err_mode, ) @@ -229,6 +235,7 @@ def test_quantization_block_tiling_versus_reference( ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) @pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) +@pytest.mark.parametrize("four_over_six_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) def test_nvfp4_quantization_extrema_versus_reference( x_dtype: torch.dtype, M: int, @@ -238,6 +245,7 @@ def test_nvfp4_quantization_extrema_versus_reference( use_cpp_allocator: bool, row_scaled_nvfp4: bool, use_4over6: bool, + four_over_six_err_mode: str, ): maybe_skip_row_scaled_unsupported_quantization( row_scaled_nvfp4, return_transpose, use_4over6=use_4over6 @@ -265,6 +273,7 @@ def test_nvfp4_quantization_extrema_versus_reference( with_post_rht_amax=False, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, + four_over_six_err_mode=four_over_six_err_mode, ) if use_cpp_allocator: @@ -297,6 +306,7 @@ def test_nvfp4_quantization_extrema_versus_reference( quant_tile_shape=(1, 16), row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, + four_over_six_err_mode=four_over_six_err_mode, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -342,6 +352,7 @@ def test_nvfp4_quantization_extrema_versus_reference( ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) @pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) +@pytest.mark.parametrize("four_over_six_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) def test_nvfp4_quantization_boundary_values( x_dtype: torch.dtype, M: int, @@ -350,6 +361,7 @@ def test_nvfp4_quantization_boundary_values( use_cpp_allocator: bool, row_scaled_nvfp4: bool, use_4over6: bool, + four_over_six_err_mode: str, ): """ Stress rounding/threshold behavior by placing values just below/above @@ -391,6 +403,7 @@ def test_nvfp4_quantization_boundary_values( with_post_rht_amax=False, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, + four_over_six_err_mode=four_over_six_err_mode, ) if use_cpp_allocator: @@ -423,6 +436,7 @@ def test_nvfp4_quantization_boundary_values( quant_tile_shape=(1, 16), row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, + four_over_six_err_mode=four_over_six_err_mode, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -468,6 +482,7 @@ def test_nvfp4_quantization_boundary_values( ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) @pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) +@pytest.mark.parametrize("four_over_six_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) def test_nvfp4_quantization_noncontiguous_inputs( x_dtype: torch.dtype, M: int, @@ -476,6 +491,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( use_cpp_allocator: bool, row_scaled_nvfp4: bool, use_4over6: bool, + four_over_six_err_mode: str, ): maybe_skip_row_scaled_unsupported_quantization( row_scaled_nvfp4, return_transpose, use_4over6=use_4over6 @@ -503,6 +519,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( with_post_rht_amax=False, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, + four_over_six_err_mode=four_over_six_err_mode, ) if use_cpp_allocator: @@ -535,6 +552,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( quant_tile_shape=(1, 16), row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, + four_over_six_err_mode=four_over_six_err_mode, ) x_nvfp4_ref = ref_quantizer.quantize(x_nc) diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 4e4a4fdc0c..1c1b79e7aa 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -520,12 +520,14 @@ def test_quantizer_update(self, module_class): [None, "weights", "activations", "all"], ids=["default", "weights", "activations", "all"], ) -def test_nvfp4_row_scaled_quantizer_roles(nvfp4_4over6): +@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) +def test_nvfp4_row_scaled_quantizer_roles(nvfp4_4over6, nvfp4_4over6_err_mode): recipe = NVFP4BlockScaling( disable_rht=True, disable_stochastic_rounding=True, disable_2d_quantization=True, nvfp4_4over6=nvfp4_4over6, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, row_scaled_activation=True, ) @@ -547,6 +549,7 @@ def expected_use_4over6(tensor_type): assert [q.use_4over6 for q in forward_quantizers] == [ expected_use_4over6(tensor_type) for tensor_type in ("input", "weight", "output") ] + assert [q.four_over_six_err_mode for q in forward_quantizers] == [nvfp4_4over6_err_mode] * 3 assert not forward_quantizers[0].is_quantizable(torch.empty(16, 16)) assert forward_quantizers[1].is_quantizable(torch.empty(16, 16)) @@ -565,6 +568,7 @@ def expected_use_4over6(tensor_type): assert [q.use_4over6 for q in role_quantizers] == [ expected_use_4over6(tensor_type) for tensor_type in ("weight", "input", "output", "input") ] + assert [q.four_over_six_err_mode for q in role_quantizers] == [nvfp4_4over6_err_mode] * 4 backward_quantizers = NVFP4BlockScalingRecipeState( recipe, @@ -579,6 +583,7 @@ def expected_use_4over6(tensor_type): assert [q.use_4over6 for q in backward_quantizers] == [ expected_use_4over6(tensor_type) for tensor_type in ("grad_output", "grad_input") ] + assert [q.four_over_six_err_mode for q in backward_quantizers] == [nvfp4_4over6_err_mode] * 2 @pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index a28aac98ca..3cc4ed93e5 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -137,11 +137,12 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, /*swizzled_scale=*/false, /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, - /*use_fast_math=*/quant_config_cpp.use_fast_math, /*rng_state=*/quant_config_cpp.rng_state, /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, /*row_scaled_nvfp4=*/row_scaled_nvfp4, /*use_4over6=*/use_4over6, + /*nvfp4_4over6_err_mode=*/quant_config_cpp.nvfp4_4over6_err_mode, + /*nvfp4_4over6_err_fast_math=*/quant_config_cpp.nvfp4_4over6_err_fast_math, /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); } @@ -287,11 +288,13 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, /*swizzled_scale=*/false, /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, - /*use_fast_math=*/quant_config_cpp.use_fast_math, /*rng_state=*/quant_config_cpp.rng_state, /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, /*row_scaled_nvfp4=*/false, - /*use_4over6=*/use_4over6, /*noop_tensor=*/noop_tensor->data, + /*use_4over6=*/use_4over6, + /*nvfp4_4over6_err_mode=*/quant_config_cpp.nvfp4_4over6_err_mode, + /*nvfp4_4over6_err_fast_math=*/quant_config_cpp.nvfp4_4over6_err_fast_math, + /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); } break; diff --git a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh index c084459e11..d0fdc30369 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh @@ -10,8 +10,8 @@ * 4over6 evaluates two TE-style NVFP4 encodings for each 1x16 block. The * map-to-6 candidate uses the normal block scale. The map-to-4 candidate uses * a 1.5x expanded block scale, which maps the FP4 value 4 to the same dynamic - * range as FP4 value 6. The selected candidate is the one with lower MSE after - * dequantizing back to the original input domain; ties select map-to-6. + * range as FP4 value 6. The selected candidate is the one with lower configured + * error after dequantizing back to the original input domain; ties select map-to-6. */ #ifndef TRANSFORMER_ENGINE_QUANTIZE_4OVER6_NVFP4_CUH_ @@ -22,8 +22,6 @@ #include #include -#include - #include "core_nvfp4.cuh" namespace transformer_engine { @@ -33,6 +31,21 @@ namespace core { #if FP4_TYPE_SUPPORTED +#define TRANSFORMER_ENGINE_NVFP4_4OVER6_ERR_MODE_SWITCH(ERR_MODE, ERR_MODE_CONST, ...) \ + switch (ERR_MODE) { \ + case kNVTENVFP44Over6ErrMAE: { \ + constexpr NVTENVFP44Over6ErrMode ERR_MODE_CONST = kNVTENVFP44Over6ErrMAE; \ + { __VA_ARGS__ } \ + } break; \ + case kNVTENVFP44Over6ErrMSE: { \ + constexpr NVTENVFP44Over6ErrMode ERR_MODE_CONST = kNVTENVFP44Over6ErrMSE; \ + { __VA_ARGS__ } \ + } break; \ + default: { \ + NVTE_ERROR("Unsupported NVFP4 4over6 error mode."); \ + } \ + } + __device__ __forceinline__ void compute_4over6_decoding_scaling_factors( const float block_amax, const float S_enc, nvfp4_scale_t &S_dec_b_fp8_map4, nvfp4_scale_t &S_dec_b_fp8_map6) { @@ -45,50 +58,36 @@ __device__ __forceinline__ void compute_4over6_decoding_scaling_factors( S_dec_b_fp8_map6 = static_cast(sf_high_precision_map6); } -template struct QuantizationScales4Over6 { nvfp4_scale_t S_dec_b_fp8_map4; nvfp4_scale_t S_dec_b_fp8_map6; - scaling_coeff_type SFcoefficient_map4; - scaling_coeff_type SFcoefficient_map6; + float SFcoefficient_map4; + float SFcoefficient_map6; }; -template -__device__ __forceinline__ scaling_coeff_type -compute_4over6_nvfp4_scaling_coefficient(const nvfp4_scale_t S_dec_block, const float S_enc) { - if constexpr (std::is_same_v) { - const float S_dec = 1.0f / S_enc; - const float scale_rcp = - fminf(1.0f / (static_cast(S_dec_block) * S_dec), detail::TypeExtrema::max); - return scale_rcp; - } else if constexpr (std::is_same_v) { - const float scale_rcp = - fminf(S_enc / static_cast(S_dec_block), detail::TypeExtrema::max); - return static_cast(scale_rcp); - } else { - NVTE_DEVICE_ERROR("Unsupported scaling-factor type. Only FP32 and BF16 are supported."); - return scaling_coeff_type{}; - } +__device__ __forceinline__ float compute_4over6_nvfp4_scaling_coefficient( + const nvfp4_scale_t S_dec_block, const float S_enc) { + const float S_dec = 1.0f / S_enc; + return fminf(1.0f / (static_cast(S_dec_block) * S_dec), detail::TypeExtrema::max); } -template -__device__ __forceinline__ QuantizationScales4Over6 +__device__ __forceinline__ QuantizationScales4Over6 compute_4over6_nvfp4_quantization_scaling_factors(const float block_amax, const float S_enc) { - QuantizationScales4Over6 scaling_factors; + QuantizationScales4Over6 scaling_factors; compute_4over6_decoding_scaling_factors(block_amax, S_enc, scaling_factors.S_dec_b_fp8_map4, scaling_factors.S_dec_b_fp8_map6); - scaling_factors.SFcoefficient_map4 = compute_4over6_nvfp4_scaling_coefficient( - scaling_factors.S_dec_b_fp8_map4, S_enc); - scaling_factors.SFcoefficient_map6 = compute_4over6_nvfp4_scaling_coefficient( - scaling_factors.S_dec_b_fp8_map6, S_enc); + scaling_factors.SFcoefficient_map4 = + compute_4over6_nvfp4_scaling_coefficient(scaling_factors.S_dec_b_fp8_map4, S_enc); + scaling_factors.SFcoefficient_map6 = + compute_4over6_nvfp4_scaling_coefficient(scaling_factors.S_dec_b_fp8_map6, S_enc); return scaling_factors; } -__device__ __forceinline__ QuantizationScales4Over6 +__device__ __forceinline__ QuantizationScales4Over6 compute_4over6_fp4_encode_quantization_scaling_factors(const float block_amax, const float global_encode_scale, const float global_decode_scale) { - QuantizationScales4Over6 scaling_factors; + QuantizationScales4Over6 scaling_factors; compute_4over6_decoding_scaling_factors(block_amax, global_encode_scale, scaling_factors.S_dec_b_fp8_map4, scaling_factors.S_dec_b_fp8_map6); @@ -101,12 +100,34 @@ compute_4over6_fp4_encode_quantization_scaling_factors(const float block_amax, return scaling_factors; } -template -__device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_mse_rn(const float (&x)[8], - const float block_scale_inverse, - const nvfp4_scale_t S_dec_b_fp8, - const float global_amax, - float *err) { +template +__device__ __forceinline__ float compute_4over6_error_rn(const float diff) { + if constexpr (ERR_MODE == kNVTENVFP44Over6ErrMSE) { + return __fmul_rn(diff, diff); + } else if constexpr (ERR_MODE == kNVTENVFP44Over6ErrMAE) { + return fabsf(diff); + } else { + NVTE_DEVICE_ERROR("Unsupported NVFP4 4over6 error mode."); + return fabsf(diff); + } +} + +template +__device__ __forceinline__ float compute_4over6_error(const float diff) { + if constexpr (ERR_MODE == kNVTENVFP44Over6ErrMSE) { + return diff * diff; + } else if constexpr (ERR_MODE == kNVTENVFP44Over6ErrMAE) { + return fabsf(diff); + } else { + NVTE_DEVICE_ERROR("Unsupported NVFP4 4over6 error mode."); + return fabsf(diff); + } +} + +template +__device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error_rn( + const float (&x)[8], const float block_scale_inverse, const nvfp4_scale_t S_dec_b_fp8, + const float global_amax, float *err) { uint32_t out = 0; uint32_t out_dequant_1 = 0; uint32_t out_dequant_2 = 0; @@ -116,20 +137,9 @@ __device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_mse_rn(const float ( constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; if constexpr (is_blackwell) { float x_scaled[8]; - if constexpr (USE_FAST_MATH) { #pragma unroll - for (int i = 0; i < 8; ++i) { - x_scaled[i] = x[i] * block_scale_inverse; - } - } else { - x_scaled[0] = __fmul_rn(x[0], block_scale_inverse); - x_scaled[1] = __fmul_rn(x[1], block_scale_inverse); - x_scaled[2] = __fmul_rn(x[2], block_scale_inverse); - x_scaled[3] = __fmul_rn(x[3], block_scale_inverse); - x_scaled[4] = __fmul_rn(x[4], block_scale_inverse); - x_scaled[5] = __fmul_rn(x[5], block_scale_inverse); - x_scaled[6] = __fmul_rn(x[6], block_scale_inverse); - x_scaled[7] = __fmul_rn(x[7], block_scale_inverse); + for (int i = 0; i < 8; ++i) { + x_scaled[i] = __fmul_rn(x[i], block_scale_inverse); } asm volatile( @@ -161,9 +171,9 @@ __device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_mse_rn(const float ( constexpr float fp4_max = detail::TypeExtrema::max; // 6.0f constexpr float fp8_4over6_max = 256.0f; - constexpr float mse_denom = fp4_max * fp8_4over6_max; + constexpr float err_denom = fp4_max * fp8_4over6_max; const float sf = static_cast(S_dec_b_fp8); - if constexpr (USE_FAST_MATH) { + if constexpr (USE_ERR_FAST_MATH) { const float dequant[8] = { __half2float(__ushort_as_half(out_dequant_1_lo)), __half2float(__ushort_as_half(out_dequant_1_hi)), @@ -176,35 +186,35 @@ __device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_mse_rn(const float ( }; #pragma unroll for (int i = 0; i < 8; ++i) { - const float val = dequant[i] * sf * global_amax / mse_denom; + const float val = dequant[i] * sf * global_amax / err_denom; const float diff = val - x[i]; - *err += diff * diff; + *err += compute_4over6_error(diff); } } else { const float val0 = __fdiv_rn( __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_1_lo)), sf), global_amax), - mse_denom); + err_denom); const float val1 = __fdiv_rn( __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_1_hi)), sf), global_amax), - mse_denom); + err_denom); const float val2 = __fdiv_rn( __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_2_lo)), sf), global_amax), - mse_denom); + err_denom); const float val3 = __fdiv_rn( __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_2_hi)), sf), global_amax), - mse_denom); + err_denom); const float val4 = __fdiv_rn( __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_3_lo)), sf), global_amax), - mse_denom); + err_denom); const float val5 = __fdiv_rn( __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_3_hi)), sf), global_amax), - mse_denom); + err_denom); const float val6 = __fdiv_rn( __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_4_lo)), sf), global_amax), - mse_denom); + err_denom); const float val7 = __fdiv_rn( __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_4_hi)), sf), global_amax), - mse_denom); + err_denom); const float diff0 = __fsub_rn(val0, x[0]); const float diff1 = __fsub_rn(val1, x[1]); @@ -215,14 +225,14 @@ __device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_mse_rn(const float ( const float diff6 = __fsub_rn(val6, x[6]); const float diff7 = __fsub_rn(val7, x[7]); - *err = __fadd_rn(*err, __fmul_rn(diff0, diff0)); - *err = __fadd_rn(*err, __fmul_rn(diff1, diff1)); - *err = __fadd_rn(*err, __fmul_rn(diff2, diff2)); - *err = __fadd_rn(*err, __fmul_rn(diff3, diff3)); - *err = __fadd_rn(*err, __fmul_rn(diff4, diff4)); - *err = __fadd_rn(*err, __fmul_rn(diff5, diff5)); - *err = __fadd_rn(*err, __fmul_rn(diff6, diff6)); - *err = __fadd_rn(*err, __fmul_rn(diff7, diff7)); + *err = __fadd_rn(*err, compute_4over6_error_rn(diff0)); + *err = __fadd_rn(*err, compute_4over6_error_rn(diff1)); + *err = __fadd_rn(*err, compute_4over6_error_rn(diff2)); + *err = __fadd_rn(*err, compute_4over6_error_rn(diff3)); + *err = __fadd_rn(*err, compute_4over6_error_rn(diff4)); + *err = __fadd_rn(*err, compute_4over6_error_rn(diff5)); + *err = __fadd_rn(*err, compute_4over6_error_rn(diff6)); + *err = __fadd_rn(*err, compute_4over6_error_rn(diff7)); } } else { NVTE_DEVICE_ERROR( @@ -233,35 +243,37 @@ __device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_mse_rn(const float ( return out; } -template -__device__ __forceinline__ void quantize_4over6_16x( - const float (&first_half)[8], const float (&second_half)[8], - const QuantizationScales4Over6 &scaling_factors, const float global_amax, - float &err_map4, float &err_map6, uint32_t (&rOut_map4)[2], uint32_t (&rOut_map6)[2]) { +template +__device__ __forceinline__ void quantize_4over6_16x(const float (&first_half)[8], + const float (&second_half)[8], + const QuantizationScales4Over6 &scaling_factors, + const float global_amax, float &err_map4, + float &err_map6, uint32_t (&rOut_map4)[2], + uint32_t (&rOut_map6)[2]) { if constexpr (REVERSE_PACK_ORDER) { - rOut_map4[1] = cvt_fp32_to_fp4_8x_with_mse_rn( + rOut_map4[1] = cvt_fp32_to_fp4_8x_with_error_rn( second_half, static_cast(scaling_factors.SFcoefficient_map4), scaling_factors.S_dec_b_fp8_map4, global_amax, &err_map4); - rOut_map6[1] = cvt_fp32_to_fp4_8x_with_mse_rn( + rOut_map6[1] = cvt_fp32_to_fp4_8x_with_error_rn( second_half, static_cast(scaling_factors.SFcoefficient_map6), scaling_factors.S_dec_b_fp8_map6, global_amax, &err_map6); - rOut_map4[0] = cvt_fp32_to_fp4_8x_with_mse_rn( + rOut_map4[0] = cvt_fp32_to_fp4_8x_with_error_rn( first_half, static_cast(scaling_factors.SFcoefficient_map4), scaling_factors.S_dec_b_fp8_map4, global_amax, &err_map4); - rOut_map6[0] = cvt_fp32_to_fp4_8x_with_mse_rn( + rOut_map6[0] = cvt_fp32_to_fp4_8x_with_error_rn( first_half, static_cast(scaling_factors.SFcoefficient_map6), scaling_factors.S_dec_b_fp8_map6, global_amax, &err_map6); } else { - rOut_map4[0] = cvt_fp32_to_fp4_8x_with_mse_rn( + rOut_map4[0] = cvt_fp32_to_fp4_8x_with_error_rn( first_half, static_cast(scaling_factors.SFcoefficient_map4), scaling_factors.S_dec_b_fp8_map4, global_amax, &err_map4); - rOut_map6[0] = cvt_fp32_to_fp4_8x_with_mse_rn( + rOut_map6[0] = cvt_fp32_to_fp4_8x_with_error_rn( first_half, static_cast(scaling_factors.SFcoefficient_map6), scaling_factors.S_dec_b_fp8_map6, global_amax, &err_map6); - rOut_map4[1] = cvt_fp32_to_fp4_8x_with_mse_rn( + rOut_map4[1] = cvt_fp32_to_fp4_8x_with_error_rn( second_half, static_cast(scaling_factors.SFcoefficient_map4), scaling_factors.S_dec_b_fp8_map4, global_amax, &err_map4); - rOut_map6[1] = cvt_fp32_to_fp4_8x_with_mse_rn( + rOut_map6[1] = cvt_fp32_to_fp4_8x_with_error_rn( second_half, static_cast(scaling_factors.SFcoefficient_map6), scaling_factors.S_dec_b_fp8_map6, global_amax, &err_map6); } @@ -271,28 +283,29 @@ __device__ __forceinline__ bool pick_4over6_map4(const float err_map4, const flo return err_map4 < err_map6; } -template -__device__ __forceinline__ nvfp4_scale_t selected_4over6_scale( - const bool pick_map4, const QuantizationScales4Over6 &scaling_factors) { +__device__ __forceinline__ nvfp4_scale_t +selected_4over6_scale(const bool pick_map4, const QuantizationScales4Over6 &scaling_factors) { if (pick_map4) { return scaling_factors.S_dec_b_fp8_map4; } return scaling_factors.S_dec_b_fp8_map6; } -template -__device__ __forceinline__ void quantize_4over6_16x( - const float (&first_half)[8], const float (&second_half)[8], - const QuantizationScales4Over6 &scaling_factors, const float global_amax, - nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { +template +__device__ __forceinline__ void quantize_4over6_16x(const float (&first_half)[8], + const float (&second_half)[8], + const QuantizationScales4Over6 &scaling_factors, + const float global_amax, + nvfp4_scale_t &S_dec_b_fp8, + uint32_t (&rOut)[2]) { float err_map4 = 0.0f; float err_map6 = 0.0f; __align__(8) uint32_t rOut_map4[2]; __align__(8) uint32_t rOut_map6[2]; - quantize_4over6_16x(first_half, second_half, scaling_factors, - global_amax, err_map4, err_map6, rOut_map4, - rOut_map6); + quantize_4over6_16x( + first_half, second_half, scaling_factors, global_amax, err_map4, err_map6, rOut_map4, + rOut_map6); const bool pick_map4 = pick_4over6_map4(err_map4, err_map6); S_dec_b_fp8 = selected_4over6_scale(pick_map4, scaling_factors); @@ -390,25 +403,24 @@ __device__ __forceinline__ void load_4over6_vec_index_halves_16x(const vec_type } } -template +template __device__ __forceinline__ void quantize_4over6_candidates_16x( - const float (&x)[16], const QuantizationScales4Over6 &scaling_factors, - const float global_amax, QuantizationCandidates4Over6 &candidates) { + const float (&x)[16], const QuantizationScales4Over6 &scaling_factors, const float global_amax, + QuantizationCandidates4Over6 &candidates) { float first_half[8]; float second_half[8]; load_4over6_contiguous_halves_16x(x, first_half, second_half); candidates.reset_errors(); - quantize_4over6_16x(first_half, second_half, scaling_factors, global_amax, - candidates.err_map4, candidates.err_map6, candidates.rOut_map4, - candidates.rOut_map6); + quantize_4over6_16x( + first_half, second_half, scaling_factors, global_amax, candidates.err_map4, + candidates.err_map6, candidates.rOut_map4, candidates.rOut_map6); } -template +template __device__ __forceinline__ bool record_and_select_4over6_2d_block( - const QuantizationScales4Over6 &scaling_factors, - const size_t block_in_tile_y, const size_t block_in_tile_x, const size_t participant_idx, + const QuantizationScales4Over6 &scaling_factors, const size_t block_in_tile_y, + const size_t block_in_tile_x, const size_t participant_idx, QuantizationScratch4Over6 &scratch, nvfp4_scale_t &S_dec_b_fp8, const QuantizationCandidates4Over6 &candidates) { scratch.err_map4_matrix[block_in_tile_y][block_in_tile_x][participant_idx] = candidates.err_map4; @@ -438,7 +450,8 @@ __device__ __forceinline__ bool record_and_select_4over6_2d_block( return scratch.pick_map4_matrix[block_in_tile_y][block_in_tile_x] != 0; } -template +template __device__ __forceinline__ bool quantize_and_select_4over6_2d_block_16x( const float (&x)[16], const float block_amax, const float global_encode_scale, const float global_decode_scale, const float global_amax, const size_t block_in_tile_y, @@ -447,7 +460,8 @@ __device__ __forceinline__ bool quantize_and_select_4over6_2d_block_16x( nvfp4_scale_t &S_dec_b_fp8, QuantizationCandidates4Over6 &candidates) { const auto scaling_factors = compute_4over6_fp4_encode_quantization_scaling_factors( block_amax, global_encode_scale, global_decode_scale); - quantize_4over6_candidates_16x(x, scaling_factors, global_amax, candidates); + quantize_4over6_candidates_16x(x, scaling_factors, global_amax, + candidates); const bool pick_map4 = record_and_select_4over6_2d_block( @@ -510,86 +524,84 @@ __device__ __forceinline__ void store_selected_4over6_packed_16x( store_4over6_packed_16x(candidates.selected_packed(pick_map4), output_vec); } -template __device__ __forceinline__ void quantize_4over6_contiguous_16x( - const input_type *x, const QuantizationScales4Over6 &scaling_factors, - const float global_amax, nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { + const input_type *x, const QuantizationScales4Over6 &scaling_factors, const float global_amax, + nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { float first_half[8]; float second_half[8]; load_4over6_contiguous_halves_16x(x, first_half, second_half); - quantize_4over6_16x(first_half, second_half, scaling_factors, - global_amax, S_dec_b_fp8, rOut); + quantize_4over6_16x( + first_half, second_half, scaling_factors, global_amax, S_dec_b_fp8, rOut); } -template __device__ __forceinline__ void quantize_4over6_pair_array_16x( - const pair_type (&x)[2][4], const QuantizationScales4Over6 &scaling_factors, + const pair_type (&x)[2][4], const QuantizationScales4Over6 &scaling_factors, const float global_amax, nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { float first_half[8]; float second_half[8]; load_4over6_pair_array_halves_16x(x, first_half, second_half); - quantize_4over6_16x(first_half, second_half, scaling_factors, - global_amax, S_dec_b_fp8, rOut); + quantize_4over6_16x( + first_half, second_half, scaling_factors, global_amax, S_dec_b_fp8, rOut); } -template +template __device__ __forceinline__ void quantize_4over6_vec2_array_candidates_16x( - const vec_type (&x)[8], const QuantizationScales4Over6 &scaling_factors, + const vec_type (&x)[8], const QuantizationScales4Over6 &scaling_factors, const float global_amax, QuantizationCandidates4Over6 &candidates) { float first_half[8]; float second_half[8]; load_4over6_vec2_array_halves_16x(x, first_half, second_half); candidates.reset_errors(); - quantize_4over6_16x(first_half, second_half, scaling_factors, global_amax, - candidates.err_map4, candidates.err_map6, candidates.rOut_map4, - candidates.rOut_map6); + quantize_4over6_16x( + first_half, second_half, scaling_factors, global_amax, candidates.err_map4, + candidates.err_map6, candidates.rOut_map4, candidates.rOut_map6); } -template __device__ __forceinline__ void quantize_4over6_vec2_array_16x( - const vec_type (&x)[8], const QuantizationScales4Over6 &scaling_factors, + const vec_type (&x)[8], const QuantizationScales4Over6 &scaling_factors, const float global_amax, nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { float first_half[8]; float second_half[8]; load_4over6_vec2_array_halves_16x(x, first_half, second_half); - quantize_4over6_16x(first_half, second_half, scaling_factors, - global_amax, S_dec_b_fp8, rOut); + quantize_4over6_16x( + first_half, second_half, scaling_factors, global_amax, S_dec_b_fp8, rOut); } -template +template __device__ __forceinline__ void quantize_4over6_vec_index_candidates_16x( - const vec_type (&x)[16], const int idx, - const QuantizationScales4Over6 &scaling_factors, const float global_amax, - QuantizationCandidates4Over6 &candidates) { + const vec_type (&x)[16], const int idx, const QuantizationScales4Over6 &scaling_factors, + const float global_amax, QuantizationCandidates4Over6 &candidates) { float first_half[8]; float second_half[8]; load_4over6_vec_index_halves_16x(x, idx, first_half, second_half); candidates.reset_errors(); - quantize_4over6_16x(first_half, second_half, scaling_factors, global_amax, - candidates.err_map4, candidates.err_map6, candidates.rOut_map4, - candidates.rOut_map6); + quantize_4over6_16x( + first_half, second_half, scaling_factors, global_amax, candidates.err_map4, + candidates.err_map6, candidates.rOut_map4, candidates.rOut_map6); } -template __device__ __forceinline__ void quantize_4over6_vec_index_16x( - const vec_type (&x)[16], const int idx, - const QuantizationScales4Over6 &scaling_factors, const float global_amax, - nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { + const vec_type (&x)[16], const int idx, const QuantizationScales4Over6 &scaling_factors, + const float global_amax, nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { float first_half[8]; float second_half[8]; load_4over6_vec_index_halves_16x(x, idx, first_half, second_half); - quantize_4over6_16x(first_half, second_half, scaling_factors, - global_amax, S_dec_b_fp8, rOut); + quantize_4over6_16x( + first_half, second_half, scaling_factors, global_amax, S_dec_b_fp8, rOut); } #endif // FP4_TYPE_SUPPORTED diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index 046a60882c..2774930fa6 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -784,8 +784,8 @@ __global__ void __launch_bounds__(THREADS_NUM) } template + typename IType, bool USE_STOCHASTIC_ROUNDING, bool RETURN_TRANSPOSE, bool USE_4OVER6, + NVTENVFP44Over6ErrMode USE_4OVER6_ERR_MODE, bool USE_4OVER6_ERR_FAST_MATH> __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_output, @@ -1113,7 +1113,8 @@ __global__ void __launch_bounds__(THREADS_NUM) QuantizationCandidates4Over6 candidates; nvfp4_scale_t S_dec_b_fp8; const bool pick_map4 = - quantize_and_select_4over6_2d_block_16x( x_4over6, block_amax, S_enc_colwise, S_dec_colwise, global_amax_colwise, block_in_tile_y, block_in_tile_x, block_col, *four_over_six_scratch, S_dec_b_fp8, @@ -1276,7 +1277,8 @@ __global__ void __launch_bounds__(THREADS_NUM) QuantizationCandidates4Over6 candidates; nvfp4_scale_t S_dec_b_fp8; const bool pick_map4 = - quantize_and_select_4over6_2d_block_16x( in_4over6_rowwise, block_amax, S_enc_rowwise, S_dec_rowwise, global_amax_rowwise, block_in_tile_y, block_in_tile_x, tid_Y_rowwise, *four_over_six_scratch, @@ -1419,8 +1421,11 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, using namespace ptx; bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; - const bool use_fast_math = quant_config ? quant_config->use_fast_math : false; const bool use_4over6 = quant_config ? quant_config->nvfp4_4over6 : false; + const NVTENVFP44Over6ErrMode use_4over6_err_mode = + use_4over6 && quant_config ? quant_config->nvfp4_4over6_err_mode : kNVTENVFP44Over6ErrMAE; + const bool use_4over6_err_fast_math = + use_4over6 && quant_config && quant_config->nvfp4_4over6_err_fast_math; const bool row_scaled_nvfp4 = output->row_scaled_nvfp4; NVTE_CHECK(!row_scaled_nvfp4 || !use_2d_quantization, "Row-scaled NVFP4 quantization does not support 2D quantization."); @@ -1537,35 +1542,41 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, TRANSFORMER_ENGINE_SWITCH_CONDITION( use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_fast_math, USE_FAST_MATH, - TRANSFORMER_ENGINE_SWITCH_CONDITION(row_scaled_nvfp4, ROW_SCALED_NVFP4, { - TRANSFORMER_ENGINE_SWITCH_CONDITION(use_4over6, USE_4OVER6, { - TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { - auto kernel = quantize_transpose_nvfp4_kernel; - - if constexpr (use_2d_quantization) { - kernel = - quantize_transpose_nvfp4_2D_kernel; - } - using FourOverSixScratch = - core::QuantizationScratch4Over6; - constexpr size_t dshmem_size = - base_dshmem_size + - FourOverSixScratch::template dynamic_shared_memory_size(); - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - dshmem_size); - kernel<<>>( - tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, - scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, - scale_stride, scale_stride_transpose, rng_state); - }); - }); + row_scaled_nvfp4, ROW_SCALED_NVFP4, + TRANSFORMER_ENGINE_SWITCH_CONDITION(use_4over6, USE_4OVER6, { + TRANSFORMER_ENGINE_NVFP4_4OVER6_ERR_MODE_SWITCH( + use_4over6_err_mode, USE_4OVER6_ERR_MODE, { + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_4over6_err_fast_math, USE_4OVER6_ERR_FAST_MATH, + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + auto kernel = + quantize_transpose_nvfp4_kernel; + + if constexpr (use_2d_quantization) { + kernel = quantize_transpose_nvfp4_2D_kernel< + COMPUTE_ACTIVATIONS, ParamOP, OP, IType, USE_STOCHASTIC_ROUNDING, + RETURN_TRANSPOSE, USE_4OVER6, USE_4OVER6_ERR_MODE, + USE_4OVER6_ERR_FAST_MATH>; + } + using FourOverSixScratch = + core::QuantizationScratch4Over6; + constexpr size_t dshmem_size = + base_dshmem_size + + FourOverSixScratch::template dynamic_shared_memory_size< + use_2d_quantization, USE_4OVER6>(); + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, + scales_ptr, scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, + amax_colwise_ptr, rows, cols, scale_stride, scale_stride_transpose, + rng_state); + });); + }); }););); #else NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); diff --git a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh index e210a1cb8d..dc0c663473 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh @@ -185,7 +185,8 @@ compute_nvfp4_scaling_coefficient(const nvfp4_scale_t S_dec_block, const f return static_cast(scale_rcp); } -template +template __device__ __forceinline__ void colwise_scaling( const IType *__restrict__ sIn_ptr, fp4e2m1x2 *__restrict__ sOut_tr_ptr, nvfp4_scale_t *__restrict__ sSFcolwise_ptr, const float S_enc_colwise, @@ -234,11 +235,10 @@ __device__ __forceinline__ void colwise_scaling( __align__(8) uint32_t rOut[SCALE_DIM / 8]; nvfp4_scale_t S_dec_b_fp8; const auto scaling_factors = - core::compute_4over6_nvfp4_quantization_scaling_factors( - block_amax[w], S_enc_colwise); + core::compute_4over6_nvfp4_quantization_scaling_factors(block_amax[w], S_enc_colwise); - core::quantize_4over6_contiguous_16x(rIn[w], scaling_factors, - global_amax_colwise, S_dec_b_fp8, rOut); + core::quantize_4over6_contiguous_16x( + rIn[w], scaling_factors, global_amax_colwise, S_dec_b_fp8, rOut); // Store scaling factors to SMEM buffer (R2S) sSFcolwise[scale_tr_offset_Y + w][scale_tr_offset_X] = S_dec_b_fp8; @@ -279,7 +279,8 @@ __device__ __forceinline__ void colwise_scaling( } } -template +template __device__ __forceinline__ void rowwise_scaling( const IType *__restrict__ sIn_ptr, fp4e2m1x2 *__restrict__ sOut_ptr, nvfp4_scale_t *__restrict__ sSFrowwise_ptr, const float S_enc_rowwise, const int stage_Y, @@ -350,15 +351,14 @@ __device__ __forceinline__ void rowwise_scaling( block_S_enc_rowwise = S_enc_rowwise; } const auto scaling_factors = - core::compute_4over6_nvfp4_quantization_scaling_factors( - block_amax, block_S_enc_rowwise); + core::compute_4over6_nvfp4_quantization_scaling_factors(block_amax, block_S_enc_rowwise); __align__(8) uint32_t rOut[WAVES]; if (bank_group == 0) { - core::quantize_4over6_pair_array_16x(rIn, scaling_factors, block_global_amax, - S_dec_b_fp8, rOut); + core::quantize_4over6_pair_array_16x( + rIn, scaling_factors, block_global_amax, S_dec_b_fp8, rOut); } else { - core::quantize_4over6_pair_array_16x( + core::quantize_4over6_pair_array_16x( rIn, scaling_factors, block_global_amax, S_dec_b_fp8, rOut); } @@ -426,7 +426,8 @@ __device__ __forceinline__ void rowwise_scaling( } template + bool ROW_SCALED_NVFP4, bool USE_4OVER6, NVTENVFP44Over6ErrMode USE_4OVER6_ERR_MODE, + bool USE_4OVER6_ERR_FAST_MATH> __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D_kernel( const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_output, @@ -652,12 +653,14 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D ptx::cp_async_bulk_wait_group_read(); // NVFP4 Quantization - rowwise_scaling( + rowwise_scaling( sIn_ptr, sOut_ptr, sSFrowwise_ptr, S_enc_rowwise, stage_Y, stage_X, buff_in, buff_out, amax_rowwise_ptr, block_offset_Y, rows, rng, random_uint4, rnd_idx); if constexpr (RETURN_TRANSPOSE) { - colwise_scaling( + colwise_scaling( sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, S_enc_colwise, global_amax_colwise, stage_Y, stage_X, buff_in, buff_out_tr, rng, random_uint4, rnd_idx); } @@ -760,8 +763,11 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, using namespace ptx; const bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; - const bool use_fast_math = quant_config ? quant_config->use_fast_math : false; const bool use_4over6 = quant_config ? quant_config->nvfp4_4over6 : false; + const NVTENVFP44Over6ErrMode use_4over6_err_mode = + use_4over6 && quant_config ? quant_config->nvfp4_4over6_err_mode : kNVTENVFP44Over6ErrMAE; + const bool use_4over6_err_fast_math = + use_4over6 && quant_config && quant_config->nvfp4_4over6_err_fast_math; const bool row_scaled_nvfp4 = output->row_scaled_nvfp4; // If transposed output is allocated, return the transposed data @@ -873,28 +879,51 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, const int dshmem_size = in_mem + out_mem + out_scales_transpose_mem + out_scales_mem + TMA_SHMEM_ALIGNMENT; - TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, - TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_fast_math, USE_FAST_MATH, - TRANSFORMER_ENGINE_SWITCH_CONDITION( - row_scaled_nvfp4, ROW_SCALED_NVFP4, - TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_4over6, USE_4OVER6, - TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { - auto kernel = - quantize_transpose_nvfp4_tuned_1D_kernel; - - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - dshmem_size); - kernel<<>>( - tensor_map_input, tensor_map_output, tensor_map_output_transpose, - scales_ptr, scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, - amax_colwise_ptr, rows, cols, scale_stride, scale_stride_transpose, - rng_state); - }););););); + if (use_4over6) { + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + row_scaled_nvfp4, ROW_SCALED_NVFP4, + TRANSFORMER_ENGINE_NVFP4_4OVER6_ERR_MODE_SWITCH( + use_4over6_err_mode, USE_4OVER6_ERR_MODE, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_4over6_err_fast_math, USE_4OVER6_ERR_FAST_MATH, + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + auto kernel = quantize_transpose_nvfp4_tuned_1D_kernel< + USE_STOCHASTIC_ROUNDING, + /*USE_FAST_MATH=*/false, RETURN_TRANSPOSE, ROW_SCALED_NVFP4, + /*USE_4OVER6=*/true, USE_4OVER6_ERR_MODE, USE_4OVER6_ERR_FAST_MATH>; + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, + scales_ptr, scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, + amax_colwise_ptr, rows, cols, scale_stride, scale_stride_transpose, + rng_state); + }););););); + } else { + const bool use_fast_math = quant_config ? quant_config->use_fast_math : false; + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_fast_math, USE_FAST_MATH, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + row_scaled_nvfp4, ROW_SCALED_NVFP4, + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + auto kernel = quantize_transpose_nvfp4_tuned_1D_kernel< + USE_STOCHASTIC_ROUNDING, USE_FAST_MATH, RETURN_TRANSPOSE, ROW_SCALED_NVFP4, + /*USE_4OVER6=*/false, kNVTENVFP44Over6ErrMAE, + /*USE_4OVER6_ERR_FAST_MATH=*/false>; + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, + scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, + cols, scale_stride, scale_stride_transpose, rng_state); + });););); + } #else NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); #endif // FP4_TYPE_SUPPORTED diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index be59dd5068..bafac1518a 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -181,7 +181,7 @@ struct Tensor { /*! \brief Whether NVFP4 uses 4over6 block scale selection. * * Only meaningful for NVFP4 tensors. 4over6 tensors use 256 as their - * global E4M3 scale bound and store the lower-MSE map-to-4/map-to-6 + * global E4M3 scale bound and store a selected map-to-4/map-to-6 * candidate for each 1x16 block. */ bool nvfp4_4over6 = false; @@ -487,6 +487,8 @@ struct QuantizationConfig { bool stochastic_rounding = false; bool use_fast_math = false; bool nvfp4_4over6 = false; + NVTENVFP44Over6ErrMode nvfp4_4over6_err_mode = kNVTENVFP44Over6ErrMAE; + bool nvfp4_4over6_err_fast_math = false; static constexpr size_t attr_sizes[] = { sizeof(uint8_t), // force_pow_2_scales @@ -497,7 +499,9 @@ struct QuantizationConfig { sizeof(uint8_t), // nvfp4_2d_quantization sizeof(uint8_t), // stochastic_rounding sizeof(uint8_t), // use_fast_math - sizeof(uint8_t) // nvfp4_4over6 + sizeof(uint8_t), // nvfp4_4over6 + sizeof(uint8_t), // nvfp4_4over6_err_mode + sizeof(uint8_t) // nvfp4_4over6_err_fast_math }; }; diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 67756d477b..fd13ec958f 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -118,6 +118,14 @@ enum NVTEScalingMode { NVTE_INVALID_SCALING = 100 }; +/*! \enum NVTENVFP44Over6ErrMode + * \brief Candidate-selection error mode for NVFP4 4over6 quantization. + */ +enum NVTENVFP44Over6ErrMode { + kNVTENVFP44Over6ErrMAE = 0, /*!< Select the candidate with lower summed absolute error */ + kNVTENVFP44Over6ErrMSE = 1, /*!< Select the candidate with lower summed squared error */ +}; + /*! \brief TE Tensor type * * NVTETensor is a contiguous tensor type storing a pointer @@ -391,11 +399,25 @@ enum NVTEQuantizationConfigAttribute { /*! Whether to use NVFP4 4over6 block scale selection. * * 4over6 evaluates map-to-4 and map-to-6 candidates for each 1x16 block, - * stores the lower-MSE candidate, and emits tensor data that uses a 256 - * global E4M3 scale bound. The output tensor's kNVTENVFP44Over6 metadata - * must match this option. + * stores the lower-error candidate according to + * kNVTEQuantizationConfigNVFP44Over6ErrMode, and emits tensor data that + * uses a 256 global E4M3 scale bound. The output tensor's + * kNVTENVFP44Over6 metadata must match this option. */ kNVTEQuantizationConfigNVFP44Over6 = 8, + /*! Candidate-selection error mode for NVFP4 4over6 quantization. + * + * The value is an NVTENVFP44Over6ErrMode encoded as uint8_t. It is only + * used when kNVTEQuantizationConfigNVFP44Over6 is enabled. + */ + kNVTEQuantizationConfigNVFP44Over6ErrMode = 9, + /*! Whether the NVFP4 4over6 candidate error computation may use fast math. + * + * This is intentionally separate from kNVTEQuantizationConfigUseFastMath so + * callers can keep candidate selection bitwise deterministic independent + * of ordinary NVFP4 fast-math settings. + */ + kNVTEQuantizationConfigNVFP44Over6ErrFastMath = 10, kNVTEQuantizationConfigNumAttributes }; @@ -1351,6 +1373,20 @@ class QuantizationConfigWrapper { sizeof(val)); } + /*! \brief Set NVFP4 4over6 candidate-selection error mode */ + void set_nvfp4_4over6_err_mode(NVTENVFP44Over6ErrMode mode) { + const auto val = static_cast(mode); + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNVFP44Over6ErrMode, &val, + sizeof(val)); + } + + /*! \brief Set whether NVFP4 4over6 candidate error computation uses fast math */ + void set_nvfp4_4over6_err_fast_math(bool use_fast_math) { + const auto val = static_cast(use_fast_math); + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNVFP44Over6ErrFastMath, + &val, sizeof(val)); + } + private: /*! \brief Wrapped NVTEQuantizationConfig. */ NVTEQuantizationConfig config_ = nullptr; diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 82cc42fedd..336ea3c517 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -14,6 +14,7 @@ _BACKWARD_OVERRIDES = (None, "high_precision", "dequantized") _NVFP4_4OVER6_SCOPES = (None, "weights", "activations", "all") +_NVFP4_4OVER6_ERR_MODES = ("MAE", "MSE") class _FormatHelper(NamedTuple): @@ -526,13 +527,15 @@ class NVFP4BlockScaling(Recipe): nvfp4_4over6 : {None, 'weights', 'activations', 'all'}, default = None Select tensors that use NVFP4 4over6. In this mode NVFP4 quantization evaluates per-block map-to-4 and map-to-6 candidates - and chooses the one with lower MSE. Ties choose map-to-6. The + and chooses the one with lower configured error. Ties choose map-to-6. The global E4M3 scale bound is 256 in this mode instead of 448. The ``activations`` scope applies to every non-weight tensor role. Random Hadamard transforms and stochastic rounding are not yet supported on tensors that use 4over6; activation and backward scopes therefore require ``disable_rht=True`` and ``disable_stochastic_rounding=True``. + nvfp4_4over6_err_mode : {'MAE', 'MSE'}, default = 'MAE' + Error metric used by NVFP4 4over6 candidate selection. backward_override : {None, 'high_precision', 'dequantized'}, default = None Backward precision mode. None does not modify backward behavior, `high_precision` keeps original high-precision operands for backward, @@ -548,6 +551,7 @@ class NVFP4BlockScaling(Recipe): disable_2d_quantization: bool = os.getenv("NVTE_NVFP4_DISABLE_2D_QUANTIZATION", "0") == "1" row_scaled_activation: bool = os.getenv("NVTE_NVFP4_ROW_SCALED_ACTIVATION", "0") == "1" nvfp4_4over6: Optional[str] = os.getenv("NVTE_NVFP4_4OVER6", None) + nvfp4_4over6_err_mode: str = os.getenv("NVTE_NVFP4_4OVER6_ERR_MODE", "MAE").upper() fp4_format: Format = Format.E2M1 fp8_format: Format = Format.E4M3 @@ -566,6 +570,9 @@ def __post_init__(self) -> None: assert ( self.nvfp4_4over6 in _NVFP4_4OVER6_SCOPES ), "NVTE_NVFP4_4OVER6 must be unset or one of: 'weights', 'activations', 'all'." + assert ( + self.nvfp4_4over6_err_mode in _NVFP4_4OVER6_ERR_MODES + ), "NVTE_NVFP4_4OVER6_ERR_MODE must be one of: 'MAE', 'MSE'." if self.nvfp4_4over6 in ("activations", "all"): assert self.disable_rht, "NVFP4 4over6 currently requires RHT to be disabled" assert ( @@ -601,6 +608,7 @@ def _make_repr(self) -> str: f"backward_override={self.backward_override}, " f"row_scaled_activation={self.row_scaled_activation}, " f"nvfp4_4over6={self.nvfp4_4over6}, " + f"nvfp4_4over6_err_mode={self.nvfp4_4over6_err_mode}, " f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, " f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, " f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, " diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 2f5cf8ac9c..5eb238dd51 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -1058,6 +1058,14 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigNVFP44Over6: bool_to_uint8(config_.nvfp4_4over6, buf); break; + case kNVTEQuantizationConfigNVFP44Over6ErrMode: { + const auto val = static_cast(config_.nvfp4_4over6_err_mode); + std::memcpy(buf, &val, attr_size); + break; + } + case kNVTEQuantizationConfigNVFP44Over6ErrFastMath: + bool_to_uint8(config_.nvfp4_4over6_err_fast_math, buf); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } @@ -1116,6 +1124,17 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigNVFP44Over6: uint8_to_bool(buf, config_.nvfp4_4over6); break; + case kNVTEQuantizationConfigNVFP44Over6ErrMode: { + const auto val = *reinterpret_cast(buf); + NVTE_CHECK(val == static_cast(kNVTENVFP44Over6ErrMAE) || + val == static_cast(kNVTENVFP44Over6ErrMSE), + "Invalid NVFP4 4over6 error mode (got ", static_cast(val), ")"); + config_.nvfp4_4over6_err_mode = static_cast(val); + break; + } + case kNVTEQuantizationConfigNVFP44Over6ErrFastMath: + uint8_to_bool(buf, config_.nvfp4_4over6_err_fast_math); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h index e4b839ea6d..c65d0d2d17 100644 --- a/transformer_engine/common/transpose/cast_transpose.h +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -66,9 +66,10 @@ void quantize_transpose_vector_blockwise_fp4( const SimpleTensor &input, const SimpleTensor &global_amax, SimpleTensor &scale_inv, SimpleTensor &scale_inv_t, SimpleTensor &output, SimpleTensor &output_t, const float epsilon, const bool return_identity, const bool return_transpose, const bool pow2_scale, - const bool swizzled_scale, const bool use_stochastic_rounding, const bool use_fast_math, + const bool swizzled_scale, const bool use_stochastic_rounding, const NVTETensor rng_state_tensor, const bool use_2d_quantization, const bool row_scaled_nvfp4, - const bool use_4over6, const SimpleTensor &noop_tensor, cudaStream_t stream); + const bool use_4over6, const NVTENVFP44Over6ErrMode nvfp4_4over6_err_mode, + const bool nvfp4_4over6_err_fast_math, const SimpleTensor &noop_tensor, cudaStream_t stream); } // namespace transformer_engine::detail diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index 9ce7bb0dc6..bcfd0ea211 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -308,7 +308,7 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x(const float2 in01, template + bool kUse4Over6, NVTENVFP44Over6ErrMode k4Over6ErrMode, bool kUse4Over6ErrFastMath> __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel( const IType* const input, const float* global_amax, OType* const output_c, OType* const output_t, ScaleType* const tile_scales_inv_c, ScaleType* const tile_scales_inv_t, @@ -549,7 +549,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo const size_t participant_idx = data_row_idx % kFP4BlockScalingSize; nvfp4_core::QuantizationCandidates4Over6 candidates; - nvfp4_core::quantize_4over6_vec2_array_candidates_16x( + nvfp4_core::quantize_4over6_vec2_array_candidates_16x( smem_vec, scaling_factors, row_global_amax, candidates); const bool pick_map4 = nvfp4_core::record_and_select_4over6_2d_block( + nvfp4_core::quantize_4over6_vec2_array_16x( smem_vec, scaling_factors, row_global_amax, scale_inv, output_vec_4over6); nvfp4_core::store_4over6_packed_16x(output_vec_4over6, output_vec); } @@ -713,7 +714,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo const size_t participant_idx = data_col_idx % kFP4BlockScalingSize; nvfp4_core::QuantizationCandidates4Over6 candidates; - nvfp4_core::quantize_4over6_vec_index_candidates_16x( + nvfp4_core::quantize_4over6_vec_index_candidates_16x( smem_vec, smem_idx, scaling_factors, global_amax[0], candidates); const bool pick_map4 = nvfp4_core::record_and_select_4over6_2d_block( + nvfp4_core::quantize_4over6_vec_index_16x( smem_vec, smem_idx, scaling_factors, global_amax[0], scale_inv, output_vec_4over6); nvfp4_core::store_4over6_packed_16x(output_vec_4over6, output_vec); } @@ -806,9 +808,10 @@ void quantize_transpose_vector_blockwise_fp4( const SimpleTensor& input, const SimpleTensor& global_amax, SimpleTensor& scale_inv, SimpleTensor& scale_inv_t, SimpleTensor& output, SimpleTensor& output_t, const float epsilon, const bool return_identity, const bool return_transpose, const bool pow2_scale, - const bool swizzled_scale, const bool use_stochastic_rounding, const bool use_fast_math, + const bool swizzled_scale, const bool use_stochastic_rounding, const NVTETensor rng_state_tensor, const bool use_2d_quantization, const bool row_scaled_nvfp4, - const bool use_4over6, const SimpleTensor& noop_tensor, cudaStream_t stream) { + const bool use_4over6, const NVTENVFP44Over6ErrMode nvfp4_4over6_err_mode, + const bool nvfp4_4over6_err_fast_math, const SimpleTensor& noop_tensor, cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_vector_blockwise_fp4); #if CUDA_VERSION >= 12080 @@ -827,6 +830,9 @@ void quantize_transpose_vector_blockwise_fp4( "Row-scaled NVFP4 quantization does not support 2D quantization."); NVTE_CHECK(!use_4over6 || !use_stochastic_rounding, "NVFP4 4over6 quantization does not support stochastic rounding."); + const NVTENVFP44Over6ErrMode use_4over6_err_mode = + use_4over6 ? nvfp4_4over6_err_mode : kNVTENVFP44Over6ErrMAE; + const bool use_4over6_err_fast_math = use_4over6 && nvfp4_4over6_err_fast_math; const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; size_t num_elements = row_length; @@ -912,47 +918,57 @@ void quantize_transpose_vector_blockwise_fp4( TRANSFORMER_ENGINE_SWITCH_CONDITION( use_4over6, kUse4Over6, - TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_fast_math, kUseFastMath, - - size_t smem_bytes = kSMemSize * sizeof(InputType); - auto kernel = block_scaled_1d_cast_transpose_kernel< - kReturnIdentity, kReturnTranspose, kPow2Scale, - kAligned, float, InputType, OutputType, ScaleType, - kSwizzledScale, kApplyStochasticRounding, - kIs2DBlockScaling, kRowScaledNVFP4, kUse4Over6, - kUseFastMath>; - if (smem_bytes >= 48 * 1024) { - cudaError_t err = cudaFuncSetAttribute( - kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_bytes); - NVTE_CHECK( - err == cudaSuccess, - "Failed to set dynamic shared memory size."); - } kernel<<>>( - reinterpret_cast(input.dptr), - reinterpret_cast(global_amax.dptr), - reinterpret_cast(output.dptr), - reinterpret_cast(output_t.dptr), - reinterpret_cast(scale_inv.dptr), - reinterpret_cast(scale_inv_t.dptr), - row_length, num_rows, scale_stride_x, - scale_stride_y, scale_t_stride_x, - scale_t_stride_y, kScaleBlockDim, epsilon, - rng_state, - noop_ptr);) // kUseFastMath - ) // kUse4Over6 - ) // kRowScaledNVFP4 - ) // kIs2DBlockScaling - ) // kApplyStochasticRounding - ) // kSwizzledScale - ) // kAligned - ) // kReturnTranspose - ) // kReturnIdentity - ) // OutputType - ) // InputType + TRANSFORMER_ENGINE_NVFP4_4OVER6_ERR_MODE_SWITCH( + use_4over6_err_mode, k4Over6ErrMode, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_4over6_err_fast_math, kUse4Over6ErrFastMath, + + size_t smem_bytes = kSMemSize * sizeof(InputType); + auto kernel = + block_scaled_1d_cast_transpose_kernel< + kReturnIdentity, kReturnTranspose, + kPow2Scale, kAligned, float, InputType, + OutputType, ScaleType, kSwizzledScale, + kApplyStochasticRounding, + kIs2DBlockScaling, kRowScaledNVFP4, + kUse4Over6, k4Over6ErrMode, + kUse4Over6ErrFastMath>; + if (smem_bytes >= 48 * 1024) { + cudaError_t err = cudaFuncSetAttribute( + kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_bytes); + NVTE_CHECK(err == cudaSuccess, + "Failed to set dynamic shared " + "memory size."); + } kernel<<>>( + reinterpret_cast( + input.dptr), + reinterpret_cast( + global_amax.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast( + scale_inv_t.dptr), + row_length, num_rows, scale_stride_x, + scale_stride_y, scale_t_stride_x, + scale_t_stride_y, kScaleBlockDim, epsilon, + rng_state, + noop_ptr);) // kUse4Over6ErrFastMath + ) // k4Over6ErrMode + ) // kUse4Over6 + ) // kRowScaledNVFP4 + ) // kIs2DBlockScaling + ) // kApplyStochasticRounding + ) // kSwizzledScale + ) // kAligned + ) // kReturnTranspose + ) // kReturnIdentity + ) // OutputType + ) // InputType NVTE_CHECK_CUDA(cudaGetLastError()); #else diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 2f411b3bd6..e97c14e68e 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -329,6 +329,7 @@ class NVFP4Quantizer : public Quantizer { bool stochastic_rounding; // Whether emitted NVFP4 tensors use 4over6 candidate selection. bool use_4over6; + NVTENVFP44Over6ErrMode nvfp4_4over6_err_mode; // Whether tensors emitted by this quantizer use row-scaled NVFP4 metadata. bool row_scaled_nvfp4; diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 6d05955148..83ac117e2c 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -1038,6 +1038,13 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, num_tensors, need_stochastic_rounding, with_bulk_generate_rng_states, need_separate_rng_states, quant_config_list, quant_config_list_colwise); + for (auto &config : quant_config_list) { + config.set_nvfp4_4over6_err_mode(quantizer.nvfp4_4over6_err_mode); + } + for (auto &config : quant_config_list_colwise) { + config.set_nvfp4_4over6_err_mode(quantizer.nvfp4_4over6_err_mode); + } + // Enable NVFP4 kernels to use math operations that sacrifice // accuracy for performance. These optimizations are experimental // and inconsistently implemented. @@ -1045,8 +1052,10 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, // 1. replace 1 / x by reciprocal_approximate_ftz(x) // 2. when RHT cast fusion is available, fusion allows cast to be performed on FP32 data, // this will essentially remove a round trip between FP32 to BF16 then FP32 + // NVFP4 4over6 candidate error math is controlled separately by + // NVTE_NVFP4_4OVER6_ERR_FAST_MATH. const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); - if (use_fast_math) { + if (use_fast_math && !quantizer.use_4over6) { for (auto &config : quant_config_list) { config.set_use_fast_math(true); } @@ -1055,6 +1064,17 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, } } + const auto use_4over6_err_fast_math = + transformer_engine::getenv("NVTE_NVFP4_4OVER6_ERR_FAST_MATH"); + if (use_4over6_err_fast_math) { + for (auto &config : quant_config_list) { + config.set_nvfp4_4over6_err_fast_math(true); + } + for (auto &config : quant_config_list_colwise) { + config.set_nvfp4_4over6_err_fast_math(true); + } + } + auto &quant_config_list_colwise_to_use = need_separate_rng_states ? quant_config_list_colwise : quant_config_list; @@ -1199,16 +1219,26 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, for (auto &config : quant_config_list) { config.set_nvfp4_4over6(quantizer.use_4over6); + config.set_nvfp4_4over6_err_mode(quantizer.nvfp4_4over6_err_mode); } - // Fast math affects the 4over6 MSE computation when 4over6 is enabled + // NVFP4 4over6 candidate error math is controlled separately by + // NVTE_NVFP4_4OVER6_ERR_FAST_MATH. const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); - if (use_fast_math) { + if (use_fast_math && !quantizer.use_4over6) { for (auto &config : quant_config_list) { config.set_use_fast_math(true); } } + const auto use_4over6_err_fast_math = + transformer_engine::getenv("NVTE_NVFP4_4OVER6_ERR_FAST_MATH"); + if (use_4over6_err_fast_math) { + for (auto &config : quant_config_list) { + config.set_nvfp4_4over6_err_fast_math(true); + } + } + // We need: // 1. Rowwise amax = amax for input // 2. Columnwise amax = amax for input too diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index b18a423864..5a0759a99c 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1730,6 +1730,14 @@ NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantize this->with_2d_quantization = quantizer.attr("with_2d_quantization").cast(); this->stochastic_rounding = quantizer.attr("stochastic_rounding").cast(); this->use_4over6 = quantizer.attr("use_4over6").cast(); + const auto nvfp4_4over6_err_mode = quantizer.attr("four_over_six_err_mode").cast(); + if (nvfp4_4over6_err_mode == "MAE") { + this->nvfp4_4over6_err_mode = kNVTENVFP44Over6ErrMAE; + } else if (nvfp4_4over6_err_mode == "MSE") { + this->nvfp4_4over6_err_mode = kNVTENVFP44Over6ErrMSE; + } else { + NVTE_ERROR("Unsupported NVFP4 4over6 error mode: ", nvfp4_4over6_err_mode); + } this->row_scaled_nvfp4 = quantizer.attr("row_scaled_nvfp4").cast(); // Get amax reduction group if needed for NVFP4 AG @@ -2295,6 +2303,9 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou quant_config.set_nvfp4_2d_quantization(this->with_2d_quantization); quant_config.set_stochastic_rounding(this->stochastic_rounding); quant_config.set_nvfp4_4over6(this->use_4over6); + quant_config.set_nvfp4_4over6_err_mode(this->nvfp4_4over6_err_mode); + quant_config_columnwise.set_nvfp4_4over6(this->use_4over6); + quant_config_columnwise.set_nvfp4_4over6_err_mode(this->nvfp4_4over6_err_mode); if (this->use_4over6) { NVTE_CHECK(!this->with_rht, "NVFP4 4over6 quantization does not support RHT."); @@ -2441,12 +2452,21 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou // 1. replace 1 / x by reciprocal_approximate_ftz(x) // 2. when RHT cast fusion is available, fusion allows cast to be performed on FP32 data, // this will essentially remove a round trip between FP32 to BF16 then FP32 + // NVFP4 4over6 candidate error math is controlled separately by + // NVTE_NVFP4_4OVER6_ERR_FAST_MATH. const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); - if (use_fast_math) { + if (use_fast_math && !this->use_4over6) { quant_config.set_use_fast_math(true); quant_config_columnwise.set_use_fast_math(true); } + const auto use_4over6_err_fast_math = + transformer_engine::getenv("NVTE_NVFP4_4OVER6_ERR_FAST_MATH"); + if (use_4over6_err_fast_math) { + quant_config.set_nvfp4_4over6_err_fast_math(true); + quant_config_columnwise.set_nvfp4_4over6_err_fast_math(true); + } + if (this->with_rht) { if (eligible_for_rht_cast_fusion) { // fusion kernel requires passing in RHT matrix directly for maximum performance diff --git a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py index d1c30f84f5..a6c1bffe60 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py @@ -352,9 +352,13 @@ def __init__( quant_tile_shape: Tuple[int, int] = (1, 16), row_scaled_nvfp4: bool = False, use_4over6: bool = False, + four_over_six_err_mode: str = "MAE", with_rht: bool = False, with_random_sign_mask: bool = True, ): + four_over_six_err_mode = four_over_six_err_mode.upper() + if four_over_six_err_mode not in ("MAE", "MSE"): + raise ValueError("four_over_six_err_mode must be 'MAE' or 'MSE'.") if row_scaled_nvfp4: if not rowwise: raise ValueError("Row-scaled NVFP4 reference quantization requires rowwise usage.") @@ -378,6 +382,7 @@ def __init__( self.quant_tile_shape = quant_tile_shape self.row_scaled_nvfp4 = row_scaled_nvfp4 self.use_4over6 = use_4over6 + self.four_over_six_err_mode = four_over_six_err_mode self.with_rht = with_rht self.with_random_sign_mask = with_random_sign_mask @@ -465,12 +470,13 @@ def _quantize_blockwise_4over6_reference( global_decode_scale: torch.Tensor, row_scaled_nvfp4: bool, tile_len_y: int, + four_over_six_err_mode: str, ) -> Tuple[torch.Tensor, torch.Tensor]: """Quantize NVFP4 with 4over6 candidate selection. This mirrors the CUDA path: map-to-4 uses a 1.5x expanded E4M3 block scale, - MSE is computed in the original input domain with the 6 * 256 denominator, - and ties choose map-to-6. + the configured error is computed in the original input domain with the + 6 * 256 denominator, and ties choose map-to-6. """ m, num_blocks, tile_len_x = x.shape n = num_blocks * tile_len_x @@ -521,24 +527,30 @@ def _quantize_blockwise_4over6_reference( sf_map4 = decode_scale_map4.to(torch.float32).squeeze(-1) sf_map6 = decode_scale_map6.to(torch.float32).squeeze(-1) if row_scaled_nvfp4: - mse_global_amax = global_amax.squeeze(-1) + error_global_amax = global_amax.squeeze(-1) else: - mse_global_amax = global_amax + error_global_amax = global_amax x_float = x.to(torch.float32) err_map4 = torch.zeros_like(vec_max) err_map6 = torch.zeros_like(vec_max) for idx in range(tile_len_x): val_map4 = fp4_map4[:, :, idx] * sf_map4 - val_map4 = val_map4 * mse_global_amax + val_map4 = val_map4 * error_global_amax val_map4 = val_map4 / denom diff_map4 = val_map4 - x_float[:, :, idx] - err_map4 = err_map4 + (diff_map4 * diff_map4).unsqueeze(-1) + if four_over_six_err_mode == "MSE": + err_map4 = err_map4 + (diff_map4 * diff_map4).unsqueeze(-1) + else: + err_map4 = err_map4 + torch.abs(diff_map4).unsqueeze(-1) val_map6 = fp4_map6[:, :, idx] * sf_map6 - val_map6 = val_map6 * mse_global_amax + val_map6 = val_map6 * error_global_amax val_map6 = val_map6 / denom diff_map6 = val_map6 - x_float[:, :, idx] - err_map6 = err_map6 + (diff_map6 * diff_map6).unsqueeze(-1) + if four_over_six_err_mode == "MSE": + err_map6 = err_map6 + (diff_map6 * diff_map6).unsqueeze(-1) + else: + err_map6 = err_map6 + torch.abs(diff_map6).unsqueeze(-1) if tile_len_y == 1: pick_map4 = err_map4 < err_map6 else: @@ -564,6 +576,7 @@ def _quantize_blockwise_reference( pow_2_scales: bool, row_scaled_nvfp4: bool = False, use_4over6: bool = False, + four_over_six_err_mode: str = "MAE", eps: float, # pylint: disable=unused-argument ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -632,7 +645,7 @@ def _quantize_blockwise_reference( global_decode_scale = torch.div(1.0, global_encode_scale) if use_4over6: # FourOverSix compares map-to-4 and map-to-6 candidates using - # the original input-domain MSE, while keeping TE-style FP4 + # the configured original input-domain error, while keeping TE-style FP4 # quantization for each candidate. return cls._quantize_blockwise_4over6_reference( x, @@ -642,6 +655,7 @@ def _quantize_blockwise_reference( global_decode_scale, row_scaled_nvfp4, tile_len_y, + four_over_six_err_mode, ) global_encode_scale_multiplier = global_encode_scale * torch.reciprocal(FLOAT4_E2M1_MAX) @@ -805,6 +819,7 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ pow_2_scales=self.pow_2_scales, row_scaled_nvfp4=self.row_scaled_nvfp4, use_4over6=self.use_4over6, + four_over_six_err_mode=self.four_over_six_err_mode, eps=self.eps, ) if transpose_scales: @@ -829,6 +844,7 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ self.quant_tile_shape[0], pow_2_scales=self.pow_2_scales, use_4over6=self.use_4over6, + four_over_six_err_mode=self.four_over_six_err_mode, eps=self.eps, ) diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index da13991cf4..403c5cbc0d 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -1676,6 +1676,7 @@ def _make(tensor_type: str) -> NVFP4Quantizer: and self.recipe.row_scaled_activation ), use_4over6=use_4over6, + four_over_six_err_mode=self.recipe.nvfp4_4over6_err_mode, ) if self.mode not in ("forward", "backward"): diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 625977ba48..c4def40e4f 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -132,6 +132,8 @@ class NVFP4Quantizer(Quantizer): row_scaled_nvfp4: bool """Whether to use NVFP4 4over6 map-to-4/map-to-6 block selection.""" use_4over6: bool + """NVFP4 4over6 candidate-selection error mode.""" + four_over_six_err_mode: str """RHT matrix random sign mask""" rht_matrix_random_sign_mask_t: int @@ -150,6 +152,7 @@ def __init__( stochastic_rounding: bool = False, row_scaled_nvfp4: bool = False, use_4over6: bool = False, + four_over_six_err_mode: str = "MAE", with_random_sign_mask: bool = True, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) @@ -162,6 +165,9 @@ def __init__( self.stochastic_rounding = stochastic_rounding self.row_scaled_nvfp4 = row_scaled_nvfp4 self.use_4over6 = use_4over6 + self.four_over_six_err_mode = four_over_six_err_mode.upper() + if self.four_over_six_err_mode not in ("MAE", "MSE"): + raise ValueError("four_over_six_err_mode must be 'MAE' or 'MSE'.") self.rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht( with_random_sign_mask, torch.cuda.current_device() ) From 0392708e62d3019807359c121929c7fc2daa6d0a Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 12 May 2026 12:59:34 -0700 Subject: [PATCH 46/57] Harden cpp test Signed-off-by: Ziang Li --- .../cpp/operator/test_cast_nvfp4_transpose.cu | 574 +++++++++++------- 1 file changed, 369 insertions(+), 205 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 7a10c50275..2c1a80cc7a 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -6,6 +6,7 @@ #include #include +#include #include #include #include @@ -90,10 +91,27 @@ struct NVFP4FourOverSixQuantization { float reciprocal_map6; fp4e2m1x2 quantized_map4; fp4e2m1x2 quantized_map6; - float error_map4; - float error_map6; }; +enum class NVFP4FourOverSixCandidate { + Map4, + Map6, +}; + +enum class NVFP4ScalingMode { + Block1D, + RowScaled1D, + Block2D, +}; + +bool use_2d_quantization(const NVFP4ScalingMode scaling_mode) { + return scaling_mode == NVFP4ScalingMode::Block2D; +} + +bool row_scaled_nvfp4(const NVFP4ScalingMode scaling_mode) { + return scaling_mode == NVFP4ScalingMode::RowScaled1D; +} + NVFP4FourOverSixQuantization compute_4over6_quantization_scales( const float block_amax, const float global_encode_scale) { constexpr float fp4_max = 6.0f; @@ -126,49 +144,34 @@ NVFP4FourOverSixQuantization compute_4over6_quantization_scales( reciprocal_map6, fp4e2m1x2(zero), fp4e2m1x2(zero), - 0.0f, - 0.0f, }; } -float compute_4over6_dequantized_value(const double quantized_value, - const fp8e4m3 scale, - const float global_amax) { - constexpr float error_scale = 1.0f / (6.0f * 256.0f); - return static_cast(quantized_value) * static_cast(scale) * global_amax * - error_scale; +fp8e4m3 select_4over6_scale(const NVFP4FourOverSixQuantization& quantization, + const NVFP4FourOverSixCandidate candidate) { + if (candidate == NVFP4FourOverSixCandidate::Map4) { + return quantization.scale_map4; + } + return quantization.scale_map6; } -float compute_4over6_error(const float value, const float reference, - const NVTENVFP44Over6ErrMode err_mode) { - const float diff = value - reference; - if (err_mode == kNVTENVFP44Over6ErrMSE) { - return diff * diff; +fp4e2m1x2 select_4over6_quantized_pair(const NVFP4FourOverSixQuantization& quantization, + const NVFP4FourOverSixCandidate candidate) { + if (candidate == NVFP4FourOverSixCandidate::Map4) { + return quantization.quantized_map4; } - NVTE_CHECK(err_mode == kNVTENVFP44Over6ErrMAE, "Unsupported NVFP4 4over6 error mode."); - return fabsf(diff); + return quantization.quantized_map6; } NVFP4FourOverSixQuantization quantize_4over6_pair( - const float x, const float y, const NVFP4FourOverSixQuantization& quantization, - const float global_amax, const NVTENVFP44Over6ErrMode err_mode) { + const float x, const float y, const NVFP4FourOverSixQuantization& quantization) { const float2 scaled_map4 = {x * quantization.reciprocal_map4, y * quantization.reciprocal_map4}; const fp4e2m1x2 quantized_map4(scaled_map4); - const double2 truncated_map4 = cvt_fp4x2_to_double2(quantized_map4); - const float dequant_x_map4 = - compute_4over6_dequantized_value(truncated_map4.x, quantization.scale_map4, global_amax); - const float dequant_y_map4 = - compute_4over6_dequantized_value(truncated_map4.y, quantization.scale_map4, global_amax); const float2 scaled_map6 = {x * quantization.reciprocal_map6, y * quantization.reciprocal_map6}; const fp4e2m1x2 quantized_map6(scaled_map6); - const double2 truncated_map6 = cvt_fp4x2_to_double2(quantized_map6); - const float dequant_x_map6 = - compute_4over6_dequantized_value(truncated_map6.x, quantization.scale_map6, global_amax); - const float dequant_y_map6 = - compute_4over6_dequantized_value(truncated_map6.y, quantization.scale_map6, global_amax); return { quantization.scale_map4, @@ -177,10 +180,6 @@ NVFP4FourOverSixQuantization quantize_4over6_pair( quantization.reciprocal_map6, quantized_map4, quantized_map6, - compute_4over6_error(dequant_x_map4, x, err_mode) + - compute_4over6_error(dequant_y_map4, y, err_mode), - compute_4over6_error(dequant_x_map6, x, err_mode) + - compute_4over6_error(dequant_y_map6, y, err_mode), }; } @@ -196,7 +195,8 @@ void quantize_nvfp4_1d(float (*OP)(const float), const float global_amax, const bool use_fast_math, const bool use_4over6 = false, - const NVTENVFP44Over6ErrMode err_mode = kNVTENVFP44Over6ErrMAE) { + const NVFP4FourOverSixCandidate four_over_six_candidate = + NVFP4FourOverSixCandidate::Map6) { // Compute a global encoding/decoding scaling factor for all S_dec_b const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math, use_4over6); @@ -234,39 +234,18 @@ void quantize_nvfp4_1d(float (*OP)(const float), if (use_4over6) { const NVFP4FourOverSixQuantization quantization = compute_4over6_quantization_scales(block_amax, S_enc); - - std::array output_map6; - std::array output_map4; - float err_map6 = 0.0f; - float err_map4 = 0.0f; + scales[scale_idx] = select_4over6_scale(quantization, four_over_six_candidate); for (size_t j = j_min; j < j_max; j += 2) { + const int idx_pair = (i * cols + j) / 2; const int cache_idx_x = j - j_min; const int cache_idx_y = cache_idx_x + 1; const float cached_x = cache_buffer[cache_idx_x]; const float cached_y = cache_buffer[cache_idx_y]; const NVFP4FourOverSixQuantization pair_quantization = - quantize_4over6_pair(cached_x, cached_y, quantization, global_amax, - err_mode); - output_map4[cache_idx_x / 2] = pair_quantization.quantized_map4; - output_map6[cache_idx_x / 2] = pair_quantization.quantized_map6; - err_map4 += pair_quantization.error_map4; - err_map6 += pair_quantization.error_map6; - } - - const bool pick_map4 = err_map4 < err_map6; - if (pick_map4) { - scales[scale_idx] = quantization.scale_map4; - for (size_t j = j_min; j < j_max; j += 2) { - const int idx_pair = (i * cols + j) / 2; - output[idx_pair] = output_map4[(j - j_min) / 2]; - } - } else { - scales[scale_idx] = quantization.scale_map6; - for (size_t j = j_min; j < j_max; j += 2) { - const int idx_pair = (i * cols + j) / 2; - output[idx_pair] = output_map6[(j - j_min) / 2]; - } + quantize_4over6_pair(cached_x, cached_y, quantization); + output[idx_pair] = + select_4over6_quantized_pair(pair_quantization, four_over_six_candidate); } continue; } @@ -317,8 +296,8 @@ void compute_2d_mathematical_scales(float (*OP)(const float), std::vector>& math_scales, const bool use_fast_math, const bool use_4over6 = false, - const NVTENVFP44Over6ErrMode err_mode = - kNVTENVFP44Over6ErrMAE) { + const NVFP4FourOverSixCandidate four_over_six_candidate = + NVFP4FourOverSixCandidate::Map6) { const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math, use_4over6); constexpr size_t block_size_Y = 16; @@ -351,33 +330,8 @@ void compute_2d_mathematical_scales(float (*OP)(const float), if (use_4over6) { const NVFP4FourOverSixQuantization quantization = compute_4over6_quantization_scales(block_amax, S_enc); - - float err_map6 = 0.0f; - float err_map4 = 0.0f; - for (size_t i = i_min; i < i_max; ++i) { - for (size_t j = j_min; j < j_max; j += 2) { - const float input_x = static_cast(input[i * cols + j]); - const float act_x = OP(input_x); - const float cached_x = static_cast(static_cast(act_x)); - float cached_y = 0.0f; - if (j + 1 < j_max) { - const float input_y = static_cast(input[i * cols + j + 1]); - const float act_y = OP(input_y); - cached_y = static_cast(static_cast(act_y)); - } - const NVFP4FourOverSixQuantization pair_quantization = - quantize_4over6_pair(cached_x, cached_y, quantization, global_amax, - err_mode); - err_map4 += pair_quantization.error_map4; - err_map6 += pair_quantization.error_map6; - } - } - - if (err_map4 < err_map6) { - math_scales[block_Y][block_X] = quantization.scale_map4; - } else { - math_scales[block_Y][block_X] = quantization.scale_map6; - } + math_scales[block_Y][block_X] = + select_4over6_scale(quantization, four_over_six_candidate); } else { const float S_dec_b = block_amax / 6.0f * S_enc; const fp8e4m3 S_dec_b_fp8_map6 = static_cast(S_dec_b); @@ -399,12 +353,13 @@ void quantize_nvfp4_2d(float (*OP)(const float), const float global_amax, const bool use_fast_math, const bool use_4over6 = false, - const NVTENVFP44Over6ErrMode err_mode = kNVTENVFP44Over6ErrMAE) { + const NVFP4FourOverSixCandidate four_over_six_candidate = + NVFP4FourOverSixCandidate::Map6) { // Step 1: Compute mathematical 8x8 scaling factors std::vector> math_scales; compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math, - use_4over6, err_mode); + use_4over6, four_over_six_candidate); const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math, use_4over6); constexpr size_t block_size_Y = 16; @@ -436,7 +391,7 @@ void quantize_nvfp4_2d(float (*OP)(const float), // Get the scaling factor for this block const float S_dec_b_fp8 = static_cast(math_scales[block_Y][block_X]); - const float S_enc_b_fp8 = S_dec_b_fp8 == 0 ? 0.f : S_enc / S_dec_b_fp8; + const float S_enc_b_fp8 = S_dec_b_fp8 == 0.0f ? 0.0f : S_enc / S_dec_b_fp8; const float scale_reciprocal = S_enc_b_fp8; // Process and cache data for this 16x16 block @@ -490,13 +445,14 @@ void quantize_nvfp4(float (*OP)(const float), const bool use_fast_math, const bool use_2d_quantization = false, const bool use_4over6 = false, - const NVTENVFP44Over6ErrMode err_mode = kNVTENVFP44Over6ErrMAE) { + const NVFP4FourOverSixCandidate four_over_six_candidate = + NVFP4FourOverSixCandidate::Map6) { if (use_2d_quantization) { quantize_nvfp4_2d(OP, input, output, scales, rows, cols, scales_stride, global_amax, - use_fast_math, use_4over6, err_mode); + use_fast_math, use_4over6, four_over_six_candidate); } else { quantize_nvfp4_1d(OP, input, output, scales, rows, cols, scales_stride, global_amax, - use_fast_math, use_4over6, err_mode); + use_fast_math, use_4over6, four_over_six_candidate); } } @@ -516,7 +472,8 @@ void compute_ref(float (*OP)(const float), const bool use_2d_quantization = false, const bool row_scaled_nvfp4 = false, const bool use_4over6 = false, - const NVTENVFP44Over6ErrMode err_mode = kNVTENVFP44Over6ErrMAE) + const NVFP4FourOverSixCandidate four_over_six_candidate = + NVFP4FourOverSixCandidate::Map6) { std::vector input_t = create_transpose(input, rows, cols); NVTE_CHECK(!(use_2d_quantization && row_scaled_nvfp4), @@ -527,7 +484,7 @@ void compute_ref(float (*OP)(const float), // Step 1: Compute mathematical 8×8 scaling factors std::vector> math_scales; compute_2d_mathematical_scales(OP, input, rows, cols, *amax, math_scales, use_fast_math, - use_4over6, err_mode); + use_4over6, four_over_six_candidate); constexpr size_t block_size_Y = 16; constexpr size_t block_size_X = 16; @@ -555,9 +512,11 @@ void compute_ref(float (*OP)(const float), // Step 4: Process quantized outputs using the same algorithm as quantize_nvfp4_2d // (This part processes the actual FP4 data using the mathematical scaling factors) quantize_nvfp4_2d(OP, input, output, nullptr, rows, cols, scales_stride, *amax, - use_fast_math, use_4over6, err_mode); // scales already filled + use_fast_math, use_4over6, + four_over_six_candidate); // scales already filled quantize_nvfp4_2d(OP, input_t.data(), output_t, nullptr, cols, rows, scales_stride_t, *amax, - use_fast_math, use_4over6, err_mode); // scales_t already filled + use_fast_math, use_4over6, + four_over_six_candidate); // scales_t already filled return; } @@ -576,16 +535,16 @@ void compute_ref(float (*OP)(const float), use_fast_math, use_2d_quantization, use_4over6, - err_mode); + four_over_six_candidate); } return; } // Ref impl for basic NVFP4 quantize_nvfp4(OP, input, output, scales, rows, cols, scales_stride, *amax, - use_fast_math, use_2d_quantization, use_4over6, err_mode); + use_fast_math, use_2d_quantization, use_4over6, four_over_six_candidate); quantize_nvfp4(OP, input_t.data(), output_t, scales_t, cols, rows, scales_stride_t, *amax, - use_fast_math, use_2d_quantization, use_4over6, err_mode); + use_fast_math, use_2d_quantization, use_4over6, four_over_six_candidate); } void compare_nvfp4_tensors(const std::string& name, @@ -710,6 +669,131 @@ void compareResults_nvfp4(Tensor &test, } } +template +bool bitwise_equal(const T& x, const T& y) { + return std::memcmp(&x, &y, sizeof(T)) == 0; +} + +bool nvfp4_output_block_matches(const fp4e2m1x2* const test_data, + const fp4e2m1x2* const ref_data, + const size_t row, + const size_t cols, + const size_t block_x) { + constexpr size_t block_size_X = 16; + const size_t j_min = block_x * block_size_X; + const size_t j_max = std::min(j_min + block_size_X, cols); + for (size_t j = j_min; j < j_max; j += 2) { + const size_t idx_pair = (row * cols + j) / 2; + if (!bitwise_equal(test_data[idx_pair], ref_data[idx_pair])) { + return false; + } + } + return true; +} + +void compare_nvfp4_4over6_candidates(const std::string& name, + const fp4e2m1* const test_data, + const fp8e4m3* const test_scales, + const fp4e2m1x2* const ref_data_map4, + const fp8e4m3* const ref_scales_map4, + const fp4e2m1x2* const ref_data_map6, + const fp8e4m3* const ref_scales_map6, + const size_t rows, + const size_t cols, + const size_t blocks_X, + const size_t scales_stride) { + constexpr int max_mismatches_to_print = 3; + const auto* const test_data_pairs = reinterpret_cast(test_data); + size_t total_mismatches = 0; + + for (size_t row = 0; row < rows; ++row) { + for (size_t block_x = 0; block_x < blocks_X; ++block_x) { + const size_t scale_idx = row * scales_stride + block_x; + const bool scale_matches_map4 = + bitwise_equal(test_scales[scale_idx], ref_scales_map4[scale_idx]); + const bool data_matches_map4 = + nvfp4_output_block_matches(test_data_pairs, ref_data_map4, row, cols, block_x); + const bool scale_matches_map6 = + bitwise_equal(test_scales[scale_idx], ref_scales_map6[scale_idx]); + const bool data_matches_map6 = + nvfp4_output_block_matches(test_data_pairs, ref_data_map6, row, cols, block_x); + + if ((scale_matches_map4 && data_matches_map4) || + (scale_matches_map6 && data_matches_map6)) { + continue; + } + + ++total_mismatches; + if (total_mismatches <= max_mismatches_to_print) { + std::cout << "Error in tensor " << name << ": 4over6 block mismatch at row " + << row << ", block_x " << block_x + << ". The output did not match either map-to-4 or map-to-6 exactly." + << std::endl; + } + } + } + + std::cout << "=== SUMMARY for tensor " << name << " ===" << std::endl; + std::cout << "Total 4over6 blocks checked: " << (rows * blocks_X) << std::endl; + if (total_mismatches > 0) { + std::cout << "STATUS: FAILED for output" << std::endl; + std::cout << "Total mismatched 4over6 blocks found: " << total_mismatches << std::endl; + std::cout << "============================" << std::endl; + GTEST_FAIL() << "Found " << total_mismatches << " 4over6 block mismatches in tensor " + << name; + } + + std::cout << "STATUS: PASSED for output" << std::endl; + std::cout << "Each 4over6 block matched either map-to-4 or map-to-6 exactly" << std::endl; + std::cout << "============================" << std::endl; +} + +void compareResults_nvfp4_4over6(Tensor& test, + const fp4e2m1x2* const ref, + const fp4e2m1x2* const ref_t, + const fp8e4m3* const ref_scales, + const fp8e4m3* const ref_scales_t, + const fp4e2m1x2* const ref_map6, + const fp4e2m1x2* const ref_t_map6, + const fp8e4m3* const ref_scales_map6, + const fp8e4m3* const ref_scales_t_map6, + const size_t rows, + const size_t cols, + const size_t blocks_X, + const size_t blocks_X_t, + const size_t scales_stride, + const size_t scales_stride_t, + const bool if_on_gpus = true, + const bool compare_columnwise = true) { + if (if_on_gpus) test.to_cpu(); + + compare_nvfp4_4over6_candidates("output", + test.rowwise_cpu_dptr(), + test.rowwise_cpu_scale_inv_ptr(), + ref, + ref_scales, + ref_map6, + ref_scales_map6, + rows, + cols, + blocks_X, + scales_stride); + + if (compare_columnwise) { + compare_nvfp4_4over6_candidates("output_t", + test.columnwise_cpu_dptr(), + test.columnwise_cpu_scale_inv_ptr(), + ref_t, + ref_scales_t, + ref_t_map6, + ref_scales_t_map6, + cols, + rows, + blocks_X_t, + scales_stride_t); + } +} + void compare_rowwise_amax(Tensor &output, const std::vector &ref_amax) { ASSERT_EQ(output.rowwise_amax_size(), ref_amax.size()); const auto *amax_ptr = output.cpu_rowwise_amax_ptr(); @@ -724,20 +808,18 @@ template void performTest(float (*OP)(const float), const std::vector& shape, const bool use_fast_math, - const bool use_2d_quantization = false, - const bool row_scaled_nvfp4 = false, + const NVFP4ScalingMode scaling_mode = NVFP4ScalingMode::Block1D, const bool use_4over6 = false, const NVTENVFP44Over6ErrMode err_mode = kNVTENVFP44Over6ErrMAE) { using namespace test; - NVTE_CHECK(!(use_2d_quantization && row_scaled_nvfp4), - "2D quantization and row-scaling are not supported together."); - DType itype = TypeInfo::dtype; DType otype = DType::kFloat4E2M1; + const bool is_2d_quantization = use_2d_quantization(scaling_mode); + const bool is_row_scaled_nvfp4 = row_scaled_nvfp4(scaling_mode); const bool rowwise = true; - const bool columnwise = !row_scaled_nvfp4; + const bool columnwise = !is_row_scaled_nvfp4; const size_t rows = first_dimension(shape); const size_t cols = last_dimension(shape); @@ -767,12 +849,49 @@ void performTest(float (*OP)(const float), std::unique_ptr ref_output_t = std::make_unique(cols * (rows / 2)); std::unique_ptr ref_scales = std::make_unique(blocks_Y * blocks_X); std::unique_ptr ref_scales_t = std::make_unique(blocks_Y_t * blocks_X_t); + std::unique_ptr ref_output_map6; + std::unique_ptr ref_output_t_map6; + std::unique_ptr ref_scales_map6; + std::unique_ptr ref_scales_t_map6; fillCase(&input, InputsFillCase::uniform); + if (use_4over6 && is_row_scaled_nvfp4) { + constexpr float target_row_amax = 256.0f * 6.0f * 8.0f; + auto *input_vals = input.rowwise_cpu_dptr(); + for (size_t row = 0; row < rows; ++row) { + float row_amax = 0.0f; + size_t max_col = 0; + for (size_t col = 0; col < cols; ++col) { + const float val = static_cast(input_vals[row * cols + col]); + const float abs_val = fabsf(val); + if (abs_val > row_amax) { + row_amax = abs_val; + max_col = col; + } + } + + if (row_amax == 0.0f) { + continue; + } + + const float row_scale = target_row_amax / row_amax; + for (size_t col = 0; col < cols; ++col) { + float scaled = static_cast(input_vals[row * cols + col]) * row_scale; + scaled = fminf(fmaxf(scaled, -target_row_amax), target_row_amax); + input_vals[row * cols + col] = static_cast(scaled); + } + + const float max_val = static_cast(input_vals[row * cols + max_col]); + input_vals[row * cols + max_col] = + static_cast(max_val < 0.0f ? -target_row_amax : target_row_amax); + } + input.from_cpu(); + } + // Compute 2nd stage NVFP4 scaling factor std::vector ref_amax; - if (row_scaled_nvfp4) { + if (is_row_scaled_nvfp4) { // Compute per-row amaxes const auto *input_vals = input.rowwise_cpu_dptr(); for (size_t row = 0; row < rows; ++row){ @@ -786,7 +905,7 @@ void performTest(float (*OP)(const float), // Update tensor // Note: No need to update amax like standard NVFP4, amaxes // are computed during quantization. - output.set_row_scaled_nvfp4(row_scaled_nvfp4); + output.set_row_scaled_nvfp4(is_row_scaled_nvfp4); } else { // Golden value of amax chosen to make the 2nd-stage scaling mantissa zero and avoid rounding issues if (use_4over6) { @@ -805,22 +924,61 @@ void performTest(float (*OP)(const float), output.from_cpu(); } - compute_ref(OP, - input.rowwise_cpu_dptr(), - ref_output.get(), - ref_output_t.get(), - ref_scales.get(), - ref_scales_t.get(), - ref_amax.data(), - rows, - cols, - scales_stride, - scales_stride_t, - use_fast_math, - use_2d_quantization, - row_scaled_nvfp4, - use_4over6, - err_mode); + if (use_4over6) { + ref_output_map6 = std::make_unique(rows * (cols / 2)); + ref_output_t_map6 = std::make_unique(cols * (rows / 2)); + ref_scales_map6 = std::make_unique(blocks_Y * blocks_X); + ref_scales_t_map6 = std::make_unique(blocks_Y_t * blocks_X_t); + + compute_ref(OP, + input.rowwise_cpu_dptr(), + ref_output.get(), + ref_output_t.get(), + ref_scales.get(), + ref_scales_t.get(), + ref_amax.data(), + rows, + cols, + scales_stride, + scales_stride_t, + use_fast_math, + is_2d_quantization, + is_row_scaled_nvfp4, + use_4over6, + NVFP4FourOverSixCandidate::Map4); + compute_ref(OP, + input.rowwise_cpu_dptr(), + ref_output_map6.get(), + ref_output_t_map6.get(), + ref_scales_map6.get(), + ref_scales_t_map6.get(), + ref_amax.data(), + rows, + cols, + scales_stride, + scales_stride_t, + use_fast_math, + is_2d_quantization, + is_row_scaled_nvfp4, + use_4over6, + NVFP4FourOverSixCandidate::Map6); + } else { + compute_ref(OP, + input.rowwise_cpu_dptr(), + ref_output.get(), + ref_output_t.get(), + ref_scales.get(), + ref_scales_t.get(), + ref_amax.data(), + rows, + cols, + scales_stride, + scales_stride_t, + use_fast_math, + is_2d_quantization, + is_row_scaled_nvfp4, + use_4over6); + } // Initialize stochastic rounding Tensor rng_state("rng_state", std::vector{2}, DType::kInt64); @@ -833,7 +991,7 @@ void performTest(float (*OP)(const float), quant_config.set_use_fast_math(use_fast_math && !use_4over6); quant_config.set_stochastic_rounding(false); quant_config.set_rng_state(rng_state.data()); - quant_config.set_nvfp4_2d_quantization(use_2d_quantization); + quant_config.set_nvfp4_2d_quantization(is_2d_quantization); quant_config.set_nvfp4_4over6(use_4over6); quant_config.set_nvfp4_4over6_err_mode(err_mode); @@ -864,21 +1022,42 @@ void performTest(float (*OP)(const float), const double atol = 1.0E-6; const double rtol = 1.0E-6; - // Set dump_data=true to enable dumping tensor data to files for analysis - compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, - false, !row_scaled_nvfp4); - - size_t scale_mismatches_num = 0; - compare_scaling_factors("scales", output.rowwise_cpu_scale_inv_ptr(), - ref_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride, - scale_mismatches_num); - - if (!row_scaled_nvfp4) { - compare_scaling_factors("scales_t", output.columnwise_cpu_scale_inv_ptr(), - ref_scales_t.get(), - unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t, + if (use_4over6) { + compareResults_nvfp4_4over6(output, + ref_output.get(), + ref_output_t.get(), + ref_scales.get(), + ref_scales_t.get(), + ref_output_map6.get(), + ref_output_t_map6.get(), + ref_scales_map6.get(), + ref_scales_t_map6.get(), + rows, + cols, + unpadded_blocks_X, + unpadded_blocks_X_t, + scales_stride, + scales_stride_t, + true, + !is_row_scaled_nvfp4); + } else { + // Set dump_data=true to enable dumping tensor data to files for analysis + compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, + true, false, !is_row_scaled_nvfp4); + + size_t scale_mismatches_num = 0; + compare_scaling_factors("scales", output.rowwise_cpu_scale_inv_ptr(), + ref_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, scale_mismatches_num); + + if (!is_row_scaled_nvfp4) { + compare_scaling_factors("scales_t", + output.columnwise_cpu_scale_inv_ptr(), + ref_scales_t.get(), + unpadded_blocks_Y_t, unpadded_blocks_X_t, + scales_stride_t, scale_mismatches_num); + } } compare_rowwise_amax(output, ref_amax); @@ -915,8 +1094,7 @@ class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam std::vector, transformer_engine::DType, bool, - bool, - bool, + NVFP4ScalingMode, bool, NVTENVFP44Over6ErrMode>> {}; @@ -933,10 +1111,9 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { const auto tensor_dims = std::get<1>(GetParam()); const DType input_type = std::get<2>(GetParam()); const bool use_fast_math = std::get<3>(GetParam()); - const bool use_2d_quantization = std::get<4>(GetParam()); - const bool row_scaled_nvfp4 = std::get<5>(GetParam()); - const bool use_4over6 = std::get<6>(GetParam()); - const NVTENVFP44Over6ErrMode err_mode = std::get<7>(GetParam()); + const NVFP4ScalingMode scaling_mode = std::get<4>(GetParam()); + const bool use_4over6 = std::get<5>(GetParam()); + const NVTENVFP44Over6ErrMode err_mode = std::get<6>(GetParam()); // Skip tests if the input tensor is 1D if (tensor_dims.size() < 2) { @@ -954,8 +1131,8 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { } TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, - performTest(OP, tensor_dims, use_fast_math, use_2d_quantization, - row_scaled_nvfp4, use_4over6, err_mode); + performTest(OP, tensor_dims, use_fast_math, scaling_mode, use_4over6, + err_mode); ); } @@ -971,6 +1148,15 @@ std::string to_string(const ActivationType Act_type) { } } +std::string to_string(const NVFP4ScalingMode scaling_mode) { + switch (scaling_mode) { + case NVFP4ScalingMode::Block1D: return ""; + case NVFP4ScalingMode::RowScaled1D: return "XROW_SCALED"; + case NVFP4ScalingMode::Block2D: return "X2D"; + default: return ""; + } +} + std::string test_name(const FusedCastTransposeNVFP4TestSuite::ParamType& param) { std::string name = to_string(std::get<0>(param)); const auto& shape = std::get<1>(param); @@ -981,15 +1167,10 @@ std::string test_name(const FusedCastTransposeNVFP4TestSuite::ParamType& param) if (std::get<3>(param)) { name += "X_FAST_SCALING"; } - if (std::get<4>(param)) { - name += "X2D"; - } + name += to_string(std::get<4>(param)); if (std::get<5>(param)) { - name += "XROW_SCALED"; - } - if (std::get<6>(param)) { name += "X4OVER6"; - if (std::get<7>(param) == kNVTENVFP44Over6ErrMSE) { + if (std::get<6>(param) == kNVTENVFP44Over6ErrMSE) { name += "XMSE"; } else { name += "XMAE"; @@ -1002,14 +1183,13 @@ INSTANTIATE_TEST_SUITE_P( OperatorTest, FusedCastTransposeNVFP4TestSuite, ::testing::Combine( - ::testing::ValuesIn(Activation_types), - ::testing::ValuesIn(tensor_dims), - ::testing::Values(DType::kBFloat16), - ::testing::Values(false), - ::testing::Values(false), - ::testing::Values(false), - ::testing::Values(false), - ::testing::Values(kNVTENVFP44Over6ErrMAE)), + ::testing::ValuesIn(Activation_types), // activation_type + ::testing::ValuesIn(tensor_dims), // tensor_dims + ::testing::Values(DType::kBFloat16), // input_type + ::testing::Values(false), // use_fast_math + ::testing::Values(NVFP4ScalingMode::Block1D), // scaling_mode + ::testing::Values(false), // use_4over6 + ::testing::Values(kNVTENVFP44Over6ErrMAE)), // err_mode [](const testing::TestParamInfo& info) { return test_name(info.param); }); @@ -1018,14 +1198,13 @@ INSTANTIATE_TEST_SUITE_P( OperatorTestRowScaled, FusedCastTransposeNVFP4TestSuite, ::testing::Combine( - ::testing::Values(ActivationType::Identity), - ::testing::Values(tensor_dims[4], tensor_dims[9], tensor_dims[12]), - ::testing::Values(DType::kBFloat16, DType::kFloat32), - ::testing::Values(false), - ::testing::Values(false), - ::testing::Values(true), - ::testing::Values(false), - ::testing::Values(kNVTENVFP44Over6ErrMAE)), + ::testing::ValuesIn(Activation_types), // activation_type + ::testing::ValuesIn(tensor_dims), // tensor_dims + ::testing::Values(DType::kBFloat16, DType::kFloat32), // input_type + ::testing::Values(false), // use_fast_math + ::testing::Values(NVFP4ScalingMode::RowScaled1D), // scaling_mode + ::testing::Values(false), // use_4over6 + ::testing::Values(kNVTENVFP44Over6ErrMAE)), // err_mode [](const testing::TestParamInfo& info) { return test_name(info.param); }); @@ -1034,46 +1213,31 @@ INSTANTIATE_TEST_SUITE_P( OperatorTest4Over6, FusedCastTransposeNVFP4TestSuite, ::testing::Combine( - ::testing::Values(ActivationType::Identity), - ::testing::Values(tensor_dims[4], tensor_dims[9], tensor_dims[12]), - ::testing::Values(DType::kBFloat16, DType::kFloat32), - ::testing::Values(false), - ::testing::Values(false), - ::testing::Values(false), - ::testing::Values(true), - ::testing::Values(kNVTENVFP44Over6ErrMAE, kNVTENVFP44Over6ErrMSE)), - [](const testing::TestParamInfo& info) { - return test_name(info.param); - }); - -INSTANTIATE_TEST_SUITE_P( - OperatorTest4Over6RowScaled, - FusedCastTransposeNVFP4TestSuite, - ::testing::Combine( - ::testing::Values(ActivationType::Identity), - ::testing::Values(tensor_dims[4], tensor_dims[9], tensor_dims[12]), - ::testing::Values(DType::kFloat32), - ::testing::Values(false), - ::testing::Values(false), - ::testing::Values(true), - ::testing::Values(true), - ::testing::Values(kNVTENVFP44Over6ErrMAE, kNVTENVFP44Over6ErrMSE)), + ::testing::ValuesIn(Activation_types), // activation_type + ::testing::ValuesIn(tensor_dims), // tensor_dims + ::testing::Values(DType::kBFloat16, DType::kFloat32), // input_type + ::testing::Values(false), // use_fast_math + ::testing::Values(NVFP4ScalingMode::Block1D, + NVFP4ScalingMode::Block2D), // scaling_mode + ::testing::Values(true), // use_4over6 + ::testing::Values(kNVTENVFP44Over6ErrMAE, + kNVTENVFP44Over6ErrMSE)), // err_mode [](const testing::TestParamInfo& info) { return test_name(info.param); }); INSTANTIATE_TEST_SUITE_P( - OperatorTest4Over62D, + OperatorTestRowScaled4Over6, FusedCastTransposeNVFP4TestSuite, ::testing::Combine( - ::testing::Values(ActivationType::Identity), - ::testing::Values(tensor_dims[4], tensor_dims[9], tensor_dims[12]), - ::testing::Values(DType::kBFloat16), - ::testing::Values(false), - ::testing::Values(true), - ::testing::Values(false), - ::testing::Values(true), - ::testing::Values(kNVTENVFP44Over6ErrMAE, kNVTENVFP44Over6ErrMSE)), + ::testing::ValuesIn(Activation_types), // activation_type + ::testing::ValuesIn(tensor_dims), // tensor_dims + ::testing::Values(DType::kFloat32), // input_type + ::testing::Values(false), // use_fast_math + ::testing::Values(NVFP4ScalingMode::RowScaled1D), // scaling_mode + ::testing::Values(true), // use_4over6 + ::testing::Values(kNVTENVFP44Over6ErrMAE, + kNVTENVFP44Over6ErrMSE)), // err_mode [](const testing::TestParamInfo& info) { return test_name(info.param); }); From 0b77a373ade589f82b73a209778fd5cbb370277e Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 12 May 2026 13:28:21 -0700 Subject: [PATCH 47/57] Add warning and err fast math coverage Signed-off-by: Ziang Li --- .../cpp/operator/test_cast_nvfp4_transpose.cu | 130 ++++++++---------- 1 file changed, 58 insertions(+), 72 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 2c1a80cc7a..8e8d4f658a 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -6,7 +6,6 @@ #include #include -#include #include #include #include @@ -671,7 +670,14 @@ void compareResults_nvfp4(Tensor &test, template bool bitwise_equal(const T& x, const T& y) { - return std::memcmp(&x, &y, sizeof(T)) == 0; + const auto *x_bytes = reinterpret_cast(&x); + const auto *y_bytes = reinterpret_cast(&y); + for (size_t i = 0; i < sizeof(T); ++i) { + if (x_bytes[i] != y_bytes[i]) { + return false; + } + } + return true; } bool nvfp4_output_block_matches(const fp4e2m1x2* const test_data, @@ -748,52 +754,6 @@ void compare_nvfp4_4over6_candidates(const std::string& name, std::cout << "============================" << std::endl; } -void compareResults_nvfp4_4over6(Tensor& test, - const fp4e2m1x2* const ref, - const fp4e2m1x2* const ref_t, - const fp8e4m3* const ref_scales, - const fp8e4m3* const ref_scales_t, - const fp4e2m1x2* const ref_map6, - const fp4e2m1x2* const ref_t_map6, - const fp8e4m3* const ref_scales_map6, - const fp8e4m3* const ref_scales_t_map6, - const size_t rows, - const size_t cols, - const size_t blocks_X, - const size_t blocks_X_t, - const size_t scales_stride, - const size_t scales_stride_t, - const bool if_on_gpus = true, - const bool compare_columnwise = true) { - if (if_on_gpus) test.to_cpu(); - - compare_nvfp4_4over6_candidates("output", - test.rowwise_cpu_dptr(), - test.rowwise_cpu_scale_inv_ptr(), - ref, - ref_scales, - ref_map6, - ref_scales_map6, - rows, - cols, - blocks_X, - scales_stride); - - if (compare_columnwise) { - compare_nvfp4_4over6_candidates("output_t", - test.columnwise_cpu_dptr(), - test.columnwise_cpu_scale_inv_ptr(), - ref_t, - ref_scales_t, - ref_t_map6, - ref_scales_t_map6, - cols, - rows, - blocks_X_t, - scales_stride_t); - } -} - void compare_rowwise_amax(Tensor &output, const std::vector &ref_amax) { ASSERT_EQ(output.rowwise_amax_size(), ref_amax.size()); const auto *amax_ptr = output.cpu_rowwise_amax_ptr(); @@ -810,9 +770,17 @@ void performTest(float (*OP)(const float), const bool use_fast_math, const NVFP4ScalingMode scaling_mode = NVFP4ScalingMode::Block1D, const bool use_4over6 = false, - const NVTENVFP44Over6ErrMode err_mode = kNVTENVFP44Over6ErrMAE) { + const NVTENVFP44Over6ErrMode err_mode = kNVTENVFP44Over6ErrMAE, + const bool use_4over6_err_fast_math = false) { using namespace test; + if (use_4over6 && use_fast_math) { + std::cout << "WARNING: Plain NVFP4 fast math is ignored for 4over6. " + "Use use_4over6_err_fast_math to test the 4over6 candidate " + "error fast-math path." + << std::endl; + } + DType itype = TypeInfo::dtype; DType otype = DType::kFloat4E2M1; @@ -994,6 +962,7 @@ void performTest(float (*OP)(const float), quant_config.set_nvfp4_2d_quantization(is_2d_quantization); quant_config.set_nvfp4_4over6(use_4over6); quant_config.set_nvfp4_4over6_err_mode(err_mode); + quant_config.set_nvfp4_4over6_err_fast_math(use_4over6 && use_4over6_err_fast_math); // Call appropriate function based on operation type // Activation functions take 3 parameters (input, output, stream) @@ -1023,23 +992,31 @@ void performTest(float (*OP)(const float), const double rtol = 1.0E-6; if (use_4over6) { - compareResults_nvfp4_4over6(output, - ref_output.get(), - ref_output_t.get(), - ref_scales.get(), - ref_scales_t.get(), - ref_output_map6.get(), - ref_output_t_map6.get(), - ref_scales_map6.get(), - ref_scales_t_map6.get(), - rows, - cols, - unpadded_blocks_X, - unpadded_blocks_X_t, - scales_stride, - scales_stride_t, - true, - !is_row_scaled_nvfp4); + output.to_cpu(); + compare_nvfp4_4over6_candidates("output", + output.rowwise_cpu_dptr(), + output.rowwise_cpu_scale_inv_ptr(), + ref_output.get(), + ref_scales.get(), + ref_output_map6.get(), + ref_scales_map6.get(), + rows, + cols, + unpadded_blocks_X, + scales_stride); + if (!is_row_scaled_nvfp4) { + compare_nvfp4_4over6_candidates("output_t", + output.columnwise_cpu_dptr(), + output.columnwise_cpu_scale_inv_ptr(), + ref_output_t.get(), + ref_scales_t.get(), + ref_output_t_map6.get(), + ref_scales_t_map6.get(), + cols, + rows, + unpadded_blocks_X_t, + scales_stride_t); + } } else { // Set dump_data=true to enable dumping tensor data to files for analysis compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, @@ -1096,7 +1073,8 @@ class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam bool, NVFP4ScalingMode, bool, - NVTENVFP44Over6ErrMode>> {}; + NVTENVFP44Over6ErrMode, + bool>> {}; TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { // Skip tests for pre-Blackwell architectures @@ -1114,6 +1092,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { const NVFP4ScalingMode scaling_mode = std::get<4>(GetParam()); const bool use_4over6 = std::get<5>(GetParam()); const NVTENVFP44Over6ErrMode err_mode = std::get<6>(GetParam()); + const bool use_4over6_err_fast_math = std::get<7>(GetParam()); // Skip tests if the input tensor is 1D if (tensor_dims.size() < 2) { @@ -1132,7 +1111,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, performTest(OP, tensor_dims, use_fast_math, scaling_mode, use_4over6, - err_mode); + err_mode, use_4over6_err_fast_math); ); } @@ -1175,6 +1154,9 @@ std::string test_name(const FusedCastTransposeNVFP4TestSuite::ParamType& param) } else { name += "XMAE"; } + if (std::get<7>(param)) { + name += "XERR_FAST_MATH"; + } } return name; } @@ -1189,7 +1171,8 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(false), // use_fast_math ::testing::Values(NVFP4ScalingMode::Block1D), // scaling_mode ::testing::Values(false), // use_4over6 - ::testing::Values(kNVTENVFP44Over6ErrMAE)), // err_mode + ::testing::Values(kNVTENVFP44Over6ErrMAE), // err_mode + ::testing::Values(false)), // use_4over6_err_fast_math [](const testing::TestParamInfo& info) { return test_name(info.param); }); @@ -1204,7 +1187,8 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(false), // use_fast_math ::testing::Values(NVFP4ScalingMode::RowScaled1D), // scaling_mode ::testing::Values(false), // use_4over6 - ::testing::Values(kNVTENVFP44Over6ErrMAE)), // err_mode + ::testing::Values(kNVTENVFP44Over6ErrMAE), // err_mode + ::testing::Values(false)), // use_4over6_err_fast_math [](const testing::TestParamInfo& info) { return test_name(info.param); }); @@ -1221,7 +1205,8 @@ INSTANTIATE_TEST_SUITE_P( NVFP4ScalingMode::Block2D), // scaling_mode ::testing::Values(true), // use_4over6 ::testing::Values(kNVTENVFP44Over6ErrMAE, - kNVTENVFP44Over6ErrMSE)), // err_mode + kNVTENVFP44Over6ErrMSE), // err_mode + ::testing::Values(false, true)), // use_4over6_err_fast_math [](const testing::TestParamInfo& info) { return test_name(info.param); }); @@ -1237,7 +1222,8 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(NVFP4ScalingMode::RowScaled1D), // scaling_mode ::testing::Values(true), // use_4over6 ::testing::Values(kNVTENVFP44Over6ErrMAE, - kNVTENVFP44Over6ErrMSE)), // err_mode + kNVTENVFP44Over6ErrMSE), // err_mode + ::testing::Values(false, true)), // use_4over6_err_fast_math [](const testing::TestParamInfo& info) { return test_name(info.param); }); From 81e579e8fbae953ce220e7434e1a2b769e9f7b06 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 12 May 2026 13:40:55 -0700 Subject: [PATCH 48/57] Fold test case and clean up cpp test Signed-off-by: Ziang Li --- .../cpp/operator/test_cast_nvfp4_transpose.cu | 60 +++++++------------ 1 file changed, 22 insertions(+), 38 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 8e8d4f658a..1efc737cf9 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -103,6 +103,12 @@ enum class NVFP4ScalingMode { Block2D, }; +struct NVFP4FourOverSixTestConfig { + bool enabled = false; + NVTENVFP44Over6ErrMode err_mode = kNVTENVFP44Over6ErrMAE; + bool err_fast_math = false; +}; + bool use_2d_quantization(const NVFP4ScalingMode scaling_mode) { return scaling_mode == NVFP4ScalingMode::Block2D; } @@ -1072,9 +1078,7 @@ class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam transformer_engine::DType, bool, NVFP4ScalingMode, - bool, - NVTENVFP44Over6ErrMode, - bool>> {}; + NVFP4FourOverSixTestConfig>> {}; TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { // Skip tests for pre-Blackwell architectures @@ -1090,9 +1094,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { const DType input_type = std::get<2>(GetParam()); const bool use_fast_math = std::get<3>(GetParam()); const NVFP4ScalingMode scaling_mode = std::get<4>(GetParam()); - const bool use_4over6 = std::get<5>(GetParam()); - const NVTENVFP44Over6ErrMode err_mode = std::get<6>(GetParam()); - const bool use_4over6_err_fast_math = std::get<7>(GetParam()); + const NVFP4FourOverSixTestConfig config = std::get<5>(GetParam()); // Skip tests if the input tensor is 1D if (tensor_dims.size() < 2) { @@ -1110,8 +1112,8 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { } TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, - performTest(OP, tensor_dims, use_fast_math, scaling_mode, use_4over6, - err_mode, use_4over6_err_fast_math); + performTest(OP, tensor_dims, use_fast_math, scaling_mode, config.enabled, + config.err_mode, config.err_fast_math); ); } @@ -1147,14 +1149,15 @@ std::string test_name(const FusedCastTransposeNVFP4TestSuite::ParamType& param) name += "X_FAST_SCALING"; } name += to_string(std::get<4>(param)); - if (std::get<5>(param)) { + const NVFP4FourOverSixTestConfig& config = std::get<5>(param); + if (config.enabled) { name += "X4OVER6"; - if (std::get<6>(param) == kNVTENVFP44Over6ErrMSE) { + if (config.err_mode == kNVTENVFP44Over6ErrMSE) { name += "XMSE"; } else { name += "XMAE"; } - if (std::get<7>(param)) { + if (config.err_fast_math) { name += "XERR_FAST_MATH"; } } @@ -1170,9 +1173,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(DType::kBFloat16), // input_type ::testing::Values(false), // use_fast_math ::testing::Values(NVFP4ScalingMode::Block1D), // scaling_mode - ::testing::Values(false), // use_4over6 - ::testing::Values(kNVTENVFP44Over6ErrMAE), // err_mode - ::testing::Values(false)), // use_4over6_err_fast_math + ::testing::Values(NVFP4FourOverSixTestConfig{})), // four_over_six_config [](const testing::TestParamInfo& info) { return test_name(info.param); }); @@ -1186,9 +1187,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(DType::kBFloat16, DType::kFloat32), // input_type ::testing::Values(false), // use_fast_math ::testing::Values(NVFP4ScalingMode::RowScaled1D), // scaling_mode - ::testing::Values(false), // use_4over6 - ::testing::Values(kNVTENVFP44Over6ErrMAE), // err_mode - ::testing::Values(false)), // use_4over6_err_fast_math + ::testing::Values(NVFP4FourOverSixTestConfig{})), // four_over_six_config [](const testing::TestParamInfo& info) { return test_name(info.param); }); @@ -1202,28 +1201,13 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(DType::kBFloat16, DType::kFloat32), // input_type ::testing::Values(false), // use_fast_math ::testing::Values(NVFP4ScalingMode::Block1D, + NVFP4ScalingMode::RowScaled1D, NVFP4ScalingMode::Block2D), // scaling_mode - ::testing::Values(true), // use_4over6 - ::testing::Values(kNVTENVFP44Over6ErrMAE, - kNVTENVFP44Over6ErrMSE), // err_mode - ::testing::Values(false, true)), // use_4over6_err_fast_math - [](const testing::TestParamInfo& info) { - return test_name(info.param); - }); - -INSTANTIATE_TEST_SUITE_P( - OperatorTestRowScaled4Over6, - FusedCastTransposeNVFP4TestSuite, - ::testing::Combine( - ::testing::ValuesIn(Activation_types), // activation_type - ::testing::ValuesIn(tensor_dims), // tensor_dims - ::testing::Values(DType::kFloat32), // input_type - ::testing::Values(false), // use_fast_math - ::testing::Values(NVFP4ScalingMode::RowScaled1D), // scaling_mode - ::testing::Values(true), // use_4over6 - ::testing::Values(kNVTENVFP44Over6ErrMAE, - kNVTENVFP44Over6ErrMSE), // err_mode - ::testing::Values(false, true)), // use_4over6_err_fast_math + ::testing::Values( + NVFP4FourOverSixTestConfig{true, kNVTENVFP44Over6ErrMAE, false}, + NVFP4FourOverSixTestConfig{true, kNVTENVFP44Over6ErrMAE, true}, + NVFP4FourOverSixTestConfig{true, kNVTENVFP44Over6ErrMSE, false}, + NVFP4FourOverSixTestConfig{true, kNVTENVFP44Over6ErrMSE, true})), // four_over_six_config [](const testing::TestParamInfo& info) { return test_name(info.param); }); From 1e311efec58b28378d720cb3d260132f7b56b824 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 12 May 2026 15:31:02 -0700 Subject: [PATCH 49/57] Initial 448 vs 256 implementation Signed-off-by: Ziang Li --- docs/envvars.rst | 10 +- .../cpp/operator/test_cast_nvfp4_transpose.cu | 87 +++++--- tests/cpp/operator/test_dequantize_nvfp4.cu | 40 +++- tests/cpp/test_common.cu | 12 ++ tests/cpp/test_common.h | 2 + tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 10 + .../nvfp4/test_nvfp4_quantize_exact.py | 8 + tests/pytorch/test_recipe.py | 30 ++- .../common/cast/dispatch/quantize.cuh | 24 ++- .../common/cast/nvfp4/core_nvfp4.cuh | 11 +- .../common/cast/nvfp4/dequantize_nvfp4.cuh | 10 +- .../cast/nvfp4/quantize_4over6_nvfp4.cuh | 83 ++++---- .../cast/nvfp4/quantize_transpose_nvfp4.cuh | 99 ++++----- .../quantize_transpose_nvfp4_tuned_1D.cuh | 74 ++++--- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 4 + transformer_engine/common/common.h | 21 +- .../transformer_engine/transformer_engine.h | 49 ++++- transformer_engine/common/recipe/__init__.py | 10 +- transformer_engine/common/recipe/nvfp4.cu | 4 +- .../common/transformer_engine.cpp | 20 +- .../common/transpose/cast_transpose.h | 5 +- ...quantize_transpose_vector_blockwise_fp4.cu | 195 +++++++++++------- transformer_engine/pytorch/csrc/common.h | 2 + .../pytorch/csrc/extensions/cast.cpp | 32 +-- transformer_engine/pytorch/csrc/quantizer.cpp | 26 ++- .../pytorch/csrc/type_converters.cpp | 2 + .../custom_recipes/quantization_ref_nvfp4.py | 35 +++- transformer_engine/pytorch/quantization.py | 9 + .../pytorch/tensor/grouped_tensor.py | 3 + .../pytorch/tensor/nvfp4_tensor.py | 21 ++ .../tensor/storage/grouped_tensor_storage.py | 20 ++ .../tensor/storage/nvfp4_tensor_storage.py | 8 + 32 files changed, 673 insertions(+), 293 deletions(-) diff --git a/docs/envvars.rst b/docs/envvars.rst index d2456f3364..2a9a23fd2f 100644 --- a/docs/envvars.rst +++ b/docs/envvars.rst @@ -291,7 +291,13 @@ Kernel Configuration :Type: ``str`` (``weights``, ``activations``, or ``all``) :Default: unset - :Description: Enable per-block map-to-4 versus map-to-6 candidate selection for selected NVFP4 quantizers in the ``NVFP4BlockScaling`` recipe. ``weights`` selects weight tensor roles, ``activations`` selects non-weight tensor roles, and ``all`` selects both. The selected block scale is the candidate with lower configured input-domain error, and ties select map-to-6. This mode uses 256 as the global E4M3 scale bound instead of the 448 bound. Tensors using 4over6 currently require RHT and stochastic rounding to be disabled; activation and backward scopes therefore require ``NVTE_NVFP4_DISABLE_RHT=1`` and ``NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING=1``. + :Description: Enable per-block map-to-4 versus map-to-6 candidate selection for selected NVFP4 quantizers in the ``NVFP4BlockScaling`` recipe. ``weights`` selects weight tensor roles, ``activations`` selects non-weight tensor roles, and ``all`` selects both. The selected block scale is the candidate with lower configured input-domain error, and ties select map-to-6. By default, this mode keeps the standard NVFP4 global E4M3 scale bound of 448. Tensors using 4over6 currently require RHT and stochastic rounding to be disabled; activation and backward scopes therefore require ``NVTE_NVFP4_DISABLE_RHT=1`` and ``NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING=1``. + +.. envvar:: NVTE_NVFP4_4OVER6_E4M3_USE_256 + + :Type: ``str`` (``weights``, ``activations``, or ``all``) + :Default: unset + :Description: Select NVFP4 4over6 quantizers that use 256 instead of 448 as the global E4M3 scale bound. ``weights`` selects weight tensor roles, ``activations`` selects non-weight tensor roles, and ``all`` selects both. This option is only meaningful for tensor roles that also enable :envvar:`NVTE_NVFP4_4OVER6`. .. envvar:: NVTE_NVFP4_4OVER6_ERR_MODE @@ -299,7 +305,7 @@ Kernel Configuration :Default: ``MAE`` :Description: Select the input-domain error metric used by NVFP4 4over6 map-to-4 versus map-to-6 candidate selection in the ``NVFP4BlockScaling`` recipe. -.. envvar:: NVTE_NVFP4_4OVER6_ERR_FAST_MATH +.. envvar:: NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH :Type: ``int`` (0 or 1) :Default: ``0`` diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 1efc737cf9..4a36edfca2 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -63,15 +63,15 @@ std::vector create_transpose(const InputType* const input, const size // Compute the global encode scale factor for a given global amax float compute_global_encode_scaling_factor_FP4(const float global_amax, const bool use_fast_math, - const bool use_4over6 = false) { + const bool use_e4m3_256 = false) { float fp8_max = 448.0f; - if (use_4over6) { + if (use_e4m3_256) { fp8_max = 256.0f; } constexpr float fp4_max = 6.0f; // 6.0f; float global_encode_scale = fp8_max * fp4_max / global_amax; // If scale is infinity, return the max normalized value - const float max_norm_clamp = (use_fast_math && !use_4over6) + const float max_norm_clamp = (use_fast_math && !use_e4m3_256) ? Numeric_Traits::maxNorm : Numeric_Traits::maxNorm; @@ -105,8 +105,9 @@ enum class NVFP4ScalingMode { struct NVFP4FourOverSixTestConfig { bool enabled = false; + bool e4m3_use_256 = false; NVTENVFP44Over6ErrMode err_mode = kNVTENVFP44Over6ErrMAE; - bool err_fast_math = false; + bool err_use_fast_math = false; }; bool use_2d_quantization(const NVFP4ScalingMode scaling_mode) { @@ -120,10 +121,12 @@ bool row_scaled_nvfp4(const NVFP4ScalingMode scaling_mode) { NVFP4FourOverSixQuantization compute_4over6_quantization_scales( const float block_amax, const float global_encode_scale) { constexpr float fp4_max = 6.0f; + constexpr float fp8_max = 448.0f; constexpr float scale_expansion_factor = 1.5f; const float base_sf_high_precision = block_amax / fp4_max * global_encode_scale; - const float sf_high_precision_map4 = base_sf_high_precision * scale_expansion_factor; - const float sf_high_precision_map6 = base_sf_high_precision; + const float sf_high_precision_map4 = + fminf(base_sf_high_precision * scale_expansion_factor, fp8_max); + const float sf_high_precision_map6 = fminf(base_sf_high_precision, fp8_max); const fp8e4m3 scale_map4 = static_cast(sf_high_precision_map4); const fp8e4m3 scale_map6 = static_cast(sf_high_precision_map6); @@ -200,11 +203,13 @@ void quantize_nvfp4_1d(float (*OP)(const float), const float global_amax, const bool use_fast_math, const bool use_4over6 = false, + const bool use_e4m3_256 = false, const NVFP4FourOverSixCandidate four_over_six_candidate = NVFP4FourOverSixCandidate::Map6) { // Compute a global encoding/decoding scaling factor for all S_dec_b - const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math, use_4over6); + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math, + use_e4m3_256); constexpr size_t block_size_X = 16; const size_t blocks_X = divide_round_up(cols, block_size_X); @@ -301,10 +306,12 @@ void compute_2d_mathematical_scales(float (*OP)(const float), std::vector>& math_scales, const bool use_fast_math, const bool use_4over6 = false, + const bool use_e4m3_256 = false, const NVFP4FourOverSixCandidate four_over_six_candidate = NVFP4FourOverSixCandidate::Map6) { - const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math, use_4over6); + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math, + use_e4m3_256); constexpr size_t block_size_Y = 16; constexpr size_t block_size_X = 16; const size_t blocks_Y = divide_round_up(rows, block_size_Y); @@ -358,15 +365,17 @@ void quantize_nvfp4_2d(float (*OP)(const float), const float global_amax, const bool use_fast_math, const bool use_4over6 = false, + const bool use_e4m3_256 = false, const NVFP4FourOverSixCandidate four_over_six_candidate = NVFP4FourOverSixCandidate::Map6) { // Step 1: Compute mathematical 8x8 scaling factors std::vector> math_scales; compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math, - use_4over6, four_over_six_candidate); + use_4over6, use_e4m3_256, four_over_six_candidate); - const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math, use_4over6); + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math, + use_e4m3_256); constexpr size_t block_size_Y = 16; constexpr size_t block_size_X = 16; const size_t blocks_Y = divide_round_up(rows, block_size_Y); @@ -450,14 +459,15 @@ void quantize_nvfp4(float (*OP)(const float), const bool use_fast_math, const bool use_2d_quantization = false, const bool use_4over6 = false, + const bool use_e4m3_256 = false, const NVFP4FourOverSixCandidate four_over_six_candidate = NVFP4FourOverSixCandidate::Map6) { if (use_2d_quantization) { quantize_nvfp4_2d(OP, input, output, scales, rows, cols, scales_stride, global_amax, - use_fast_math, use_4over6, four_over_six_candidate); + use_fast_math, use_4over6, use_e4m3_256, four_over_six_candidate); } else { quantize_nvfp4_1d(OP, input, output, scales, rows, cols, scales_stride, global_amax, - use_fast_math, use_4over6, four_over_six_candidate); + use_fast_math, use_4over6, use_e4m3_256, four_over_six_candidate); } } @@ -477,6 +487,7 @@ void compute_ref(float (*OP)(const float), const bool use_2d_quantization = false, const bool row_scaled_nvfp4 = false, const bool use_4over6 = false, + const bool use_e4m3_256 = false, const NVFP4FourOverSixCandidate four_over_six_candidate = NVFP4FourOverSixCandidate::Map6) { @@ -489,7 +500,7 @@ void compute_ref(float (*OP)(const float), // Step 1: Compute mathematical 8×8 scaling factors std::vector> math_scales; compute_2d_mathematical_scales(OP, input, rows, cols, *amax, math_scales, use_fast_math, - use_4over6, four_over_six_candidate); + use_4over6, use_e4m3_256, four_over_six_candidate); constexpr size_t block_size_Y = 16; constexpr size_t block_size_X = 16; @@ -517,10 +528,10 @@ void compute_ref(float (*OP)(const float), // Step 4: Process quantized outputs using the same algorithm as quantize_nvfp4_2d // (This part processes the actual FP4 data using the mathematical scaling factors) quantize_nvfp4_2d(OP, input, output, nullptr, rows, cols, scales_stride, *amax, - use_fast_math, use_4over6, + use_fast_math, use_4over6, use_e4m3_256, four_over_six_candidate); // scales already filled quantize_nvfp4_2d(OP, input_t.data(), output_t, nullptr, cols, rows, scales_stride_t, *amax, - use_fast_math, use_4over6, + use_fast_math, use_4over6, use_e4m3_256, four_over_six_candidate); // scales_t already filled return; @@ -540,6 +551,7 @@ void compute_ref(float (*OP)(const float), use_fast_math, use_2d_quantization, use_4over6, + use_e4m3_256, four_over_six_candidate); } return; @@ -547,9 +559,11 @@ void compute_ref(float (*OP)(const float), // Ref impl for basic NVFP4 quantize_nvfp4(OP, input, output, scales, rows, cols, scales_stride, *amax, - use_fast_math, use_2d_quantization, use_4over6, four_over_six_candidate); + use_fast_math, use_2d_quantization, use_4over6, use_e4m3_256, + four_over_six_candidate); quantize_nvfp4(OP, input_t.data(), output_t, scales_t, cols, rows, scales_stride_t, *amax, - use_fast_math, use_2d_quantization, use_4over6, four_over_six_candidate); + use_fast_math, use_2d_quantization, use_4over6, use_e4m3_256, + four_over_six_candidate); } void compare_nvfp4_tensors(const std::string& name, @@ -776,13 +790,14 @@ void performTest(float (*OP)(const float), const bool use_fast_math, const NVFP4ScalingMode scaling_mode = NVFP4ScalingMode::Block1D, const bool use_4over6 = false, + const bool use_e4m3_256 = false, const NVTENVFP44Over6ErrMode err_mode = kNVTENVFP44Over6ErrMAE, - const bool use_4over6_err_fast_math = false) { + const bool use_4over6_err_use_fast_math = false) { using namespace test; if (use_4over6 && use_fast_math) { std::cout << "WARNING: Plain NVFP4 fast math is ignored for 4over6. " - "Use use_4over6_err_fast_math to test the 4over6 candidate " + "Use use_4over6_err_use_fast_math to test the 4over6 candidate " "error fast-math path." << std::endl; } @@ -818,6 +833,7 @@ void performTest(float (*OP)(const float), Tensor input("input", shape, itype); Tensor output("output", shape, otype, rowwise, columnwise, NVTE_NVFP4_1D_SCALING); output.set_nvfp4_4over6(use_4over6); + output.set_nvfp4_4over6_e4m3_use_256(use_4over6 && use_e4m3_256); std::unique_ptr ref_output = std::make_unique(rows * (cols / 2)); std::unique_ptr ref_output_t = std::make_unique(cols * (rows / 2)); @@ -831,7 +847,7 @@ void performTest(float (*OP)(const float), fillCase(&input, InputsFillCase::uniform); if (use_4over6 && is_row_scaled_nvfp4) { - constexpr float target_row_amax = 256.0f * 6.0f * 8.0f; + const float target_row_amax = (use_e4m3_256 ? 256.0f : 448.0f) * 6.0f * 8.0f; auto *input_vals = input.rowwise_cpu_dptr(); for (size_t row = 0; row < rows; ++row) { float row_amax = 0.0f; @@ -883,7 +899,7 @@ void performTest(float (*OP)(const float), } else { // Golden value of amax chosen to make the 2nd-stage scaling mantissa zero and avoid rounding issues if (use_4over6) { - ref_amax.assign(1, 256.0f * 6.0f * 8.0f); + ref_amax.assign(1, (use_e4m3_256 ? 256.0f : 448.0f) * 6.0f * 8.0f); } else { ref_amax.assign(1, 448.0f * 6.0f * 8.0f); } @@ -919,6 +935,7 @@ void performTest(float (*OP)(const float), is_2d_quantization, is_row_scaled_nvfp4, use_4over6, + use_e4m3_256, NVFP4FourOverSixCandidate::Map4); compute_ref(OP, input.rowwise_cpu_dptr(), @@ -935,6 +952,7 @@ void performTest(float (*OP)(const float), is_2d_quantization, is_row_scaled_nvfp4, use_4over6, + use_e4m3_256, NVFP4FourOverSixCandidate::Map6); } else { compute_ref(OP, @@ -967,8 +985,9 @@ void performTest(float (*OP)(const float), quant_config.set_rng_state(rng_state.data()); quant_config.set_nvfp4_2d_quantization(is_2d_quantization); quant_config.set_nvfp4_4over6(use_4over6); + quant_config.set_nvfp4_4over6_e4m3_use_256(use_4over6 && use_e4m3_256); quant_config.set_nvfp4_4over6_err_mode(err_mode); - quant_config.set_nvfp4_4over6_err_fast_math(use_4over6 && use_4over6_err_fast_math); + quant_config.set_nvfp4_4over6_err_use_fast_math(use_4over6 && use_4over6_err_use_fast_math); // Call appropriate function based on operation type // Activation functions take 3 parameters (input, output, stream) @@ -1113,7 +1132,8 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, performTest(OP, tensor_dims, use_fast_math, scaling_mode, config.enabled, - config.err_mode, config.err_fast_math); + config.e4m3_use_256, config.err_mode, + config.err_use_fast_math); ); } @@ -1152,13 +1172,18 @@ std::string test_name(const FusedCastTransposeNVFP4TestSuite::ParamType& param) const NVFP4FourOverSixTestConfig& config = std::get<5>(param); if (config.enabled) { name += "X4OVER6"; + if (config.e4m3_use_256) { + name += "XE4M3_USE_256"; + } else { + name += "XE4M3_USE_448"; + } if (config.err_mode == kNVTENVFP44Over6ErrMSE) { name += "XMSE"; } else { name += "XMAE"; } - if (config.err_fast_math) { - name += "XERR_FAST_MATH"; + if (config.err_use_fast_math) { + name += "XERR_USE_FAST_MATH"; } } return name; @@ -1204,10 +1229,14 @@ INSTANTIATE_TEST_SUITE_P( NVFP4ScalingMode::RowScaled1D, NVFP4ScalingMode::Block2D), // scaling_mode ::testing::Values( - NVFP4FourOverSixTestConfig{true, kNVTENVFP44Over6ErrMAE, false}, - NVFP4FourOverSixTestConfig{true, kNVTENVFP44Over6ErrMAE, true}, - NVFP4FourOverSixTestConfig{true, kNVTENVFP44Over6ErrMSE, false}, - NVFP4FourOverSixTestConfig{true, kNVTENVFP44Over6ErrMSE, true})), // four_over_six_config + NVFP4FourOverSixTestConfig{true, false, kNVTENVFP44Over6ErrMAE, false}, + NVFP4FourOverSixTestConfig{true, false, kNVTENVFP44Over6ErrMAE, true}, + NVFP4FourOverSixTestConfig{true, false, kNVTENVFP44Over6ErrMSE, false}, + NVFP4FourOverSixTestConfig{true, false, kNVTENVFP44Over6ErrMSE, true}, + NVFP4FourOverSixTestConfig{true, true, kNVTENVFP44Over6ErrMAE, false}, + NVFP4FourOverSixTestConfig{true, true, kNVTENVFP44Over6ErrMAE, true}, + NVFP4FourOverSixTestConfig{true, true, kNVTENVFP44Over6ErrMSE, false}, + NVFP4FourOverSixTestConfig{true, true, kNVTENVFP44Over6ErrMSE, true})), // four_over_six_config [](const testing::TestParamInfo& info) { return test_name(info.param); }); diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu index 714574facd..823fb2f9c8 100644 --- a/tests/cpp/operator/test_dequantize_nvfp4.cu +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -47,8 +47,8 @@ void compute_ref_dequantize_nvfp4(const uint8_t *packed_data, size_t rows, size_t cols, size_t scale_stride, - bool use_4over6) { - const float factor_inv = 1.0f / (6.0f * (use_4over6 ? 256.0f : 448.0f)); + bool use_e4m3_256) { + const float factor_inv = 1.0f / (6.0f * (use_e4m3_256 ? 256.0f : 448.0f)); constexpr size_t BLOCK_SIZE = 16; const size_t Mread = cols / BLOCK_SIZE; const size_t bytes_per_block = BLOCK_SIZE / 2; @@ -92,7 +92,8 @@ float compute_amax(test::Tensor &t, size_t rows, size_t cols) { template void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, const bool row_scaled_nvfp4, - const bool use_4over6) { + const bool use_4over6, + const bool use_e4m3_256) { using namespace test; DType otype = TypeInfo::dtype; @@ -108,7 +109,9 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, // Configure quantized tensor amax size_t amax_size = 1; quantized.set_nvfp4_4over6(use_4over6); + quantized.set_nvfp4_4over6_e4m3_use_256(use_4over6 && use_e4m3_256); ASSERT_EQ(quantized.nvfp4_4over6(), use_4over6); + ASSERT_EQ(quantized.nvfp4_4over6_e4m3_use_256(), use_4over6 && use_e4m3_256); if (row_scaled_nvfp4) { quantized.set_row_scaled_nvfp4(true); amax_size = rows; @@ -122,6 +125,7 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, if (rows > 0 && cols > 0) { QuantizationConfigWrapper quant_config; quant_config.set_nvfp4_4over6(use_4over6); + quant_config.set_nvfp4_4over6_e4m3_use_256(use_4over6 && use_e4m3_256); nvte_quantize_v2(input.data(), quantized.data(), quant_config, 0); cudaDeviceSynchronize(); auto err = cudaGetLastError(); @@ -152,7 +156,7 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, std::make_unique(rows * cols); compute_ref_dequantize_nvfp4( fp4_data, scales, amax_vals, ref_output.get(), - rows, cols, scale_stride, use_4over6); + rows, cols, scale_stride, use_4over6 && use_e4m3_256); // Compare results from TE and reference impls auto [atol, rtol] = getTolerances(otype); @@ -163,7 +167,8 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, template void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, const bool row_scaled_nvfp4, - const bool use_4over6) { + const bool use_4over6, + const bool use_e4m3_256) { using namespace test; DType otype = TypeInfo::dtype; @@ -173,7 +178,9 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, Tensor quantized_compact("quantized_compact", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); quantized_compact.set_nvfp4_4over6(use_4over6); + quantized_compact.set_nvfp4_4over6_e4m3_use_256(use_4over6 && use_e4m3_256); ASSERT_EQ(quantized_compact.nvfp4_4over6(), use_4over6); + ASSERT_EQ(quantized_compact.nvfp4_4over6_e4m3_use_256(), use_4over6 && use_e4m3_256); if (row_scaled_nvfp4) { quantized_compact.set_row_scaled_nvfp4(true); } else if (rows > 0 && cols > 0) { @@ -185,6 +192,7 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, if (rows > 0 && cols > 0) { QuantizationConfigWrapper quant_config; quant_config.set_nvfp4_4over6(use_4over6); + quant_config.set_nvfp4_4over6_e4m3_use_256(use_4over6 && use_e4m3_256); nvte_quantize_v2(input.data(), quantized_compact.data(), quant_config, 0); cudaDeviceSynchronize(); } @@ -198,7 +206,9 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, Tensor quantized_swizzled("quantized_swizzled", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); quantized_swizzled.set_nvfp4_4over6(use_4over6); + quantized_swizzled.set_nvfp4_4over6_e4m3_use_256(use_4over6 && use_e4m3_256); ASSERT_EQ(quantized_swizzled.nvfp4_4over6(), use_4over6); + ASSERT_EQ(quantized_swizzled.nvfp4_4over6_e4m3_use_256(), use_4over6 && use_e4m3_256); if (row_scaled_nvfp4) { quantized_swizzled.set_row_scaled_nvfp4(true); } else { @@ -274,6 +284,7 @@ class DequantizeNVFP4TestSuite : public ::testing::TestWithParam , transformer_engine::DType, bool, + bool, bool>> {}; TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4) @@ -286,10 +297,11 @@ TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4) const DType output_type = std::get<1>(GetParam()); const bool row_scaled_nvfp4 = std::get<2>(GetParam()); const bool use_4over6 = std::get<3>(GetParam()); + const bool use_e4m3_256 = use_4over6 && std::get<4>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, performTest_dequantize_nvfp4( - tensor_size.first, tensor_size.second, row_scaled_nvfp4, use_4over6); + tensor_size.first, tensor_size.second, row_scaled_nvfp4, use_4over6, use_e4m3_256); ); } @@ -300,6 +312,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(nvfp4_tensor_dims), ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), ::testing::Bool(), + ::testing::Bool(), ::testing::Bool()), [](const testing::TestParamInfo& info) { @@ -307,7 +320,11 @@ INSTANTIATE_TEST_SUITE_P( std::to_string(std::get<0>(info.param).second) + "X" + test::typeName(std::get<1>(info.param)) + "X" + (std::get<2>(info.param) ? "RowScaled" : "PerTensor") + "X" + - (std::get<3>(info.param) ? "FourOverSix" : "Default"); + (std::get<3>(info.param) ? "FourOverSix" : "Default") + "X" + + (std::get<3>(info.param) + ? (std::get<4>(info.param) ? "E4M3Use256" : "E4M3Use448") + : (std::get<4>(info.param) ? "E4M3Use256Ignored" + : "E4M3Use448")); return name; } ); @@ -316,6 +333,7 @@ class DequantizeNVFP4SwizzledTestSuite : public ::testing::TestWithParam , transformer_engine::DType, bool, + bool, bool>> {}; TEST_P(DequantizeNVFP4SwizzledTestSuite, TestDequantizeNVFP4Swizzled) @@ -328,10 +346,11 @@ TEST_P(DequantizeNVFP4SwizzledTestSuite, TestDequantizeNVFP4Swizzled) const DType output_type = std::get<1>(GetParam()); const bool row_scaled_nvfp4 = std::get<2>(GetParam()); const bool use_4over6 = std::get<3>(GetParam()); + const bool use_e4m3_256 = use_4over6 && std::get<4>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, performTest_dequantize_nvfp4_swizzled( - tensor_size.first, tensor_size.second, row_scaled_nvfp4, use_4over6); + tensor_size.first, tensor_size.second, row_scaled_nvfp4, use_4over6, use_e4m3_256); ); } @@ -342,6 +361,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(nvfp4_tensor_dims), ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), ::testing::Bool(), + ::testing::Bool(), ::testing::Bool()), [](const testing::TestParamInfo& info) { @@ -350,6 +370,10 @@ INSTANTIATE_TEST_SUITE_P( test::typeName(std::get<1>(info.param)) + "X" + (std::get<2>(info.param) ? "RowScaled" : "PerTensor") + "X" + (std::get<3>(info.param) ? "FourOverSix" : "Default") + "X" + + (std::get<3>(info.param) + ? (std::get<4>(info.param) ? "E4M3Use256" : "E4M3Use448") + : (std::get<4>(info.param) ? "E4M3Use256Ignored" + : "E4M3Use448")) + "X" + "Swizzled"; return name; } diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 4687160f1b..055c3b744e 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -446,12 +446,24 @@ void Tensor::set_nvfp4_4over6(bool nvfp4_4over6) { tensor_.set_nvfp4_4over6(nvfp4_4over6); } +void Tensor::set_nvfp4_4over6_e4m3_use_256(bool nvfp4_4over6_e4m3_use_256) { + NVTE_CHECK(tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING, + "NVFP4 4over6 E4M3 256 scale bound is only supported for NVFP4 tensors."); + tensor_.set_nvfp4_4over6_e4m3_use_256(nvfp4_4over6_e4m3_use_256); +} + bool Tensor::nvfp4_4over6() const { NVTE_CHECK(tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING, "NVFP4 4over6 is only supported for NVFP4 tensors."); return tensor_.get_nvfp4_4over6(); } +bool Tensor::nvfp4_4over6_e4m3_use_256() const { + NVTE_CHECK(tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING, + "NVFP4 4over6 E4M3 256 scale bound is only supported for NVFP4 tensors."); + return tensor_.get_nvfp4_4over6_e4m3_use_256(); +} + void Tensor::to_cpu() { if (data_rowwise_) { data_rowwise_->to_cpu(); } if (data_columnwise_) { data_columnwise_->to_cpu(); } diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 46fd320080..8c6c900040 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -294,12 +294,14 @@ class Tensor { } bool nvfp4_4over6() const; + bool nvfp4_4over6_e4m3_use_256() const; void set_tensor_amax_nullptr(); void set_with_gemm_swizzled_scales(bool with_gemm_swizzled_scales); void set_row_scaled_nvfp4(bool row_scaled_nvfp4); void set_nvfp4_4over6(bool nvfp4_4over6); + void set_nvfp4_4over6_e4m3_use_256(bool nvfp4_4over6_e4m3_use_256); void to_cpu(); void from_cpu(); diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 10d578aa95..5f9a839a78 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -29,8 +29,11 @@ def check_nvfp4_gemm_versus_reference( w_columnwise: bool = False, row_scaled_nvfp4: bool = False, use_4over6: bool = False, + four_over_six_e4m3_use_256: bool = False, four_over_six_err_mode: str = "MAE", ): + if four_over_six_e4m3_use_256 and not use_4over6: + pytest.skip("E4M3 256 bound is only meaningful for 4over6") te_dtype = tex.DType.kFloat4E2M1 # Setup device and random seed @@ -62,6 +65,7 @@ def check_nvfp4_gemm_versus_reference( with_post_rht_amax=False, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, + four_over_six_e4m3_use_256=four_over_six_e4m3_use_256, four_over_six_err_mode=four_over_six_err_mode, ) w_quantizer = NVFP4Quantizer( @@ -73,6 +77,7 @@ def check_nvfp4_gemm_versus_reference( with_rht=False, with_post_rht_amax=False, use_4over6=use_4over6, + four_over_six_e4m3_use_256=four_over_six_e4m3_use_256, four_over_six_err_mode=four_over_six_err_mode, ) @@ -130,6 +135,7 @@ def check_nvfp4_gemm_versus_reference( quant_tile_shape=(1, 16), row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, + four_over_six_e4m3_use_256=four_over_six_e4m3_use_256, four_over_six_err_mode=four_over_six_err_mode, ) w_ref_quantizer = NVFP4QuantizerRef( @@ -140,6 +146,7 @@ def check_nvfp4_gemm_versus_reference( eps=0.0, quant_tile_shape=(1, 16), use_4over6=use_4over6, + four_over_six_e4m3_use_256=four_over_six_e4m3_use_256, four_over_six_err_mode=four_over_six_err_mode, ) @@ -442,6 +449,7 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) @pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) +@pytest.mark.parametrize("four_over_six_e4m3_use_256", [False, True], ids=["e4m3_448", "e4m3_256"]) @pytest.mark.parametrize("four_over_six_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) def test_nvfp4_gemm_versus_reference( M: int, @@ -455,6 +463,7 @@ def test_nvfp4_gemm_versus_reference( is_w_columnwise: bool, row_scaled_nvfp4: bool, use_4over6: bool, + four_over_six_e4m3_use_256: bool, four_over_six_err_mode: str, ): if row_scaled_nvfp4: @@ -475,6 +484,7 @@ def test_nvfp4_gemm_versus_reference( w_columnwise=is_w_columnwise, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, + four_over_six_e4m3_use_256=four_over_six_e4m3_use_256, four_over_six_err_mode=four_over_six_err_mode, ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index b18521f1b1..12ffdc9329 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -53,8 +53,11 @@ def check_quantization_nvfp4_versus_reference( with_2d_quantization: bool, row_scaled_nvfp4: bool = False, use_4over6: bool = False, + four_over_six_e4m3_use_256: bool = False, four_over_six_err_mode: str = "MAE", ) -> None: + if four_over_six_e4m3_use_256 and not use_4over6: + pytest.skip("E4M3 256 bound is only meaningful for 4over6") maybe_skip_row_scaled_unsupported_quantization( row_scaled_nvfp4, return_transpose, with_2d_quantization, use_4over6, x_dtype, M, N ) @@ -81,6 +84,7 @@ def check_quantization_nvfp4_versus_reference( with_2d_quantization=with_2d_quantization, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, + four_over_six_e4m3_use_256=four_over_six_e4m3_use_256, four_over_six_err_mode=four_over_six_err_mode, ) if use_cpp_allocator: @@ -116,6 +120,7 @@ def check_quantization_nvfp4_versus_reference( quant_tile_shape=quant_tile_shape, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, + four_over_six_e4m3_use_256=four_over_six_e4m3_use_256, four_over_six_err_mode=four_over_six_err_mode, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -193,6 +198,7 @@ def check_quantization_nvfp4_versus_reference( ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) @pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) +@pytest.mark.parametrize("four_over_six_e4m3_use_256", [False, True], ids=["e4m3_448", "e4m3_256"]) @pytest.mark.parametrize("four_over_six_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) def test_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, @@ -204,6 +210,7 @@ def test_quantization_block_tiling_versus_reference( with_2d_quantization: bool, row_scaled_nvfp4: bool, use_4over6: bool, + four_over_six_e4m3_use_256: bool, four_over_six_err_mode: str, ) -> None: check_quantization_nvfp4_versus_reference( @@ -216,6 +223,7 @@ def test_quantization_block_tiling_versus_reference( with_2d_quantization=with_2d_quantization, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, + four_over_six_e4m3_use_256=four_over_six_e4m3_use_256, four_over_six_err_mode=four_over_six_err_mode, ) diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 1c1b79e7aa..9425c7019a 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -520,13 +520,21 @@ def test_quantizer_update(self, module_class): [None, "weights", "activations", "all"], ids=["default", "weights", "activations", "all"], ) +@pytest.mark.parametrize( + "nvfp4_4over6_e4m3_use_256", + [None, "weights", "activations", "all"], + ids=["e4m3_448", "e4m3_256_weights", "e4m3_256_activations", "e4m3_256_all"], +) @pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) -def test_nvfp4_row_scaled_quantizer_roles(nvfp4_4over6, nvfp4_4over6_err_mode): +def test_nvfp4_row_scaled_quantizer_roles( + nvfp4_4over6, nvfp4_4over6_e4m3_use_256, nvfp4_4over6_err_mode +): recipe = NVFP4BlockScaling( disable_rht=True, disable_stochastic_rounding=True, disable_2d_quantization=True, nvfp4_4over6=nvfp4_4over6, + nvfp4_4over6_e4m3_use_256=nvfp4_4over6_e4m3_use_256, nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, row_scaled_activation=True, ) @@ -540,6 +548,17 @@ def expected_use_4over6(tensor_type): return tensor_type != "weight" return False + def expected_e4m3_use_256(tensor_type): + if not expected_use_4over6(tensor_type): + return False + if nvfp4_4over6_e4m3_use_256 == "all": + return True + if nvfp4_4over6_e4m3_use_256 == "weights": + return tensor_type == "weight" + if nvfp4_4over6_e4m3_use_256 == "activations": + return tensor_type != "weight" + return False + forward_quantizers = NVFP4BlockScalingRecipeState( recipe, mode="forward", @@ -549,6 +568,9 @@ def expected_use_4over6(tensor_type): assert [q.use_4over6 for q in forward_quantizers] == [ expected_use_4over6(tensor_type) for tensor_type in ("input", "weight", "output") ] + assert [q.four_over_six_e4m3_use_256 for q in forward_quantizers] == [ + expected_e4m3_use_256(tensor_type) for tensor_type in ("input", "weight", "output") + ] assert [q.four_over_six_err_mode for q in forward_quantizers] == [nvfp4_4over6_err_mode] * 3 assert not forward_quantizers[0].is_quantizable(torch.empty(16, 16)) assert forward_quantizers[1].is_quantizable(torch.empty(16, 16)) @@ -568,6 +590,9 @@ def expected_use_4over6(tensor_type): assert [q.use_4over6 for q in role_quantizers] == [ expected_use_4over6(tensor_type) for tensor_type in ("weight", "input", "output", "input") ] + assert [q.four_over_six_e4m3_use_256 for q in role_quantizers] == [ + expected_e4m3_use_256(tensor_type) for tensor_type in ("weight", "input", "output", "input") + ] assert [q.four_over_six_err_mode for q in role_quantizers] == [nvfp4_4over6_err_mode] * 4 backward_quantizers = NVFP4BlockScalingRecipeState( @@ -583,6 +608,9 @@ def expected_use_4over6(tensor_type): assert [q.use_4over6 for q in backward_quantizers] == [ expected_use_4over6(tensor_type) for tensor_type in ("grad_output", "grad_input") ] + assert [q.four_over_six_e4m3_use_256 for q in backward_quantizers] == [ + expected_e4m3_use_256(tensor_type) for tensor_type in ("grad_output", "grad_input") + ] assert [q.four_over_six_err_mode for q in backward_quantizers] == [nvfp4_4over6_err_mode] * 2 diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 3cc4ed93e5..e8acf5f276 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -103,7 +103,13 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, const bool row_scaled_nvfp4 = output_tensor->row_scaled_nvfp4; NVTE_CHECK(quant_config_cpp.nvfp4_4over6 == output_tensor->nvfp4_4over6, "Tensor and quantization config have inconsistent options for NVFP4 4over6."); + NVTE_CHECK( + quant_config_cpp.nvfp4_4over6_e4m3_use_256 == output_tensor->nvfp4_4over6_e4m3_use_256, + "Tensor and quantization config have inconsistent options for NVFP4 4over6 " + "E4M3 scale bound."); const bool use_4over6 = quant_config_cpp.nvfp4_4over6; + NVTE_CHECK(use_4over6 || !quant_config_cpp.nvfp4_4over6_e4m3_use_256, + "NVFP4 4over6 E4M3 256 scale bound requires 4over6 quantization."); NVTE_CHECK(!use_4over6 || !quant_config_cpp.stochastic_rounding, "NVFP4 4over6 quantization does not support stochastic rounding."); if (row_scaled_nvfp4) { @@ -141,8 +147,9 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, /*row_scaled_nvfp4=*/row_scaled_nvfp4, /*use_4over6=*/use_4over6, + /*use_4over6_e4m3_use_256=*/quant_config_cpp.nvfp4_4over6_e4m3_use_256, /*nvfp4_4over6_err_mode=*/quant_config_cpp.nvfp4_4over6_err_mode, - /*nvfp4_4over6_err_fast_math=*/quant_config_cpp.nvfp4_4over6_err_fast_math, + /*nvfp4_4over6_err_use_fast_math=*/quant_config_cpp.nvfp4_4over6_err_use_fast_math, /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); } @@ -259,7 +266,13 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens auto dtype = grad_tensor->dtype(); NVTE_CHECK(quant_config_cpp.nvfp4_4over6 == output_tensor->nvfp4_4over6, "Tensor and quantization config have inconsistent options for NVFP4 4over6."); + NVTE_CHECK( + quant_config_cpp.nvfp4_4over6_e4m3_use_256 == output_tensor->nvfp4_4over6_e4m3_use_256, + "Tensor and quantization config have inconsistent options for NVFP4 4over6 " + "E4M3 scale bound."); const bool use_4over6 = quant_config_cpp.nvfp4_4over6; + NVTE_CHECK(use_4over6 || !quant_config_cpp.nvfp4_4over6_e4m3_use_256, + "NVFP4 4over6 E4M3 256 scale bound requires 4over6 quantization."); NVTE_CHECK(!use_4over6 || !quant_config_cpp.stochastic_rounding, "NVFP4 4over6 quantization does not support stochastic rounding."); NVTE_CHECK(!output_tensor->row_scaled_nvfp4, @@ -292,8 +305,9 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, /*row_scaled_nvfp4=*/false, /*use_4over6=*/use_4over6, + /*use_4over6_e4m3_use_256=*/quant_config_cpp.nvfp4_4over6_e4m3_use_256, /*nvfp4_4over6_err_mode=*/quant_config_cpp.nvfp4_4over6_err_mode, - /*nvfp4_4over6_err_fast_math=*/quant_config_cpp.nvfp4_4over6_err_fast_math, + /*nvfp4_4over6_err_use_fast_math=*/quant_config_cpp.nvfp4_4over6_err_use_fast_math, /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); } @@ -392,8 +406,14 @@ void group_quantize_fwd_host_aware_helper(const NVTETensor input, NVTETensor *ou for (const auto *output_tensor : output_tensors) { NVTE_CHECK(quant_config_cpp.nvfp4_4over6 == output_tensor->nvfp4_4over6, "Tensor and quantization config have inconsistent options for NVFP4 4over6."); + NVTE_CHECK( + quant_config_cpp.nvfp4_4over6_e4m3_use_256 == output_tensor->nvfp4_4over6_e4m3_use_256, + "Tensor and quantization config have inconsistent options for NVFP4 4over6 " + "E4M3 scale bound."); } const bool use_4over6 = quant_config_cpp.nvfp4_4over6; + NVTE_CHECK(use_4over6 || !quant_config_cpp.nvfp4_4over6_e4m3_use_256, + "NVFP4 4over6 E4M3 256 scale bound requires 4over6 quantization."); NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, "2D quantization is not supported for group quantize."); NVTE_CHECK(!use_4over6, "NVFP4 4over6 quantization is not supported for group quantize."); diff --git a/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh index d04417d47d..ace00fb415 100644 --- a/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh @@ -75,13 +75,14 @@ namespace core { #if FP4_TYPE_SUPPORTED using namespace ptx; -// Compute the global encode scale factor for a given global amax -// 4over6 uses 256 instead of 448 to leave room for the map-to-4 scale expansion -template +// Compute the global encode scale factor for a given global amax. +// NVFP4 uses 448 by default. Some 4over6 tensors use 256 to leave room for +// map-to-4 scale expansion. +template __device__ __forceinline__ float compute_global_encode_scaling_factor_FP4(const float global_amax) { using namespace detail; - constexpr float fp8_max = USE_4OVER6 ? 256.0f : TypeExtrema::max; // 448.0f; - constexpr float fp4_max = TypeExtrema::max; // 6.0f; + constexpr float fp8_max = USE_E4M3_256 ? 256.0f : TypeExtrema::max; // 448.0f; + constexpr float fp4_max = TypeExtrema::max; // 6.0f; float global_encode_scale = fp8_max * fp4_max / global_amax; // If scale is infinity, return max value of float32 global_encode_scale = fminf(global_encode_scale, TypeExtrema::max); diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index 630eb540ca..0228d2786a 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -31,7 +31,7 @@ namespace dispatch { namespace nvfp4 { namespace dequantize_kernel { #if FP4_TYPE_SUPPORTED -template +template __global__ void __launch_bounds__(512) dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales, const float *const tensor_amax, const size_t N, const size_t M, @@ -64,7 +64,7 @@ __global__ void __launch_bounds__(512) value.vec = input_vectorized[my_index]; fp8e4m3 scale = scales[my_scale_index]; float amax = ROW_SCALED_NVFP4 ? tensor_amax[y] : tensor_amax[0]; - constexpr float factor_inv = 1.0f / (6.0f * (USE_4OVER6 ? 256.0f : 448.0f)); + constexpr float factor_inv = 1.0f / (6.0f * (USE_E4M3_256 ? 256.0f : 448.0f)); float final_scale = static_cast(scale) * amax * factor_inv; #pragma unroll for (int i = 0; i < 4; i++) { @@ -91,7 +91,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) const bool with_gemm_swizzled_scales = input.with_gemm_swizzled_scales; const bool row_scaled_nvfp4 = input.row_scaled_nvfp4; - const bool use_4over6 = input.nvfp4_4over6; + const bool use_e4m3_256 = input.nvfp4_4over6_e4m3_use_256; constexpr int FP4_BLOCK_SIZE = 16; const size_t N = input.flat_first_dim(); @@ -115,9 +115,9 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) TRANSFORMER_ENGINE_SWITCH_CONDITION( row_scaled_nvfp4, ROW_SCALED_NVFP4, TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_4over6, USE_4OVER6, + use_e4m3_256, USE_E4M3_256, dequantize_fp4_kernel<<>>( + USE_E4M3_256><<>>( input.data.dptr, reinterpret_cast(output->data.dptr), reinterpret_cast(input.scale_inv.dptr), reinterpret_cast(input.amax.dptr), N, Mread, diff --git a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh index d0fdc30369..9d4b0410c8 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh @@ -50,10 +50,12 @@ __device__ __forceinline__ void compute_4over6_decoding_scaling_factors( const float block_amax, const float S_enc, nvfp4_scale_t &S_dec_b_fp8_map4, nvfp4_scale_t &S_dec_b_fp8_map6) { constexpr float fp4_max = detail::TypeExtrema::max; // 6.0f + constexpr float fp8_max = detail::TypeExtrema::max; // 448.0f constexpr float scale_expansion_factor = 1.5f; const float base_sf_high_precision = block_amax / fp4_max * S_enc; - const float sf_high_precision_map4 = base_sf_high_precision * scale_expansion_factor; - const float sf_high_precision_map6 = base_sf_high_precision; + const float sf_high_precision_map4 = + fminf(base_sf_high_precision * scale_expansion_factor, fp8_max); + const float sf_high_precision_map6 = fminf(base_sf_high_precision, fp8_max); S_dec_b_fp8_map4 = static_cast(sf_high_precision_map4); S_dec_b_fp8_map6 = static_cast(sf_high_precision_map6); } @@ -124,7 +126,7 @@ __device__ __forceinline__ float compute_4over6_error(const float diff) { } } -template +template __device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error_rn( const float (&x)[8], const float block_scale_inverse, const nvfp4_scale_t S_dec_b_fp8, const float global_amax, float *err) { @@ -170,10 +172,11 @@ __device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error_rn( const uint16_t out_dequant_4_lo = out_dequant_4 & 0xFFFF; constexpr float fp4_max = detail::TypeExtrema::max; // 6.0f - constexpr float fp8_4over6_max = 256.0f; + constexpr float fp8_4over6_max = + USE_E4M3_256 ? 256.0f : detail::TypeExtrema::max; // 448.0f constexpr float err_denom = fp4_max * fp8_4over6_max; const float sf = static_cast(S_dec_b_fp8); - if constexpr (USE_ERR_FAST_MATH) { + if constexpr (USE_ERR_USE_FAST_MATH) { const float dequant[8] = { __half2float(__ushort_as_half(out_dequant_1_lo)), __half2float(__ushort_as_half(out_dequant_1_hi)), @@ -243,7 +246,8 @@ __device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error_rn( return out; } -template +template __device__ __forceinline__ void quantize_4over6_16x(const float (&first_half)[8], const float (&second_half)[8], const QuantizationScales4Over6 &scaling_factors, @@ -251,29 +255,29 @@ __device__ __forceinline__ void quantize_4over6_16x(const float (&first_half)[8] float &err_map6, uint32_t (&rOut_map4)[2], uint32_t (&rOut_map6)[2]) { if constexpr (REVERSE_PACK_ORDER) { - rOut_map4[1] = cvt_fp32_to_fp4_8x_with_error_rn( + rOut_map4[1] = cvt_fp32_to_fp4_8x_with_error_rn( second_half, static_cast(scaling_factors.SFcoefficient_map4), scaling_factors.S_dec_b_fp8_map4, global_amax, &err_map4); - rOut_map6[1] = cvt_fp32_to_fp4_8x_with_error_rn( + rOut_map6[1] = cvt_fp32_to_fp4_8x_with_error_rn( second_half, static_cast(scaling_factors.SFcoefficient_map6), scaling_factors.S_dec_b_fp8_map6, global_amax, &err_map6); - rOut_map4[0] = cvt_fp32_to_fp4_8x_with_error_rn( + rOut_map4[0] = cvt_fp32_to_fp4_8x_with_error_rn( first_half, static_cast(scaling_factors.SFcoefficient_map4), scaling_factors.S_dec_b_fp8_map4, global_amax, &err_map4); - rOut_map6[0] = cvt_fp32_to_fp4_8x_with_error_rn( + rOut_map6[0] = cvt_fp32_to_fp4_8x_with_error_rn( first_half, static_cast(scaling_factors.SFcoefficient_map6), scaling_factors.S_dec_b_fp8_map6, global_amax, &err_map6); } else { - rOut_map4[0] = cvt_fp32_to_fp4_8x_with_error_rn( + rOut_map4[0] = cvt_fp32_to_fp4_8x_with_error_rn( first_half, static_cast(scaling_factors.SFcoefficient_map4), scaling_factors.S_dec_b_fp8_map4, global_amax, &err_map4); - rOut_map6[0] = cvt_fp32_to_fp4_8x_with_error_rn( + rOut_map6[0] = cvt_fp32_to_fp4_8x_with_error_rn( first_half, static_cast(scaling_factors.SFcoefficient_map6), scaling_factors.S_dec_b_fp8_map6, global_amax, &err_map6); - rOut_map4[1] = cvt_fp32_to_fp4_8x_with_error_rn( + rOut_map4[1] = cvt_fp32_to_fp4_8x_with_error_rn( second_half, static_cast(scaling_factors.SFcoefficient_map4), scaling_factors.S_dec_b_fp8_map4, global_amax, &err_map4); - rOut_map6[1] = cvt_fp32_to_fp4_8x_with_error_rn( + rOut_map6[1] = cvt_fp32_to_fp4_8x_with_error_rn( second_half, static_cast(scaling_factors.SFcoefficient_map6), scaling_factors.S_dec_b_fp8_map6, global_amax, &err_map6); } @@ -291,7 +295,8 @@ selected_4over6_scale(const bool pick_map4, const QuantizationScales4Over6 &scal return scaling_factors.S_dec_b_fp8_map6; } -template +template __device__ __forceinline__ void quantize_4over6_16x(const float (&first_half)[8], const float (&second_half)[8], const QuantizationScales4Over6 &scaling_factors, @@ -303,7 +308,7 @@ __device__ __forceinline__ void quantize_4over6_16x(const float (&first_half)[8] __align__(8) uint32_t rOut_map4[2]; __align__(8) uint32_t rOut_map6[2]; - quantize_4over6_16x( + quantize_4over6_16x( first_half, second_half, scaling_factors, global_amax, err_map4, err_map6, rOut_map4, rOut_map6); @@ -403,7 +408,7 @@ __device__ __forceinline__ void load_4over6_vec_index_halves_16x(const vec_type } } -template +template __device__ __forceinline__ void quantize_4over6_candidates_16x( const float (&x)[16], const QuantizationScales4Over6 &scaling_factors, const float global_amax, QuantizationCandidates4Over6 &candidates) { @@ -412,7 +417,7 @@ __device__ __forceinline__ void quantize_4over6_candidates_16x( load_4over6_contiguous_halves_16x(x, first_half, second_half); candidates.reset_errors(); - quantize_4over6_16x( + quantize_4over6_16x( first_half, second_half, scaling_factors, global_amax, candidates.err_map4, candidates.err_map6, candidates.rOut_map4, candidates.rOut_map6); } @@ -450,8 +455,8 @@ __device__ __forceinline__ bool record_and_select_4over6_2d_block( return scratch.pick_map4_matrix[block_in_tile_y][block_in_tile_x] != 0; } -template +template __device__ __forceinline__ bool quantize_and_select_4over6_2d_block_16x( const float (&x)[16], const float block_amax, const float global_encode_scale, const float global_decode_scale, const float global_amax, const size_t block_in_tile_y, @@ -460,8 +465,8 @@ __device__ __forceinline__ bool quantize_and_select_4over6_2d_block_16x( nvfp4_scale_t &S_dec_b_fp8, QuantizationCandidates4Over6 &candidates) { const auto scaling_factors = compute_4over6_fp4_encode_quantization_scaling_factors( block_amax, global_encode_scale, global_decode_scale); - quantize_4over6_candidates_16x(x, scaling_factors, global_amax, - candidates); + quantize_4over6_candidates_16x( + x, scaling_factors, global_amax, candidates); const bool pick_map4 = record_and_select_4over6_2d_block( @@ -524,8 +529,8 @@ __device__ __forceinline__ void store_selected_4over6_packed_16x( store_4over6_packed_16x(candidates.selected_packed(pick_map4), output_vec); } -template +template __device__ __forceinline__ void quantize_4over6_contiguous_16x( const input_type *x, const QuantizationScales4Over6 &scaling_factors, const float global_amax, nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { @@ -533,12 +538,12 @@ __device__ __forceinline__ void quantize_4over6_contiguous_16x( float second_half[8]; load_4over6_contiguous_halves_16x(x, first_half, second_half); - quantize_4over6_16x( + quantize_4over6_16x( first_half, second_half, scaling_factors, global_amax, S_dec_b_fp8, rOut); } -template +template __device__ __forceinline__ void quantize_4over6_pair_array_16x( const pair_type (&x)[2][4], const QuantizationScales4Over6 &scaling_factors, const float global_amax, nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { @@ -546,11 +551,12 @@ __device__ __forceinline__ void quantize_4over6_pair_array_16x( float second_half[8]; load_4over6_pair_array_halves_16x(x, first_half, second_half); - quantize_4over6_16x( + quantize_4over6_16x( first_half, second_half, scaling_factors, global_amax, S_dec_b_fp8, rOut); } -template +template __device__ __forceinline__ void quantize_4over6_vec2_array_candidates_16x( const vec_type (&x)[8], const QuantizationScales4Over6 &scaling_factors, const float global_amax, QuantizationCandidates4Over6 &candidates) { @@ -559,13 +565,13 @@ __device__ __forceinline__ void quantize_4over6_vec2_array_candidates_16x( load_4over6_vec2_array_halves_16x(x, first_half, second_half); candidates.reset_errors(); - quantize_4over6_16x( + quantize_4over6_16x( first_half, second_half, scaling_factors, global_amax, candidates.err_map4, candidates.err_map6, candidates.rOut_map4, candidates.rOut_map6); } -template +template __device__ __forceinline__ void quantize_4over6_vec2_array_16x( const vec_type (&x)[8], const QuantizationScales4Over6 &scaling_factors, const float global_amax, nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { @@ -573,11 +579,12 @@ __device__ __forceinline__ void quantize_4over6_vec2_array_16x( float second_half[8]; load_4over6_vec2_array_halves_16x(x, first_half, second_half); - quantize_4over6_16x( + quantize_4over6_16x( first_half, second_half, scaling_factors, global_amax, S_dec_b_fp8, rOut); } -template +template __device__ __forceinline__ void quantize_4over6_vec_index_candidates_16x( const vec_type (&x)[16], const int idx, const QuantizationScales4Over6 &scaling_factors, const float global_amax, QuantizationCandidates4Over6 &candidates) { @@ -586,13 +593,13 @@ __device__ __forceinline__ void quantize_4over6_vec_index_candidates_16x( load_4over6_vec_index_halves_16x(x, idx, first_half, second_half); candidates.reset_errors(); - quantize_4over6_16x( + quantize_4over6_16x( first_half, second_half, scaling_factors, global_amax, candidates.err_map4, candidates.err_map6, candidates.rOut_map4, candidates.rOut_map6); } -template +template __device__ __forceinline__ void quantize_4over6_vec_index_16x( const vec_type (&x)[16], const int idx, const QuantizationScales4Over6 &scaling_factors, const float global_amax, nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { @@ -600,7 +607,7 @@ __device__ __forceinline__ void quantize_4over6_vec_index_16x( float second_half[8]; load_4over6_vec_index_halves_16x(x, idx, first_half, second_half); - quantize_4over6_16x( + quantize_4over6_16x( first_half, second_half, scaling_factors, global_amax, S_dec_b_fp8, rOut); } diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index 2774930fa6..09372f8b98 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -785,7 +785,8 @@ __global__ void __launch_bounds__(THREADS_NUM) template + bool USE_4OVER6_E4M3_USE_256, NVTENVFP44Over6ErrMode USE_4OVER6_ERR_MODE, + bool USE_4OVER6_ERR_USE_FAST_MATH> __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_output, @@ -915,14 +916,14 @@ __global__ void __launch_bounds__(THREADS_NUM) const float S_enc_rowwise = (amax_rowwise_ptr == nullptr) ? 1.0f - : compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); + : compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); // NOTE: This is to match with how emulation code was written. const float S_dec_rowwise = 1.0 / S_enc_rowwise; const float S_enc_colwise = (amax_colwise_ptr == nullptr) ? S_enc_rowwise - : compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); + : compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); const float S_dec_colwise = 1.0 / S_enc_colwise; const float global_amax_rowwise = (amax_rowwise_ptr == nullptr) ? 1.0f : *amax_rowwise_ptr; const float global_amax_colwise = @@ -1112,13 +1113,12 @@ __global__ void __launch_bounds__(THREADS_NUM) const size_t block_col = threadIdx.x % BLOCK_DIM; QuantizationCandidates4Over6 candidates; nvfp4_scale_t S_dec_b_fp8; - const bool pick_map4 = - quantize_and_select_4over6_2d_block_16x( - x_4over6, block_amax, S_enc_colwise, S_dec_colwise, global_amax_colwise, - block_in_tile_y, block_in_tile_x, block_col, *four_over_six_scratch, S_dec_b_fp8, - candidates); + const bool pick_map4 = quantize_and_select_4over6_2d_block_16x< + USE_4OVER6_ERR_MODE, USE_4OVER6_ERR_USE_FAST_MATH, USE_4OVER6_E4M3_USE_256, BLOCK_DIM, + BLOCKS_PER_TILE_Y, BLOCKS_PER_TILE_X>( + x_4over6, block_amax, S_enc_colwise, S_dec_colwise, global_amax_colwise, + block_in_tile_y, block_in_tile_x, block_col, *four_over_six_scratch, S_dec_b_fp8, + candidates); const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; @@ -1276,13 +1276,12 @@ __global__ void __launch_bounds__(THREADS_NUM) if constexpr (USE_4OVER6) { QuantizationCandidates4Over6 candidates; nvfp4_scale_t S_dec_b_fp8; - const bool pick_map4 = - quantize_and_select_4over6_2d_block_16x( - in_4over6_rowwise, block_amax, S_enc_rowwise, S_dec_rowwise, global_amax_rowwise, - block_in_tile_y, block_in_tile_x, tid_Y_rowwise, *four_over_six_scratch, - S_dec_b_fp8, candidates); + const bool pick_map4 = quantize_and_select_4over6_2d_block_16x< + USE_4OVER6_ERR_MODE, USE_4OVER6_ERR_USE_FAST_MATH, USE_4OVER6_E4M3_USE_256, BLOCK_DIM, + BLOCKS_PER_TILE_Y, BLOCKS_PER_TILE_X>( + in_4over6_rowwise, block_amax, S_enc_rowwise, S_dec_rowwise, global_amax_rowwise, + block_in_tile_y, block_in_tile_x, tid_Y_rowwise, *four_over_six_scratch, S_dec_b_fp8, + candidates); const size_t scales_offset_Y = scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; @@ -1422,10 +1421,12 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; const bool use_4over6 = quant_config ? quant_config->nvfp4_4over6 : false; + const bool use_4over6_e4m3_use_256 = + use_4over6 && quant_config && quant_config->nvfp4_4over6_e4m3_use_256; const NVTENVFP44Over6ErrMode use_4over6_err_mode = use_4over6 && quant_config ? quant_config->nvfp4_4over6_err_mode : kNVTENVFP44Over6ErrMAE; - const bool use_4over6_err_fast_math = - use_4over6 && quant_config && quant_config->nvfp4_4over6_err_fast_math; + const bool use_4over6_err_use_fast_math = + use_4over6 && quant_config && quant_config->nvfp4_4over6_err_use_fast_math; const bool row_scaled_nvfp4 = output->row_scaled_nvfp4; NVTE_CHECK(!row_scaled_nvfp4 || !use_2d_quantization, "Row-scaled NVFP4 quantization does not support 2D quantization."); @@ -1547,35 +1548,37 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, TRANSFORMER_ENGINE_NVFP4_4OVER6_ERR_MODE_SWITCH( use_4over6_err_mode, USE_4OVER6_ERR_MODE, { TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_4over6_err_fast_math, USE_4OVER6_ERR_FAST_MATH, - TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { - auto kernel = - quantize_transpose_nvfp4_kernel; - - if constexpr (use_2d_quantization) { - kernel = quantize_transpose_nvfp4_2D_kernel< - COMPUTE_ACTIVATIONS, ParamOP, OP, IType, USE_STOCHASTIC_ROUNDING, - RETURN_TRANSPOSE, USE_4OVER6, USE_4OVER6_ERR_MODE, - USE_4OVER6_ERR_FAST_MATH>; - } - using FourOverSixScratch = - core::QuantizationScratch4Over6; - constexpr size_t dshmem_size = - base_dshmem_size + - FourOverSixScratch::template dynamic_shared_memory_size< - use_2d_quantization, USE_4OVER6>(); - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - dshmem_size); - kernel<<>>( - tensor_map_input, tensor_map_output, tensor_map_output_transpose, - scales_ptr, scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, - amax_colwise_ptr, rows, cols, scale_stride, scale_stride_transpose, - rng_state); - });); + use_4over6_err_use_fast_math, USE_4OVER6_ERR_USE_FAST_MATH, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_4over6_e4m3_use_256, USE_4OVER6_E4M3_USE_256, + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + auto kernel = + quantize_transpose_nvfp4_kernel; + + if constexpr (use_2d_quantization) { + kernel = quantize_transpose_nvfp4_2D_kernel< + COMPUTE_ACTIVATIONS, ParamOP, OP, IType, USE_STOCHASTIC_ROUNDING, + RETURN_TRANSPOSE, USE_4OVER6, USE_4OVER6_E4M3_USE_256, + USE_4OVER6_ERR_MODE, USE_4OVER6_ERR_USE_FAST_MATH>; + } + using FourOverSixScratch = + core::QuantizationScratch4Over6; + constexpr size_t dshmem_size = + base_dshmem_size + + FourOverSixScratch::template dynamic_shared_memory_size< + use_2d_quantization, USE_4OVER6>(); + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, + scales_ptr, scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, + amax_colwise_ptr, rows, cols, scale_stride, scale_stride_transpose, + rng_state); + }););); }); }););); #else diff --git a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh index dc0c663473..68bdc84da2 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh @@ -186,7 +186,8 @@ compute_nvfp4_scaling_coefficient(const nvfp4_scale_t S_dec_block, const f } template + bool USE_4OVER6_E4M3_USE_256, NVTENVFP44Over6ErrMode USE_4OVER6_ERR_MODE, + bool USE_4OVER6_ERR_USE_FAST_MATH> __device__ __forceinline__ void colwise_scaling( const IType *__restrict__ sIn_ptr, fp4e2m1x2 *__restrict__ sOut_tr_ptr, nvfp4_scale_t *__restrict__ sSFcolwise_ptr, const float S_enc_colwise, @@ -237,7 +238,8 @@ __device__ __forceinline__ void colwise_scaling( const auto scaling_factors = core::compute_4over6_nvfp4_quantization_scaling_factors(block_amax[w], S_enc_colwise); - core::quantize_4over6_contiguous_16x( + core::quantize_4over6_contiguous_16x( rIn[w], scaling_factors, global_amax_colwise, S_dec_b_fp8, rOut); // Store scaling factors to SMEM buffer (R2S) @@ -280,7 +282,8 @@ __device__ __forceinline__ void colwise_scaling( } template + bool USE_4OVER6_E4M3_USE_256, NVTENVFP44Over6ErrMode USE_4OVER6_ERR_MODE, + bool USE_4OVER6_ERR_USE_FAST_MATH> __device__ __forceinline__ void rowwise_scaling( const IType *__restrict__ sIn_ptr, fp4e2m1x2 *__restrict__ sOut_ptr, nvfp4_scale_t *__restrict__ sSFrowwise_ptr, const float S_enc_rowwise, const int stage_Y, @@ -341,7 +344,8 @@ __device__ __forceinline__ void rowwise_scaling( if (row_idx < rows) { block_global_amax = amax_rowwise_ptr[row_idx]; block_S_enc_rowwise = - core::compute_global_encode_scaling_factor_FP4(block_global_amax); + core::compute_global_encode_scaling_factor_FP4( + block_global_amax); } else { block_global_amax = 1.0f; block_S_enc_rowwise = 1.0f; @@ -355,10 +359,12 @@ __device__ __forceinline__ void rowwise_scaling( __align__(8) uint32_t rOut[WAVES]; if (bank_group == 0) { - core::quantize_4over6_pair_array_16x( + core::quantize_4over6_pair_array_16x( rIn, scaling_factors, block_global_amax, S_dec_b_fp8, rOut); } else { - core::quantize_4over6_pair_array_16x( + core::quantize_4over6_pair_array_16x( rIn, scaling_factors, block_global_amax, S_dec_b_fp8, rOut); } @@ -427,7 +433,7 @@ __device__ __forceinline__ void rowwise_scaling( template + bool USE_4OVER6_ERR_USE_FAST_MATH, bool USE_4OVER6_E4M3_USE_256> __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D_kernel( const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_output, @@ -495,12 +501,14 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D const float S_enc_rowwise = (amax_rowwise_ptr == nullptr) ? 1.0f - : core::compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); + : core::compute_global_encode_scaling_factor_FP4( + *amax_rowwise_ptr); const float S_enc_colwise = (amax_colwise_ptr == nullptr) ? S_enc_rowwise - : core::compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); + : core::compute_global_encode_scaling_factor_FP4( + *amax_colwise_ptr); // Original NVFP4 uses a scalar per-tensor amax for both rowwise and columnwise output. // If no dedicated columnwise amax buffer is allocated, the rowwise amax is that same scalar. const float global_amax_colwise = (amax_colwise_ptr == nullptr) @@ -654,13 +662,13 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D // NVFP4 Quantization rowwise_scaling( + USE_4OVER6_E4M3_USE_256, USE_4OVER6_ERR_MODE, USE_4OVER6_ERR_USE_FAST_MATH>( sIn_ptr, sOut_ptr, sSFrowwise_ptr, S_enc_rowwise, stage_Y, stage_X, buff_in, buff_out, amax_rowwise_ptr, block_offset_Y, rows, rng, random_uint4, rnd_idx); if constexpr (RETURN_TRANSPOSE) { - colwise_scaling( + colwise_scaling( sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, S_enc_colwise, global_amax_colwise, stage_Y, stage_X, buff_in, buff_out_tr, rng, random_uint4, rnd_idx); } @@ -764,10 +772,12 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, const bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; const bool use_4over6 = quant_config ? quant_config->nvfp4_4over6 : false; + const bool use_4over6_e4m3_use_256 = + use_4over6 && quant_config && quant_config->nvfp4_4over6_e4m3_use_256; const NVTENVFP44Over6ErrMode use_4over6_err_mode = use_4over6 && quant_config ? quant_config->nvfp4_4over6_err_mode : kNVTENVFP44Over6ErrMAE; - const bool use_4over6_err_fast_math = - use_4over6 && quant_config && quant_config->nvfp4_4over6_err_fast_math; + const bool use_4over6_err_use_fast_math = + use_4over6 && quant_config && quant_config->nvfp4_4over6_err_use_fast_math; const bool row_scaled_nvfp4 = output->row_scaled_nvfp4; // If transposed output is allocated, return the transposed data @@ -887,21 +897,24 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, TRANSFORMER_ENGINE_NVFP4_4OVER6_ERR_MODE_SWITCH( use_4over6_err_mode, USE_4OVER6_ERR_MODE, TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_4over6_err_fast_math, USE_4OVER6_ERR_FAST_MATH, - TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { - auto kernel = quantize_transpose_nvfp4_tuned_1D_kernel< - USE_STOCHASTIC_ROUNDING, - /*USE_FAST_MATH=*/false, RETURN_TRANSPOSE, ROW_SCALED_NVFP4, - /*USE_4OVER6=*/true, USE_4OVER6_ERR_MODE, USE_4OVER6_ERR_FAST_MATH>; - - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - dshmem_size); - kernel<<>>( - tensor_map_input, tensor_map_output, tensor_map_output_transpose, - scales_ptr, scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, - amax_colwise_ptr, rows, cols, scale_stride, scale_stride_transpose, - rng_state); - }););););); + use_4over6_err_use_fast_math, USE_4OVER6_ERR_USE_FAST_MATH, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_4over6_e4m3_use_256, USE_4OVER6_E4M3_USE_256, + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + auto kernel = quantize_transpose_nvfp4_tuned_1D_kernel< + USE_STOCHASTIC_ROUNDING, + /*USE_FAST_MATH=*/false, RETURN_TRANSPOSE, ROW_SCALED_NVFP4, + /*USE_4OVER6=*/true, USE_4OVER6_ERR_MODE, + USE_4OVER6_ERR_USE_FAST_MATH, USE_4OVER6_E4M3_USE_256>; + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, + scales_ptr, scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, + amax_colwise_ptr, rows, cols, scale_stride, scale_stride_transpose, + rng_state); + });););););); } else { const bool use_fast_math = quant_config ? quant_config->use_fast_math : false; TRANSFORMER_ENGINE_SWITCH_CONDITION( @@ -914,7 +927,8 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, auto kernel = quantize_transpose_nvfp4_tuned_1D_kernel< USE_STOCHASTIC_ROUNDING, USE_FAST_MATH, RETURN_TRANSPOSE, ROW_SCALED_NVFP4, /*USE_4OVER6=*/false, kNVTENVFP44Over6ErrMAE, - /*USE_4OVER6_ERR_FAST_MATH=*/false>; + /*USE_4OVER6_ERR_USE_FAST_MATH=*/false, + /*USE_4OVER6_E4M3_USE_256=*/false>; cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 1b3aa1950f..34f687ab9d 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -234,6 +234,10 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz chunk.set_nvfp4_4over6(source.get_nvfp4_4over6()); continue; } + if (param_type == NVTETensorParam::kNVTENVFP44Over6E4M3Use256) { + chunk.set_nvfp4_4over6_e4m3_use_256(source.get_nvfp4_4over6_e4m3_use_256()); + continue; + } auto param = source.get_parameter(param_type); auto param_dptr = reinterpret_cast(param.data_ptr); auto param_dtype = static_cast(param.dtype); diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index bafac1518a..a1c31bae20 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -180,11 +180,16 @@ struct Tensor { bool row_scaled_nvfp4 = false; /*! \brief Whether NVFP4 uses 4over6 block scale selection. * - * Only meaningful for NVFP4 tensors. 4over6 tensors use 256 as their - * global E4M3 scale bound and store a selected map-to-4/map-to-6 - * candidate for each 1x16 block. + * Only meaningful for NVFP4 tensors. 4over6 tensors store a selected + * map-to-4/map-to-6 candidate for each 1x16 block. */ bool nvfp4_4over6 = false; + /*! \brief Whether NVFP4 4over6 uses 256 as the global E4M3 scale bound. + * + * Only meaningful when nvfp4_4over6 is true. If false, the standard NVFP4 + * E4M3 bound 448 is used. + */ + bool nvfp4_4over6_e4m3_use_256 = false; /*! Map from NVTETensorParam to parameter sizes */ static constexpr size_t attr_sizes[] = { @@ -197,7 +202,8 @@ struct Tensor { sizeof(NVTEBasicTensor), // kNVTEColumnwiseAmax sizeof(uint8_t), // kNVTEWithGEMMSwizzledScales sizeof(uint8_t), // kNVTERowScaledNVFP4 - sizeof(uint8_t) // kNVTENVFP44Over6 + sizeof(uint8_t), // kNVTENVFP44Over6 + sizeof(uint8_t) // kNVTENVFP44Over6E4M3Use256 }; Tensor() : scaling_mode{NVTE_DELAYED_TENSOR_SCALING}, nvte_tensor{0} {} @@ -215,6 +221,7 @@ struct Tensor { with_gemm_swizzled_scales = false; row_scaled_nvfp4 = false; nvfp4_4over6 = false; + nvfp4_4over6_e4m3_use_256 = false; } explicit operator NVTETensor() const noexcept { return nvte_tensor; } @@ -487,8 +494,9 @@ struct QuantizationConfig { bool stochastic_rounding = false; bool use_fast_math = false; bool nvfp4_4over6 = false; + bool nvfp4_4over6_e4m3_use_256 = false; NVTENVFP44Over6ErrMode nvfp4_4over6_err_mode = kNVTENVFP44Over6ErrMAE; - bool nvfp4_4over6_err_fast_math = false; + bool nvfp4_4over6_err_use_fast_math = false; static constexpr size_t attr_sizes[] = { sizeof(uint8_t), // force_pow_2_scales @@ -500,8 +508,9 @@ struct QuantizationConfig { sizeof(uint8_t), // stochastic_rounding sizeof(uint8_t), // use_fast_math sizeof(uint8_t), // nvfp4_4over6 + sizeof(uint8_t), // nvfp4_4over6_e4m3_use_256 sizeof(uint8_t), // nvfp4_4over6_err_mode - sizeof(uint8_t) // nvfp4_4over6_err_fast_math + sizeof(uint8_t) // nvfp4_4over6_err_use_fast_math }; }; diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index fd13ec958f..dcff2fbe1c 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -85,11 +85,17 @@ enum NVTETensorParam { kNVTERowScaledNVFP4 = 8, /*! Whether an NVFP4 tensor is encoded with 4over6 semantics. * - * This is part of the tensor data contract: 4over6 tensors use 256 as - * their global E4M3 scale bound, so downstream dequantization and GEMM - * scale consumers must decode them differently from default NVFP4 tensors. + * This records whether block scales were selected by comparing map-to-4 + * and map-to-6 candidates. */ kNVTENVFP44Over6 = 9, + /*! Whether an NVFP4 4over6 tensor uses 256 as its global E4M3 scale bound. + * + * This is part of the tensor data contract. Downstream dequantization and + * GEMM scale consumers must use the same global E4M3 bound used during + * quantization. + */ + kNVTENVFP44Over6E4M3Use256 = 10, kNVTENumTensorParams }; @@ -400,24 +406,29 @@ enum NVTEQuantizationConfigAttribute { * * 4over6 evaluates map-to-4 and map-to-6 candidates for each 1x16 block, * stores the lower-error candidate according to - * kNVTEQuantizationConfigNVFP44Over6ErrMode, and emits tensor data that - * uses a 256 global E4M3 scale bound. The output tensor's + * kNVTEQuantizationConfigNVFP44Over6ErrMode. The output tensor's * kNVTENVFP44Over6 metadata must match this option. */ kNVTEQuantizationConfigNVFP44Over6 = 8, + /*! Whether NVFP4 4over6 should use 256 as the global E4M3 scale bound. + * + * If disabled, 4over6 uses the default NVFP4 448 bound. The output tensor's + * kNVTENVFP44Over6E4M3Use256 metadata must match this option. + */ + kNVTEQuantizationConfigNVFP44Over6E4M3Use256 = 9, /*! Candidate-selection error mode for NVFP4 4over6 quantization. * * The value is an NVTENVFP44Over6ErrMode encoded as uint8_t. It is only * used when kNVTEQuantizationConfigNVFP44Over6 is enabled. */ - kNVTEQuantizationConfigNVFP44Over6ErrMode = 9, + kNVTEQuantizationConfigNVFP44Over6ErrMode = 10, /*! Whether the NVFP4 4over6 candidate error computation may use fast math. * * This is intentionally separate from kNVTEQuantizationConfigUseFastMath so * callers can keep candidate selection bitwise deterministic independent * of ordinary NVFP4 fast-math settings. */ - kNVTEQuantizationConfigNVFP44Over6ErrFastMath = 10, + kNVTEQuantizationConfigNVFP44Over6ErrUseFastMath = 11, kNVTEQuantizationConfigNumAttributes }; @@ -823,6 +834,11 @@ class TensorWrapper { nvte_set_tensor_param_v2(tensor_, kNVTENVFP44Over6, &val, sizeof(val)); } + void set_nvfp4_4over6_e4m3_use_256(bool nvfp4_4over6_e4m3_use_256) { + const auto val = static_cast(nvfp4_4over6_e4m3_use_256); + nvte_set_tensor_param_v2(tensor_, kNVTENVFP44Over6E4M3Use256, &val, sizeof(val)); + } + // Parameter getters NVTEBasicTensor get_parameter(const NVTETensorParam param) const noexcept { @@ -871,6 +887,12 @@ class TensorWrapper { return static_cast(val); } + bool get_nvfp4_4over6_e4m3_use_256() const { + uint8_t val = 0; + nvte_get_tensor_param_v2(tensor_, kNVTENVFP44Over6E4M3Use256, &val, sizeof(val), nullptr); + return static_cast(val); + } + /*! \brief Get an underlying NVTETensor. * * \return NVTETensor held by this TensorWrapper. @@ -1373,6 +1395,13 @@ class QuantizationConfigWrapper { sizeof(val)); } + /*! \brief Set whether NVFP4 4over6 uses the 256 global E4M3 scale bound */ + void set_nvfp4_4over6_e4m3_use_256(bool nvfp4_4over6_e4m3_use_256) { + const auto val = static_cast(nvfp4_4over6_e4m3_use_256); + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNVFP44Over6E4M3Use256, + &val, sizeof(val)); + } + /*! \brief Set NVFP4 4over6 candidate-selection error mode */ void set_nvfp4_4over6_err_mode(NVTENVFP44Over6ErrMode mode) { const auto val = static_cast(mode); @@ -1381,10 +1410,10 @@ class QuantizationConfigWrapper { } /*! \brief Set whether NVFP4 4over6 candidate error computation uses fast math */ - void set_nvfp4_4over6_err_fast_math(bool use_fast_math) { + void set_nvfp4_4over6_err_use_fast_math(bool use_fast_math) { const auto val = static_cast(use_fast_math); - nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNVFP44Over6ErrFastMath, - &val, sizeof(val)); + nvte_set_quantization_config_attribute( + config_, kNVTEQuantizationConfigNVFP44Over6ErrUseFastMath, &val, sizeof(val)); } private: diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 336ea3c517..f3658efacf 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -528,12 +528,14 @@ class NVFP4BlockScaling(Recipe): Select tensors that use NVFP4 4over6. In this mode NVFP4 quantization evaluates per-block map-to-4 and map-to-6 candidates and chooses the one with lower configured error. Ties choose map-to-6. The - global E4M3 scale bound is 256 in this mode instead of 448. The ``activations`` scope applies to every non-weight tensor role. Random Hadamard transforms and stochastic rounding are not yet supported on tensors that use 4over6; activation and backward scopes therefore require ``disable_rht=True`` and ``disable_stochastic_rounding=True``. + nvfp4_4over6_e4m3_use_256 : {None, 'weights', 'activations', 'all'}, default = None + Select 4over6 tensors that use 256 as the global E4M3 scale + bound. If unset, 4over6 uses the standard NVFP4 448 bound. nvfp4_4over6_err_mode : {'MAE', 'MSE'}, default = 'MAE' Error metric used by NVFP4 4over6 candidate selection. backward_override : {None, 'high_precision', 'dequantized'}, default = None @@ -551,6 +553,7 @@ class NVFP4BlockScaling(Recipe): disable_2d_quantization: bool = os.getenv("NVTE_NVFP4_DISABLE_2D_QUANTIZATION", "0") == "1" row_scaled_activation: bool = os.getenv("NVTE_NVFP4_ROW_SCALED_ACTIVATION", "0") == "1" nvfp4_4over6: Optional[str] = os.getenv("NVTE_NVFP4_4OVER6", None) + nvfp4_4over6_e4m3_use_256: Optional[str] = os.getenv("NVTE_NVFP4_4OVER6_E4M3_USE_256", None) nvfp4_4over6_err_mode: str = os.getenv("NVTE_NVFP4_4OVER6_ERR_MODE", "MAE").upper() fp4_format: Format = Format.E2M1 @@ -570,6 +573,10 @@ def __post_init__(self) -> None: assert ( self.nvfp4_4over6 in _NVFP4_4OVER6_SCOPES ), "NVTE_NVFP4_4OVER6 must be unset or one of: 'weights', 'activations', 'all'." + assert self.nvfp4_4over6_e4m3_use_256 in _NVFP4_4OVER6_SCOPES, ( + "NVTE_NVFP4_4OVER6_E4M3_USE_256 must be unset or one of: " + "'weights', 'activations', 'all'." + ) assert ( self.nvfp4_4over6_err_mode in _NVFP4_4OVER6_ERR_MODES ), "NVTE_NVFP4_4OVER6_ERR_MODE must be one of: 'MAE', 'MSE'." @@ -608,6 +615,7 @@ def _make_repr(self) -> str: f"backward_override={self.backward_override}, " f"row_scaled_activation={self.row_scaled_activation}, " f"nvfp4_4over6={self.nvfp4_4over6}, " + f"nvfp4_4over6_e4m3_use_256={self.nvfp4_4over6_e4m3_use_256}, " f"nvfp4_4over6_err_mode={self.nvfp4_4over6_err_mode}, " f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, " f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, " diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu index 23ca156092..4b1950cbd3 100644 --- a/transformer_engine/common/recipe/nvfp4.cu +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -924,8 +924,8 @@ void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_r void *amax_A_ptr = use_rowwise_amax_A ? tA->amax.dptr : tA->columnwise_amax.dptr; void *amax_B_ptr = use_rowwise_amax_B ? tB->amax.dptr : tB->columnwise_amax.dptr; void *alpha_ptr = tOut->data.dptr; - const float fp8_max_A = tA->nvfp4_4over6 ? 256.0f : 448.0f; - const float fp8_max_B = tB->nvfp4_4over6 ? 256.0f : 448.0f; + const float fp8_max_A = tA->nvfp4_4over6_e4m3_use_256 ? 256.0f : 448.0f; + const float fp8_max_B = tB->nvfp4_4over6_e4m3_use_256 ? 256.0f : 448.0f; // check for not null pointers NVTE_CHECK(amax_A_ptr != nullptr, "amax_A_ptr is null"); diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 5eb238dd51..1971e78bdc 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -858,6 +858,9 @@ void nvte_set_tensor_param_v2(NVTETensor tensor, NVTETensorParam param, const vo case kNVTENVFP44Over6: t.nvfp4_4over6 = static_cast(*reinterpret_cast(buf)); break; + case kNVTENVFP44Over6E4M3Use256: + t.nvfp4_4over6_e4m3_use_256 = static_cast(*reinterpret_cast(buf)); + break; default: NVTE_ERROR("Unsupported tensor parameter (", static_cast(param), ")"); } @@ -944,6 +947,9 @@ void nvte_get_tensor_param_v2(const NVTETensor tensor, NVTETensorParam param, vo case kNVTENVFP44Over6: *reinterpret_cast(buf) = static_cast(t->nvfp4_4over6); break; + case kNVTENVFP44Over6E4M3Use256: + *reinterpret_cast(buf) = static_cast(t->nvfp4_4over6_e4m3_use_256); + break; default: NVTE_ERROR("Unsupported tensor parameter (", static_cast(param), ")"); } @@ -1058,13 +1064,16 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigNVFP44Over6: bool_to_uint8(config_.nvfp4_4over6, buf); break; + case kNVTEQuantizationConfigNVFP44Over6E4M3Use256: + bool_to_uint8(config_.nvfp4_4over6_e4m3_use_256, buf); + break; case kNVTEQuantizationConfigNVFP44Over6ErrMode: { const auto val = static_cast(config_.nvfp4_4over6_err_mode); std::memcpy(buf, &val, attr_size); break; } - case kNVTEQuantizationConfigNVFP44Over6ErrFastMath: - bool_to_uint8(config_.nvfp4_4over6_err_fast_math, buf); + case kNVTEQuantizationConfigNVFP44Over6ErrUseFastMath: + bool_to_uint8(config_.nvfp4_4over6_err_use_fast_math, buf); break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); @@ -1124,6 +1133,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigNVFP44Over6: uint8_to_bool(buf, config_.nvfp4_4over6); break; + case kNVTEQuantizationConfigNVFP44Over6E4M3Use256: + uint8_to_bool(buf, config_.nvfp4_4over6_e4m3_use_256); + break; case kNVTEQuantizationConfigNVFP44Over6ErrMode: { const auto val = *reinterpret_cast(buf); NVTE_CHECK(val == static_cast(kNVTENVFP44Over6ErrMAE) || @@ -1132,8 +1144,8 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, config_.nvfp4_4over6_err_mode = static_cast(val); break; } - case kNVTEQuantizationConfigNVFP44Over6ErrFastMath: - uint8_to_bool(buf, config_.nvfp4_4over6_err_fast_math); + case kNVTEQuantizationConfigNVFP44Over6ErrUseFastMath: + uint8_to_bool(buf, config_.nvfp4_4over6_err_use_fast_math); break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h index c65d0d2d17..87ead8d186 100644 --- a/transformer_engine/common/transpose/cast_transpose.h +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -68,8 +68,9 @@ void quantize_transpose_vector_blockwise_fp4( const bool return_identity, const bool return_transpose, const bool pow2_scale, const bool swizzled_scale, const bool use_stochastic_rounding, const NVTETensor rng_state_tensor, const bool use_2d_quantization, const bool row_scaled_nvfp4, - const bool use_4over6, const NVTENVFP44Over6ErrMode nvfp4_4over6_err_mode, - const bool nvfp4_4over6_err_fast_math, const SimpleTensor &noop_tensor, cudaStream_t stream); + const bool use_4over6, const bool use_4over6_e4m3_use_256, + const NVTENVFP44Over6ErrMode nvfp4_4over6_err_mode, const bool nvfp4_4over6_err_use_fast_math, + const SimpleTensor &noop_tensor, cudaStream_t stream); } // namespace transformer_engine::detail diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index bcfd0ea211..ff8caa5712 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -308,7 +308,8 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x(const float2 in01, template + bool kUse4Over6, bool kUse4Over6E4M3Use256, NVTENVFP44Over6ErrMode k4Over6ErrMode, + bool kUse4Over6ErrUseFastMath> __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel( const IType* const input, const float* global_amax, OType* const output_c, OType* const output_t, ScaleType* const tile_scales_inv_c, ScaleType* const tile_scales_inv_t, @@ -416,7 +417,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo const int kNumThreadsReduce = kScaleBlockDim / kNVecOut; const float global_encode_scale = - kIsE8Scaling ? 1.0f : compute_global_encode_scaling_factor_FP4(global_amax[0]); + kIsE8Scaling ? 1.0f + : compute_global_encode_scaling_factor_FP4(global_amax[0]); constexpr float fp4_max_inv = 1.0f / TypeExtrema::max; const float global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; const float global_decode_scale = 1.0 / global_encode_scale; @@ -511,9 +513,9 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo float row_global_encode_scale = global_encode_scale; if constexpr (kRowScaledNVFP4) { row_global_encode_scale = - row_idx < num_rows - ? compute_global_encode_scaling_factor_FP4(global_amax[row_idx]) - : 1.0f; + row_idx < num_rows ? compute_global_encode_scaling_factor_FP4( + global_amax[row_idx]) + : 1.0f; } const float row_global_encode_scale_multiplier = kRowScaledNVFP4 ? row_global_encode_scale * fp4_max_inv : global_encode_scale_multiplier; @@ -549,8 +551,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo const size_t participant_idx = data_row_idx % kFP4BlockScalingSize; nvfp4_core::QuantizationCandidates4Over6 candidates; - nvfp4_core::quantize_4over6_vec2_array_candidates_16x( + nvfp4_core::quantize_4over6_vec2_array_candidates_16x< + k4Over6ErrMode, kUse4Over6ErrUseFastMath, kUse4Over6E4M3Use256>( smem_vec, scaling_factors, row_global_amax, candidates); const bool pick_map4 = nvfp4_core::record_and_select_4over6_2d_block( + nvfp4_core::quantize_4over6_vec2_array_16x( smem_vec, scaling_factors, row_global_amax, scale_inv, output_vec_4over6); nvfp4_core::store_4over6_packed_16x(output_vec_4over6, output_vec); } @@ -714,8 +717,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo const size_t participant_idx = data_col_idx % kFP4BlockScalingSize; nvfp4_core::QuantizationCandidates4Over6 candidates; - nvfp4_core::quantize_4over6_vec_index_candidates_16x( + nvfp4_core::quantize_4over6_vec_index_candidates_16x< + k4Over6ErrMode, kUse4Over6ErrUseFastMath, kUse4Over6E4M3Use256>( smem_vec, smem_idx, scaling_factors, global_amax[0], candidates); const bool pick_map4 = nvfp4_core::record_and_select_4over6_2d_block( + nvfp4_core::quantize_4over6_vec_index_16x( smem_vec, smem_idx, scaling_factors, global_amax[0], scale_inv, output_vec_4over6); nvfp4_core::store_4over6_packed_16x(output_vec_4over6, output_vec); } @@ -810,8 +814,9 @@ void quantize_transpose_vector_blockwise_fp4( const bool return_identity, const bool return_transpose, const bool pow2_scale, const bool swizzled_scale, const bool use_stochastic_rounding, const NVTETensor rng_state_tensor, const bool use_2d_quantization, const bool row_scaled_nvfp4, - const bool use_4over6, const NVTENVFP44Over6ErrMode nvfp4_4over6_err_mode, - const bool nvfp4_4over6_err_fast_math, const SimpleTensor& noop_tensor, cudaStream_t stream) { + const bool use_4over6, const bool use_4over6_e4m3_use_256, + const NVTENVFP44Over6ErrMode nvfp4_4over6_err_mode, const bool nvfp4_4over6_err_use_fast_math, + const SimpleTensor& noop_tensor, cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_vector_blockwise_fp4); #if CUDA_VERSION >= 12080 @@ -830,9 +835,12 @@ void quantize_transpose_vector_blockwise_fp4( "Row-scaled NVFP4 quantization does not support 2D quantization."); NVTE_CHECK(!use_4over6 || !use_stochastic_rounding, "NVFP4 4over6 quantization does not support stochastic rounding."); + NVTE_CHECK(use_4over6 || !use_4over6_e4m3_use_256, + "NVFP4 4over6 E4M3 256 scale bound requires 4over6 quantization."); const NVTENVFP44Over6ErrMode use_4over6_err_mode = use_4over6 ? nvfp4_4over6_err_mode : kNVTENVFP44Over6ErrMAE; - const bool use_4over6_err_fast_math = use_4over6 && nvfp4_4over6_err_fast_math; + const bool use_4over6_err_use_fast_math = use_4over6 && nvfp4_4over6_err_use_fast_math; + const bool enabled_4over6_e4m3_use_256 = use_4over6 && use_4over6_e4m3_use_256; const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; size_t num_elements = row_length; @@ -906,69 +914,102 @@ void quantize_transpose_vector_blockwise_fp4( TRANSFORMER_ENGINE_SWITCH_CONDITION( swizzled_scale, kSwizzledScale, - TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_stochastic_rounding, kApplyStochasticRounding, - - TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_2d_quantization, kIs2DBlockScaling, - - TRANSFORMER_ENGINE_SWITCH_CONDITION( - row_scaled_nvfp4, kRowScaledNVFP4, - - TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_4over6, kUse4Over6, - - TRANSFORMER_ENGINE_NVFP4_4OVER6_ERR_MODE_SWITCH( - use_4over6_err_mode, k4Over6ErrMode, - - TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_4over6_err_fast_math, kUse4Over6ErrFastMath, - - size_t smem_bytes = kSMemSize * sizeof(InputType); - auto kernel = - block_scaled_1d_cast_transpose_kernel< - kReturnIdentity, kReturnTranspose, - kPow2Scale, kAligned, float, InputType, - OutputType, ScaleType, kSwizzledScale, - kApplyStochasticRounding, - kIs2DBlockScaling, kRowScaledNVFP4, - kUse4Over6, k4Over6ErrMode, - kUse4Over6ErrFastMath>; - if (smem_bytes >= 48 * 1024) { - cudaError_t err = cudaFuncSetAttribute( - kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_bytes); - NVTE_CHECK(err == cudaSuccess, - "Failed to set dynamic shared " - "memory size."); - } kernel<<>>( - reinterpret_cast( - input.dptr), - reinterpret_cast( - global_amax.dptr), - reinterpret_cast(output.dptr), - reinterpret_cast(output_t.dptr), - reinterpret_cast(scale_inv.dptr), - reinterpret_cast( - scale_inv_t.dptr), - row_length, num_rows, scale_stride_x, - scale_stride_y, scale_t_stride_x, - scale_t_stride_y, kScaleBlockDim, epsilon, - rng_state, - noop_ptr);) // kUse4Over6ErrFastMath - ) // k4Over6ErrMode - ) // kUse4Over6 - ) // kRowScaledNVFP4 - ) // kIs2DBlockScaling - ) // kApplyStochasticRounding - ) // kSwizzledScale - ) // kAligned - ) // kReturnTranspose - ) // kReturnIdentity - ) // OutputType - ) // InputType + if (use_4over6) { + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_2d_quantization, kIs2DBlockScaling, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + row_scaled_nvfp4, kRowScaledNVFP4, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + enabled_4over6_e4m3_use_256, kUse4Over6E4M3Use256, + + TRANSFORMER_ENGINE_NVFP4_4OVER6_ERR_MODE_SWITCH( + use_4over6_err_mode, k4Over6ErrMode, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_4over6_err_use_fast_math, + kUse4Over6ErrUseFastMath, + + size_t smem_bytes = kSMemSize * sizeof(InputType); + auto kernel = block_scaled_1d_cast_transpose_kernel< + kReturnIdentity, kReturnTranspose, kPow2Scale, + kAligned, float, InputType, OutputType, + ScaleType, kSwizzledScale, + /*kApplyStochasticRounding=*/false, + kIs2DBlockScaling, kRowScaledNVFP4, + /*kUse4Over6=*/true, kUse4Over6E4M3Use256, + k4Over6ErrMode, kUse4Over6ErrUseFastMath>; + if (smem_bytes >= 48 * 1024) { + cudaError_t err = cudaFuncSetAttribute( + kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_bytes); + NVTE_CHECK(err == cudaSuccess, + "Failed to set dynamic shared memory " + "size."); + } kernel<<>>( + reinterpret_cast(input.dptr), + reinterpret_cast( + global_amax.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), + row_length, num_rows, scale_stride_x, + scale_stride_y, scale_t_stride_x, + scale_t_stride_y, kScaleBlockDim, epsilon, + rng_state, + noop_ptr);) // kUse4Over6ErrUseFastMath + ) // k4Over6ErrMode + ) // kUse4Over6E4M3Use256 + ) // kRowScaledNVFP4 + ) // kIs2DBlockScaling + } else { + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, kApplyStochasticRounding, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_2d_quantization, kIs2DBlockScaling, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + row_scaled_nvfp4, kRowScaledNVFP4, + + size_t smem_bytes = kSMemSize * sizeof(InputType); + auto kernel = block_scaled_1d_cast_transpose_kernel< + kReturnIdentity, kReturnTranspose, kPow2Scale, kAligned, + float, InputType, OutputType, ScaleType, kSwizzledScale, + kApplyStochasticRounding, kIs2DBlockScaling, + kRowScaledNVFP4, /*kUse4Over6=*/false, + /*kUse4Over6E4M3Use256=*/false, kNVTENVFP44Over6ErrMAE, + /*kUse4Over6ErrUseFastMath=*/false>; + if (smem_bytes >= 48 * 1024) { + cudaError_t err = cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_bytes); + NVTE_CHECK(err == cudaSuccess, + "Failed to set dynamic shared memory size."); + } kernel<<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(global_amax.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), + row_length, num_rows, scale_stride_x, scale_stride_y, + scale_t_stride_x, scale_t_stride_y, kScaleBlockDim, + epsilon, rng_state, + noop_ptr);) // kRowScaledNVFP4 + ) // kIs2DBlockScaling + ) // kApplyStochasticRounding + }) // kSwizzledScale + ) // kAligned + ) // kReturnTranspose + ) // kReturnIdentity + ) // OutputType + ) // InputType NVTE_CHECK_CUDA(cudaGetLastError()); #else diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index e97c14e68e..8b86f26b24 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -329,6 +329,8 @@ class NVFP4Quantizer : public Quantizer { bool stochastic_rounding; // Whether emitted NVFP4 tensors use 4over6 candidate selection. bool use_4over6; + // Whether emitted NVFP4 4over6 tensors use 256 as the global E4M3 scale bound. + bool four_over_six_e4m3_use_256; NVTENVFP44Over6ErrMode nvfp4_4over6_err_mode; // Whether tensors emitted by this quantizer use row-scaled NVFP4 metadata. bool row_scaled_nvfp4; diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 83ac117e2c..caaca112d2 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -725,6 +725,7 @@ std::tuple, std::vector, bool> bulk_alloc const auto rowwise_usage = quantizer_cpp_list[0]->rowwise_usage; const bool row_scaled_nvfp4 = quantizer_cpp_list[0]->row_scaled_nvfp4; const bool use_4over6 = quantizer_cpp_list[0]->use_4over6; + const bool four_over_six_e4m3_use_256 = quantizer_cpp_list[0]->four_over_six_e4m3_use_256; const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage; if (row_scaled_nvfp4) { NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 bulk allocation requires rowwise usage."); @@ -872,7 +873,8 @@ std::tuple, std::vector, bool> bulk_alloc tensor_py_list.emplace_back(NVFP4TensorClass( rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, amax_rowwise, amax_columnwise, fp4_dtype, quantizer_py_list[i], with_gemm_swizzled_scales, - py::arg("row_scaled_nvfp4") = row_scaled_nvfp4, py::arg("use_4over6") = use_4over6)); + py::arg("row_scaled_nvfp4") = row_scaled_nvfp4, py::arg("use_4over6") = use_4over6, + py::arg("four_over_six_e4m3_use_256") = four_over_six_e4m3_use_256)); // Construct C++ tensor // Use a TensorWrapper variable to hold the output of makeTransformerEngineTensor, @@ -891,6 +893,7 @@ std::tuple, std::vector, bool> bulk_alloc tensor_wrapper.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); tensor_wrapper.set_row_scaled_nvfp4(row_scaled_nvfp4); tensor_wrapper.set_nvfp4_4over6(use_4over6); + tensor_wrapper.set_nvfp4_4over6_e4m3_use_256(four_over_six_e4m3_use_256); // Set the amax rowwise and amax columnwise if available if (rowwise_usage) { @@ -1039,9 +1042,13 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, need_separate_rng_states, quant_config_list, quant_config_list_colwise); for (auto &config : quant_config_list) { + config.set_nvfp4_4over6(quantizer.use_4over6); + config.set_nvfp4_4over6_e4m3_use_256(quantizer.four_over_six_e4m3_use_256); config.set_nvfp4_4over6_err_mode(quantizer.nvfp4_4over6_err_mode); } for (auto &config : quant_config_list_colwise) { + config.set_nvfp4_4over6(quantizer.use_4over6); + config.set_nvfp4_4over6_e4m3_use_256(quantizer.four_over_six_e4m3_use_256); config.set_nvfp4_4over6_err_mode(quantizer.nvfp4_4over6_err_mode); } @@ -1053,7 +1060,7 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, // 2. when RHT cast fusion is available, fusion allows cast to be performed on FP32 data, // this will essentially remove a round trip between FP32 to BF16 then FP32 // NVFP4 4over6 candidate error math is controlled separately by - // NVTE_NVFP4_4OVER6_ERR_FAST_MATH. + // NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH. const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); if (use_fast_math && !quantizer.use_4over6) { for (auto &config : quant_config_list) { @@ -1064,14 +1071,14 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, } } - const auto use_4over6_err_fast_math = - transformer_engine::getenv("NVTE_NVFP4_4OVER6_ERR_FAST_MATH"); - if (use_4over6_err_fast_math) { + const auto use_4over6_err_use_fast_math = + transformer_engine::getenv("NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH"); + if (use_4over6_err_use_fast_math) { for (auto &config : quant_config_list) { - config.set_nvfp4_4over6_err_fast_math(true); + config.set_nvfp4_4over6_err_use_fast_math(true); } for (auto &config : quant_config_list_colwise) { - config.set_nvfp4_4over6_err_fast_math(true); + config.set_nvfp4_4over6_err_use_fast_math(true); } } @@ -1219,11 +1226,12 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, for (auto &config : quant_config_list) { config.set_nvfp4_4over6(quantizer.use_4over6); + config.set_nvfp4_4over6_e4m3_use_256(quantizer.four_over_six_e4m3_use_256); config.set_nvfp4_4over6_err_mode(quantizer.nvfp4_4over6_err_mode); } // NVFP4 4over6 candidate error math is controlled separately by - // NVTE_NVFP4_4OVER6_ERR_FAST_MATH. + // NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH. const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); if (use_fast_math && !quantizer.use_4over6) { for (auto &config : quant_config_list) { @@ -1231,11 +1239,11 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, } } - const auto use_4over6_err_fast_math = - transformer_engine::getenv("NVTE_NVFP4_4OVER6_ERR_FAST_MATH"); - if (use_4over6_err_fast_math) { + const auto use_4over6_err_use_fast_math = + transformer_engine::getenv("NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH"); + if (use_4over6_err_use_fast_math) { for (auto &config : quant_config_list) { - config.set_nvfp4_4over6_err_fast_math(true); + config.set_nvfp4_4over6_err_use_fast_math(true); } } diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 5a0759a99c..e671fa996c 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1730,6 +1730,8 @@ NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantize this->with_2d_quantization = quantizer.attr("with_2d_quantization").cast(); this->stochastic_rounding = quantizer.attr("stochastic_rounding").cast(); this->use_4over6 = quantizer.attr("use_4over6").cast(); + this->four_over_six_e4m3_use_256 = + this->use_4over6 && quantizer.attr("four_over_six_e4m3_use_256").cast(); const auto nvfp4_4over6_err_mode = quantizer.attr("four_over_six_err_mode").cast(); if (nvfp4_4over6_err_mode == "MAE") { this->nvfp4_4over6_err_mode = kNVTENVFP44Over6ErrMAE; @@ -1788,6 +1790,7 @@ std::pair NVFP4Quantizer::create_tensor( " (got shape=", shape, ")"); const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; const bool use_4over6 = this->use_4over6; + const bool four_over_six_e4m3_use_256 = this->four_over_six_e4m3_use_256; if (row_scaled_nvfp4) { NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 quantization requires rowwise usage."); NVTE_CHECK(!columnwise_usage, @@ -1856,6 +1859,7 @@ std::pair NVFP4Quantizer::create_tensor( kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); kwargs["use_4over6"] = py::cast(use_4over6); + kwargs["four_over_six_e4m3_use_256"] = py::cast(four_over_six_e4m3_use_256); kwargs["fake_dtype"] = GetATenDType(dtype); py::tuple args(0); @@ -1887,6 +1891,7 @@ std::pair NVFP4Quantizer::create_tensor( kwargs["device"] = py::cast(device); kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); kwargs["use_4over6"] = py::cast(use_4over6); + kwargs["four_over_six_e4m3_use_256"] = py::cast(four_over_six_e4m3_use_256); py::tuple args(0); PyObject* result = PyObject_Call(reinterpret_cast(NVFP4TensorPythonClass), args.ptr(), kwargs.ptr()); @@ -1921,6 +1926,7 @@ std::pair NVFP4Quantizer::create_tensor( out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); out_cpp.set_row_scaled_nvfp4(row_scaled_nvfp4); out_cpp.set_nvfp4_4over6(use_4over6); + out_cpp.set_nvfp4_4over6_e4m3_use_256(four_over_six_e4m3_use_256); this->set_quantization_params(&out_cpp); return {std::move(out_cpp), std::move(out_py)}; @@ -1950,6 +1956,7 @@ std::pair NVFP4Quantizer::create_grouped_tenso const std::vector logical_shape_vec = {logical_first_dim, logical_last_dim}; const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; const bool use_4over6 = this->use_4over6; + const bool four_over_six_e4m3_use_256 = this->four_over_six_e4m3_use_256; if (row_scaled_nvfp4) { NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 grouped quantization requires rowwise usage."); NVTE_CHECK(!columnwise_usage, @@ -2025,6 +2032,7 @@ std::pair NVFP4Quantizer::create_grouped_tenso kwargs["with_gemm_swizzled_scales"] = this->optimize_for_gemm; kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); kwargs["use_4over6"] = py::cast(use_4over6); + kwargs["four_over_six_e4m3_use_256"] = py::cast(four_over_six_e4m3_use_256); PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); if (result == nullptr) { PyErr_Print(); @@ -2100,6 +2108,8 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( const auto [flat_first_dim, flat_last_dim] = get_2d_dims(shape); const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; + const bool use_4over6 = this->use_4over6; + const bool four_over_six_e4m3_use_256 = this->four_over_six_e4m3_use_256; if (row_scaled_nvfp4) { NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 quantization requires rowwise usage."); NVTE_CHECK(!columnwise_usage, @@ -2107,6 +2117,7 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( } tensor.attr("_row_scaled_nvfp4") = py::cast(row_scaled_nvfp4); tensor.attr("_use_4over6") = py::cast(use_4over6); + tensor.attr("_four_over_six_e4m3_use_256") = py::cast(four_over_six_e4m3_use_256); // Coerce row-wise data if (rowwise_usage) { @@ -2212,6 +2223,7 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); out_cpp.set_row_scaled_nvfp4(row_scaled_nvfp4); out_cpp.set_nvfp4_4over6(use_4over6); + out_cpp.set_nvfp4_4over6_e4m3_use_256(four_over_six_e4m3_use_256); this->set_quantization_params(&out_cpp); return {std::move(out_cpp), std::move(tensor)}; @@ -2303,8 +2315,10 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou quant_config.set_nvfp4_2d_quantization(this->with_2d_quantization); quant_config.set_stochastic_rounding(this->stochastic_rounding); quant_config.set_nvfp4_4over6(this->use_4over6); + quant_config.set_nvfp4_4over6_e4m3_use_256(this->four_over_six_e4m3_use_256); quant_config.set_nvfp4_4over6_err_mode(this->nvfp4_4over6_err_mode); quant_config_columnwise.set_nvfp4_4over6(this->use_4over6); + quant_config_columnwise.set_nvfp4_4over6_e4m3_use_256(this->four_over_six_e4m3_use_256); quant_config_columnwise.set_nvfp4_4over6_err_mode(this->nvfp4_4over6_err_mode); if (this->use_4over6) { @@ -2453,18 +2467,18 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou // 2. when RHT cast fusion is available, fusion allows cast to be performed on FP32 data, // this will essentially remove a round trip between FP32 to BF16 then FP32 // NVFP4 4over6 candidate error math is controlled separately by - // NVTE_NVFP4_4OVER6_ERR_FAST_MATH. + // NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH. const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); if (use_fast_math && !this->use_4over6) { quant_config.set_use_fast_math(true); quant_config_columnwise.set_use_fast_math(true); } - const auto use_4over6_err_fast_math = - transformer_engine::getenv("NVTE_NVFP4_4OVER6_ERR_FAST_MATH"); - if (use_4over6_err_fast_math) { - quant_config.set_nvfp4_4over6_err_fast_math(true); - quant_config_columnwise.set_nvfp4_4over6_err_fast_math(true); + const auto use_4over6_err_use_fast_math = + transformer_engine::getenv("NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH"); + if (use_4over6_err_use_fast_math) { + quant_config.set_nvfp4_4over6_err_use_fast_math(true); + quant_config_columnwise.set_nvfp4_4over6_err_use_fast_math(true); } if (this->with_rht) { diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index 69ed98b435..fc707a209d 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -136,6 +136,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) const bool with_gemm_swizzled_scales = tensor.attr("_with_gemm_swizzled_scales").cast(); const bool row_scaled_nvfp4 = tensor.attr("_row_scaled_nvfp4").cast(); const bool use_4over6 = tensor.attr("_use_4over6").cast(); + const bool four_over_six_e4m3_use_256 = tensor.attr("_four_over_six_e4m3_use_256").cast(); NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for NVFP4 Tensor."); @@ -167,6 +168,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) ret.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); ret.set_row_scaled_nvfp4(row_scaled_nvfp4); ret.set_nvfp4_4over6(use_4over6); + ret.set_nvfp4_4over6_e4m3_use_256(four_over_six_e4m3_use_256); // Quantizer state quantizer->set_quantization_params(&ret); diff --git a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py index a6c1bffe60..7d8e8d4745 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py @@ -222,6 +222,7 @@ class NVFP4TensorRef(QuantizedTensorStorage): global_amax_row: Optional[torch.Tensor] = None global_amax_col: Optional[torch.Tensor] = None use_4over6: bool = False + four_over_six_e4m3_use_256: bool = False dtype: Optional[torch.dtype] = None device: Optional[torch.device] = None @@ -352,6 +353,7 @@ def __init__( quant_tile_shape: Tuple[int, int] = (1, 16), row_scaled_nvfp4: bool = False, use_4over6: bool = False, + four_over_six_e4m3_use_256: bool = False, four_over_six_err_mode: str = "MAE", with_rht: bool = False, with_random_sign_mask: bool = True, @@ -382,6 +384,7 @@ def __init__( self.quant_tile_shape = quant_tile_shape self.row_scaled_nvfp4 = row_scaled_nvfp4 self.use_4over6 = use_4over6 + self.four_over_six_e4m3_use_256 = use_4over6 and four_over_six_e4m3_use_256 self.four_over_six_err_mode = four_over_six_err_mode self.with_rht = with_rht self.with_random_sign_mask = with_random_sign_mask @@ -471,18 +474,23 @@ def _quantize_blockwise_4over6_reference( row_scaled_nvfp4: bool, tile_len_y: int, four_over_six_err_mode: str, + four_over_six_e4m3_use_256: bool, ) -> Tuple[torch.Tensor, torch.Tensor]: """Quantize NVFP4 with 4over6 candidate selection. This mirrors the CUDA path: map-to-4 uses a 1.5x expanded E4M3 block scale, the configured error is computed in the original input domain with the - 6 * 256 denominator, and ties choose map-to-6. + selected global E4M3 denominator, and ties choose map-to-6. """ m, num_blocks, tile_len_x = x.shape n = num_blocks * tile_len_x FLOAT4_E2M1_MAX = torch.tensor(6.0, device=x.device, dtype=torch.float32) FLOAT8_E4M3_MAX = torch.tensor(448.0, device=x.device, dtype=torch.float32) - GLOBAL_SCALE_E4M3_MAX = torch.tensor(256.0, device=x.device, dtype=torch.float32) + GLOBAL_SCALE_E4M3_MAX = torch.tensor( + 256.0 if four_over_six_e4m3_use_256 else 448.0, + device=x.device, + dtype=torch.float32, + ) decode_scale_base = torch.div(vec_max, FLOAT4_E2M1_MAX) * global_encode_scale decode_scale_map4 = decode_scale_base * 1.5 @@ -576,6 +584,7 @@ def _quantize_blockwise_reference( pow_2_scales: bool, row_scaled_nvfp4: bool = False, use_4over6: bool = False, + four_over_six_e4m3_use_256: bool = False, four_over_six_err_mode: str = "MAE", eps: float, # pylint: disable=unused-argument ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -609,8 +618,9 @@ def _quantize_blockwise_reference( x = x.view(m, n // tile_len_x, tile_len_x) FLOAT4_E2M1_MAX = torch.tensor(6.0, device=x.device, dtype=torch.float32) FLOAT8_E4M3_MAX = torch.tensor(448.0, device=x.device, dtype=torch.float32) + global_scale_e4m3_max = 256.0 if (use_4over6 and four_over_six_e4m3_use_256) else 448.0 GLOBAL_SCALE_E4M3_MAX = torch.tensor( - 256.0 if use_4over6 else 448.0, device=x.device, dtype=torch.float32 + global_scale_e4m3_max, device=x.device, dtype=torch.float32 ) decode_scale = torch.div(vec_max, FLOAT4_E2M1_MAX) @@ -656,6 +666,7 @@ def _quantize_blockwise_reference( row_scaled_nvfp4, tile_len_y, four_over_six_err_mode, + four_over_six_e4m3_use_256, ) global_encode_scale_multiplier = global_encode_scale * torch.reciprocal(FLOAT4_E2M1_MAX) @@ -819,6 +830,7 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ pow_2_scales=self.pow_2_scales, row_scaled_nvfp4=self.row_scaled_nvfp4, use_4over6=self.use_4over6, + four_over_six_e4m3_use_256=self.four_over_six_e4m3_use_256, four_over_six_err_mode=self.four_over_six_err_mode, eps=self.eps, ) @@ -844,6 +856,7 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ self.quant_tile_shape[0], pow_2_scales=self.pow_2_scales, use_4over6=self.use_4over6, + four_over_six_e4m3_use_256=self.four_over_six_e4m3_use_256, four_over_six_err_mode=self.four_over_six_err_mode, eps=self.eps, ) @@ -885,6 +898,7 @@ def quantize( global_amax_row=global_amax_row, global_amax_col=global_amax_col, use_4over6=self.use_4over6, + four_over_six_e4m3_use_256=self.four_over_six_e4m3_use_256, dtype=tensor.dtype, device=tensor.device, quant_dtype=self.dtype, @@ -933,6 +947,7 @@ def update_quantized( dst.global_amax_row = global_amax_row dst.global_amax_col = global_amax_col dst.use_4over6 = self.use_4over6 + dst.four_over_six_e4m3_use_256 = self.four_over_six_e4m3_use_256 dst.dtype = src.dtype dst.quant_dtype = self.dtype dst.original_shape = original_shape @@ -1044,11 +1059,21 @@ def qgemm( qresult_w_use_4over6 = getattr( qresult_w, "use_4over6", getattr(qresult_w, "_use_4over6", self.use_4over6) ) - if qresult_x_use_4over6: + qresult_x_use_256 = getattr( + qresult_x, + "four_over_six_e4m3_use_256", + getattr(qresult_x, "_four_over_six_e4m3_use_256", self.four_over_six_e4m3_use_256), + ) + qresult_w_use_256 = getattr( + qresult_w, + "four_over_six_e4m3_use_256", + getattr(qresult_w, "_four_over_six_e4m3_use_256", self.four_over_six_e4m3_use_256), + ) + if qresult_x_use_4over6 and qresult_x_use_256: fp8_max_x = 256.0 else: fp8_max_x = 448.0 - if qresult_w_use_4over6: + if qresult_w_use_4over6 and qresult_w_use_256: fp8_max_w = 256.0 else: fp8_max_w = 448.0 diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 403c5cbc0d..bab8d15559 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -1662,6 +1662,14 @@ def _make(tensor_type: str) -> NVFP4Quantizer: use_4over6 = tensor_type == "weight" elif self.recipe.nvfp4_4over6 == "activations": use_4over6 = tensor_type != "weight" + use_4over6_e4m3_use_256 = False + if use_4over6: + if self.recipe.nvfp4_4over6_e4m3_use_256 == "all": + use_4over6_e4m3_use_256 = True + elif self.recipe.nvfp4_4over6_e4m3_use_256 == "weights": + use_4over6_e4m3_use_256 = tensor_type == "weight" + elif self.recipe.nvfp4_4over6_e4m3_use_256 == "activations": + use_4over6_e4m3_use_256 = tensor_type != "weight" return NVFP4Quantizer( fp4_dtype=self.dtype, rowwise=True, @@ -1676,6 +1684,7 @@ def _make(tensor_type: str) -> NVFP4Quantizer: and self.recipe.row_scaled_activation ), use_4over6=use_4over6, + four_over_six_e4m3_use_256=use_4over6_e4m3_use_256, four_over_six_err_mode=self.recipe.nvfp4_4over6_err_mode, ) diff --git a/transformer_engine/pytorch/tensor/grouped_tensor.py b/transformer_engine/pytorch/tensor/grouped_tensor.py index 834194a0ce..e47389dc27 100644 --- a/transformer_engine/pytorch/tensor/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/grouped_tensor.py @@ -94,6 +94,7 @@ def __new__( with_gemm_swizzled_scales: bool = False, row_scaled_nvfp4: bool = False, use_4over6: bool = False, + four_over_six_e4m3_use_256: bool = False, ): if ( shapes is not None @@ -168,6 +169,7 @@ def __new__( with_gemm_swizzled_scales=with_gemm_swizzled_scales, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, + four_over_six_e4m3_use_256=four_over_six_e4m3_use_256, ) return instance @@ -201,6 +203,7 @@ def copy_grouped_storage_metadata(dst: GroupedTensor, src: GroupedTensor) -> Non dst._with_gemm_swizzled_scales = src._with_gemm_swizzled_scales dst.row_scaled_nvfp4 = src.row_scaled_nvfp4 dst.use_4over6 = src.use_4over6 + dst.four_over_six_e4m3_use_256 = src.four_over_six_e4m3_use_256 def make_wrapper_like(src: GroupedTensor, requires_grad: bool) -> GroupedTensor: """Create a wrapper of the same type and tensor metadata as src.""" diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index c4def40e4f..cef3bd4f2b 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -132,6 +132,8 @@ class NVFP4Quantizer(Quantizer): row_scaled_nvfp4: bool """Whether to use NVFP4 4over6 map-to-4/map-to-6 block selection.""" use_4over6: bool + """Whether 4over6 uses 256 instead of 448 as the global E4M3 scale bound.""" + four_over_six_e4m3_use_256: bool """NVFP4 4over6 candidate-selection error mode.""" four_over_six_err_mode: str @@ -152,6 +154,7 @@ def __init__( stochastic_rounding: bool = False, row_scaled_nvfp4: bool = False, use_4over6: bool = False, + four_over_six_e4m3_use_256: bool = False, four_over_six_err_mode: str = "MAE", with_random_sign_mask: bool = True, ) -> None: @@ -165,6 +168,7 @@ def __init__( self.stochastic_rounding = stochastic_rounding self.row_scaled_nvfp4 = row_scaled_nvfp4 self.use_4over6 = use_4over6 + self.four_over_six_e4m3_use_256 = use_4over6 and four_over_six_e4m3_use_256 self.four_over_six_err_mode = four_over_six_err_mode.upper() if self.four_over_six_err_mode not in ("MAE", "MSE"): raise ValueError("four_over_six_err_mode must be 'MAE' or 'MSE'.") @@ -215,6 +219,8 @@ def copy(self) -> NVFP4Quantizer: stochastic_rounding=self.stochastic_rounding, row_scaled_nvfp4=self.row_scaled_nvfp4, use_4over6=self.use_4over6, + four_over_six_e4m3_use_256=self.four_over_six_e4m3_use_256, + four_over_six_err_mode=self.four_over_six_err_mode, ) quantizer.internal = self.internal quantizer.optimize_for_gemm = self.optimize_for_gemm @@ -368,6 +374,7 @@ def __new__( with_gemm_swizzled_scales: bool, row_scaled_nvfp4: bool = False, use_4over6: bool = False, + four_over_six_e4m3_use_256: bool = False, **kwargs, ): instance = super().__new__( @@ -384,6 +391,7 @@ def __new__( *args, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, + four_over_six_e4m3_use_256=four_over_six_e4m3_use_256, **kwargs, ) return instance @@ -543,6 +551,7 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m self._amax_columnwise, self._row_scaled_nvfp4, self._use_4over6, + self._four_over_six_e4m3_use_256, self.shape[-1], ) return sharded_tensors, metadata @@ -568,6 +577,7 @@ def fsdp_post_all_gather( amax_columnwise, row_scaled_nvfp4, use_4over6, + four_over_six_e4m3_use_256, K, ) = metadata @@ -594,6 +604,7 @@ def fsdp_post_all_gather( out._amax_columnwise = amax_columnwise out._row_scaled_nvfp4 = row_scaled_nvfp4 out._use_4over6 = use_4over6 + out._four_over_six_e4m3_use_256 = four_over_six_e4m3_use_256 else: # Construct new tensor (first iteration) out = NVFP4Tensor( @@ -611,6 +622,7 @@ def fsdp_post_all_gather( with_gemm_swizzled_scales=False, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, + four_over_six_e4m3_use_256=four_over_six_e4m3_use_256, ) # Derive columnwise data locally via transpose instead of all-gathering it @@ -751,6 +763,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, row_scaled_nvfp4=tensor._row_scaled_nvfp4, use_4over6=tensor._use_4over6, + four_over_six_e4m3_use_256=tensor._four_over_six_e4m3_use_256, ) # Default case @@ -772,6 +785,7 @@ def _make_in_reduce_ex( with_gemm_swizzled_scales: bool = False, row_scaled_nvfp4: bool = False, use_4over6: bool = False, + four_over_six_e4m3_use_256: bool = False, ) -> NVFP4Tensor: """Build NVFP4Tensor, for use in __reduce__ @@ -794,6 +808,7 @@ def _make_in_reduce_ex( with_gemm_swizzled_scales=with_gemm_swizzled_scales, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, + four_over_six_e4m3_use_256=four_over_six_e4m3_use_256, ) def __reduce_ex__(self, protocol: int) -> tuple: @@ -814,6 +829,7 @@ def __reduce_ex__(self, protocol: int) -> tuple: self._with_gemm_swizzled_scales, self._row_scaled_nvfp4, self._use_4over6, + self._four_over_six_e4m3_use_256, ), ) @@ -868,6 +884,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: self._with_gemm_swizzled_scales = tensor._with_gemm_swizzled_scales self._row_scaled_nvfp4 = tensor._row_scaled_nvfp4 self._use_4over6 = tensor._use_4over6 + self._four_over_six_e4m3_use_256 = tensor._four_over_six_e4m3_use_256 return # Quantize to FP8 @@ -990,6 +1007,7 @@ def forward( with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, row_scaled_nvfp4=tensor._row_scaled_nvfp4, use_4over6=tensor._use_4over6, + four_over_six_e4m3_use_256=tensor._four_over_six_e4m3_use_256, ) @staticmethod @@ -1034,6 +1052,7 @@ def backward( with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales, row_scaled_nvfp4=grad._row_scaled_nvfp4, use_4over6=grad._use_4over6, + four_over_six_e4m3_use_256=grad._four_over_six_e4m3_use_256, ) return dgrad, None return grad.view(ctx.shape), None @@ -1120,6 +1139,7 @@ def forward( with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, row_scaled_nvfp4=tensor._row_scaled_nvfp4, use_4over6=tensor._use_4over6, + four_over_six_e4m3_use_256=tensor._four_over_six_e4m3_use_256, ) @staticmethod @@ -1164,6 +1184,7 @@ def backward( with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales, row_scaled_nvfp4=grad._row_scaled_nvfp4, use_4over6=grad._use_4over6, + four_over_six_e4m3_use_256=grad._four_over_six_e4m3_use_256, ) return dgrad, None return grad.view(ctx.shape), None diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index 713d04ccf2..dfde0e9e18 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -74,6 +74,7 @@ def _initialize_storage_fields( with_gemm_swizzled_scales: bool = False, row_scaled_nvfp4: bool = False, use_4over6: bool = False, + four_over_six_e4m3_use_256: bool = False, ) -> None: """ Initialize a GroupedTensor. @@ -151,6 +152,7 @@ def _initialize_storage_fields( instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales instance.row_scaled_nvfp4 = row_scaled_nvfp4 instance.use_4over6 = use_4over6 + instance.four_over_six_e4m3_use_256 = use_4over6 and four_over_six_e4m3_use_256 def __new__( cls, @@ -178,6 +180,7 @@ def __new__( with_gemm_swizzled_scales: bool = False, row_scaled_nvfp4: bool = False, use_4over6: bool = False, + four_over_six_e4m3_use_256: bool = False, ): instance = object.__new__(cls) cls._initialize_storage_fields( @@ -205,6 +208,7 @@ def __new__( with_gemm_swizzled_scales=with_gemm_swizzled_scales, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, + four_over_six_e4m3_use_256=four_over_six_e4m3_use_256, ) return instance @@ -329,6 +333,15 @@ def use_4over6(self) -> bool: def use_4over6(self, use_4over6: bool) -> None: self._use_4over6 = use_4over6 + @property + def four_over_six_e4m3_use_256(self) -> bool: + """Whether grouped NVFP4 4over6 tensors use the 256 E4M3 scale bound.""" + return self._four_over_six_e4m3_use_256 + + @four_over_six_e4m3_use_256.setter + def four_over_six_e4m3_use_256(self, four_over_six_e4m3_use_256: bool) -> None: + self._four_over_six_e4m3_use_256 = four_over_six_e4m3_use_256 + def prepare_for_saving( self, ) -> Tuple[list[Optional[torch.Tensor]], "GroupedTensorStorage"]: @@ -399,6 +412,7 @@ def clear(self) -> None: self.fake_dtype = torch.float32 self.row_scaled_nvfp4 = False self.use_4over6 = False + self.four_over_six_e4m3_use_256 = False def __repr__(self) -> str: """String representation of the GroupedTensorStorage.""" @@ -569,6 +583,7 @@ def copy(self) -> "GroupedTensorStorage": with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, row_scaled_nvfp4=self.row_scaled_nvfp4, use_4over6=self.use_4over6, + four_over_six_e4m3_use_256=self.four_over_six_e4m3_use_256, ) @staticmethod @@ -681,6 +696,7 @@ def make_grouped_tensor( columnwise_scale_inv_offsets = None row_scaled_nvfp4 = False use_4over6 = False + four_over_six_e4m3_use_256 = False if no_quantization: assert dtype is not None, "dtype must be provided for unquantized GroupedTensor" if rowwise_usage: @@ -741,6 +757,7 @@ def make_grouped_tensor( elif quantizer._get_compatible_recipe().nvfp4(): row_scaled_nvfp4 = quantizer.row_scaled_nvfp4 use_4over6 = quantizer.use_4over6 + four_over_six_e4m3_use_256 = quantizer.four_over_six_e4m3_use_256 if row_scaled_nvfp4: if not rowwise_usage: raise ValueError( @@ -870,6 +887,7 @@ def make_grouped_tensor( ), row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, + four_over_six_e4m3_use_256=four_over_six_e4m3_use_256, ) grouped_tensor.quantized_tensors = grouped_tensor.split_into_quantized_tensors() return grouped_tensor @@ -985,6 +1003,7 @@ def split_into_quantized_tensors( nvfp4_rowwise_amax_offsets = None row_scaled_nvfp4 = self.row_scaled_nvfp4 use_4over6 = self.use_4over6 + four_over_six_e4m3_use_256 = self.four_over_six_e4m3_use_256 if recipe.nvfp4() and row_scaled_nvfp4: cum = 0 nvfp4_rowwise_amax_offsets = [0] @@ -1213,6 +1232,7 @@ def split_into_quantized_tensors( with_gemm_swizzled_scales=quantizer.optimize_for_gemm, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, + four_over_six_e4m3_use_256=four_over_six_e4m3_use_256, ) result.append(tensor) diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index ff4dcd8a79..6c3f8ca70c 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -101,6 +101,8 @@ class NVFP4TensorStorage(QuantizedTensorStorage): _row_scaled_nvfp4: bool # Whether this NVFP4 tensor uses 4over6 map-to-4/map-to-6 block selection _use_4over6: bool + # Whether this 4over6 tensor uses 256 instead of 448 as the global E4M3 scale bound + _four_over_six_e4m3_use_256: bool def __new__( cls, @@ -117,6 +119,7 @@ def __new__( fake_dtype: Optional[torch.dtype] = None, row_scaled_nvfp4: bool = False, use_4over6: bool = False, + four_over_six_e4m3_use_256: bool = False, **kwargs, ): if cls is NVFP4TensorStorage: @@ -136,6 +139,7 @@ def __new__( instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales instance._row_scaled_nvfp4 = row_scaled_nvfp4 instance._use_4over6 = use_4over6 + instance._four_over_six_e4m3_use_256 = use_4over6 and four_over_six_e4m3_use_256 return instance @@ -164,6 +168,8 @@ def copy_from_storage(self, src: QuantizedTensorStorage) -> None: raise RuntimeError("Rowwise amax scaling mode mismatch in copy_from_storage") if self._use_4over6 != src._use_4over6: raise RuntimeError("NVFP4 4over6 mode mismatch in copy_from_storage") + if self._four_over_six_e4m3_use_256 != src._four_over_six_e4m3_use_256: + raise RuntimeError("NVFP4 4over6 E4M3 scale bound mismatch in copy_from_storage") def _copy_optional(dst: Optional[torch.Tensor], src_tensor: Optional[torch.Tensor]): if dst is not None and src_tensor is not None: @@ -190,6 +196,7 @@ def get_metadata(self) -> Dict[str, Any]: "with_gemm_swizzled_scales": self._with_gemm_swizzled_scales, "row_scaled_nvfp4": self._row_scaled_nvfp4, "use_4over6": self._use_4over6, + "four_over_six_e4m3_use_256": self._four_over_six_e4m3_use_256, "fake_dtype": self._dtype, } @@ -324,6 +331,7 @@ def view(self, shape: torch.Size): with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, row_scaled_nvfp4=self._row_scaled_nvfp4, use_4over6=self._use_4over6, + four_over_six_e4m3_use_256=self._four_over_six_e4m3_use_256, fake_dtype=self._dtype, ) From 38a1c4c1179bff09707cf55d3cb3903c3bf54815 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 12 May 2026 16:39:27 -0700 Subject: [PATCH 50/57] Use e4m3 max instead of boolean, more template Signed-off-by: Ziang Li --- .../cpp/operator/test_cast_nvfp4_transpose.cu | 84 ++++++------ tests/cpp/operator/test_dequantize_nvfp4.cu | 54 ++++---- tests/cpp/test_common.cu | 12 +- tests/cpp/test_common.h | 4 +- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 62 ++++----- .../nvfp4/test_nvfp4_quantize_exact.py | 52 +++---- tests/pytorch/test_recipe.py | 46 +++---- .../common/cast/dispatch/quantize.cuh | 37 +++-- .../common/cast/nvfp4/core_nvfp4.cuh | 11 +- .../common/cast/nvfp4/dequantize_nvfp4.cuh | 33 +++-- .../cast/nvfp4/quantize_4over6_nvfp4.cuh | 128 +++++++++++------- .../cast/nvfp4/quantize_transpose_nvfp4.cuh | 113 +++++++--------- .../quantize_transpose_nvfp4_tuned_1D.cuh | 93 ++++++------- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 4 +- transformer_engine/common/common.h | 16 +-- .../transformer_engine/transformer_engine.h | 41 +++--- transformer_engine/common/recipe/__init__.py | 8 +- transformer_engine/common/recipe/nvfp4.cu | 4 +- .../common/transformer_engine.cpp | 20 +-- .../common/transpose/cast_transpose.h | 2 +- ...quantize_transpose_vector_blockwise_fp4.cu | 68 +++++----- transformer_engine/pytorch/csrc/common.h | 4 +- .../pytorch/csrc/extensions/cast.cpp | 12 +- transformer_engine/pytorch/csrc/quantizer.cpp | 29 ++-- .../pytorch/csrc/type_converters.cpp | 4 +- .../custom_recipes/quantization_ref_nvfp4.py | 72 +++++----- transformer_engine/pytorch/quantization.py | 20 +-- .../pytorch/tensor/grouped_tensor.py | 6 +- .../pytorch/tensor/nvfp4_tensor.py | 54 ++++---- .../tensor/storage/grouped_tensor_storage.py | 34 ++--- .../tensor/storage/nvfp4_tensor_storage.py | 14 +- 31 files changed, 579 insertions(+), 562 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 4a36edfca2..b13102f1f9 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -63,15 +63,13 @@ std::vector create_transpose(const InputType* const input, const size // Compute the global encode scale factor for a given global amax float compute_global_encode_scaling_factor_FP4(const float global_amax, const bool use_fast_math, - const bool use_e4m3_256 = false) { - float fp8_max = 448.0f; - if (use_e4m3_256) { - fp8_max = 256.0f; - } + const int e4m3_max = 448) { + NVTE_CHECK(e4m3_max == 448 || e4m3_max == 256, "Unsupported NVFP4 E4M3 max."); + const float fp8_max = static_cast(e4m3_max); constexpr float fp4_max = 6.0f; // 6.0f; float global_encode_scale = fp8_max * fp4_max / global_amax; // If scale is infinity, return the max normalized value - const float max_norm_clamp = (use_fast_math && !use_e4m3_256) + const float max_norm_clamp = (use_fast_math && e4m3_max == 448) ? Numeric_Traits::maxNorm : Numeric_Traits::maxNorm; @@ -105,7 +103,7 @@ enum class NVFP4ScalingMode { struct NVFP4FourOverSixTestConfig { bool enabled = false; - bool e4m3_use_256 = false; + int e4m3_max = 448; NVTENVFP44Over6ErrMode err_mode = kNVTENVFP44Over6ErrMAE; bool err_use_fast_math = false; }; @@ -203,13 +201,13 @@ void quantize_nvfp4_1d(float (*OP)(const float), const float global_amax, const bool use_fast_math, const bool use_4over6 = false, - const bool use_e4m3_256 = false, + const int e4m3_max = 448, const NVFP4FourOverSixCandidate four_over_six_candidate = NVFP4FourOverSixCandidate::Map6) { // Compute a global encoding/decoding scaling factor for all S_dec_b const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math, - use_e4m3_256); + e4m3_max); constexpr size_t block_size_X = 16; const size_t blocks_X = divide_round_up(cols, block_size_X); @@ -306,12 +304,12 @@ void compute_2d_mathematical_scales(float (*OP)(const float), std::vector>& math_scales, const bool use_fast_math, const bool use_4over6 = false, - const bool use_e4m3_256 = false, + const int e4m3_max = 448, const NVFP4FourOverSixCandidate four_over_six_candidate = NVFP4FourOverSixCandidate::Map6) { const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math, - use_e4m3_256); + e4m3_max); constexpr size_t block_size_Y = 16; constexpr size_t block_size_X = 16; const size_t blocks_Y = divide_round_up(rows, block_size_Y); @@ -365,17 +363,17 @@ void quantize_nvfp4_2d(float (*OP)(const float), const float global_amax, const bool use_fast_math, const bool use_4over6 = false, - const bool use_e4m3_256 = false, + const int e4m3_max = 448, const NVFP4FourOverSixCandidate four_over_six_candidate = NVFP4FourOverSixCandidate::Map6) { // Step 1: Compute mathematical 8x8 scaling factors std::vector> math_scales; compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math, - use_4over6, use_e4m3_256, four_over_six_candidate); + use_4over6, e4m3_max, four_over_six_candidate); const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math, - use_e4m3_256); + e4m3_max); constexpr size_t block_size_Y = 16; constexpr size_t block_size_X = 16; const size_t blocks_Y = divide_round_up(rows, block_size_Y); @@ -459,15 +457,15 @@ void quantize_nvfp4(float (*OP)(const float), const bool use_fast_math, const bool use_2d_quantization = false, const bool use_4over6 = false, - const bool use_e4m3_256 = false, + const int e4m3_max = 448, const NVFP4FourOverSixCandidate four_over_six_candidate = NVFP4FourOverSixCandidate::Map6) { if (use_2d_quantization) { quantize_nvfp4_2d(OP, input, output, scales, rows, cols, scales_stride, global_amax, - use_fast_math, use_4over6, use_e4m3_256, four_over_six_candidate); + use_fast_math, use_4over6, e4m3_max, four_over_six_candidate); } else { quantize_nvfp4_1d(OP, input, output, scales, rows, cols, scales_stride, global_amax, - use_fast_math, use_4over6, use_e4m3_256, four_over_six_candidate); + use_fast_math, use_4over6, e4m3_max, four_over_six_candidate); } } @@ -487,7 +485,7 @@ void compute_ref(float (*OP)(const float), const bool use_2d_quantization = false, const bool row_scaled_nvfp4 = false, const bool use_4over6 = false, - const bool use_e4m3_256 = false, + const int e4m3_max = 448, const NVFP4FourOverSixCandidate four_over_six_candidate = NVFP4FourOverSixCandidate::Map6) { @@ -500,7 +498,7 @@ void compute_ref(float (*OP)(const float), // Step 1: Compute mathematical 8×8 scaling factors std::vector> math_scales; compute_2d_mathematical_scales(OP, input, rows, cols, *amax, math_scales, use_fast_math, - use_4over6, use_e4m3_256, four_over_six_candidate); + use_4over6, e4m3_max, four_over_six_candidate); constexpr size_t block_size_Y = 16; constexpr size_t block_size_X = 16; @@ -528,10 +526,10 @@ void compute_ref(float (*OP)(const float), // Step 4: Process quantized outputs using the same algorithm as quantize_nvfp4_2d // (This part processes the actual FP4 data using the mathematical scaling factors) quantize_nvfp4_2d(OP, input, output, nullptr, rows, cols, scales_stride, *amax, - use_fast_math, use_4over6, use_e4m3_256, + use_fast_math, use_4over6, e4m3_max, four_over_six_candidate); // scales already filled quantize_nvfp4_2d(OP, input_t.data(), output_t, nullptr, cols, rows, scales_stride_t, *amax, - use_fast_math, use_4over6, use_e4m3_256, + use_fast_math, use_4over6, e4m3_max, four_over_six_candidate); // scales_t already filled return; @@ -551,7 +549,7 @@ void compute_ref(float (*OP)(const float), use_fast_math, use_2d_quantization, use_4over6, - use_e4m3_256, + e4m3_max, four_over_six_candidate); } return; @@ -559,10 +557,10 @@ void compute_ref(float (*OP)(const float), // Ref impl for basic NVFP4 quantize_nvfp4(OP, input, output, scales, rows, cols, scales_stride, *amax, - use_fast_math, use_2d_quantization, use_4over6, use_e4m3_256, + use_fast_math, use_2d_quantization, use_4over6, e4m3_max, four_over_six_candidate); quantize_nvfp4(OP, input_t.data(), output_t, scales_t, cols, rows, scales_stride_t, *amax, - use_fast_math, use_2d_quantization, use_4over6, use_e4m3_256, + use_fast_math, use_2d_quantization, use_4over6, e4m3_max, four_over_six_candidate); } @@ -790,7 +788,7 @@ void performTest(float (*OP)(const float), const bool use_fast_math, const NVFP4ScalingMode scaling_mode = NVFP4ScalingMode::Block1D, const bool use_4over6 = false, - const bool use_e4m3_256 = false, + const int e4m3_max = 448, const NVTENVFP44Over6ErrMode err_mode = kNVTENVFP44Over6ErrMAE, const bool use_4over6_err_use_fast_math = false) { using namespace test; @@ -833,7 +831,7 @@ void performTest(float (*OP)(const float), Tensor input("input", shape, itype); Tensor output("output", shape, otype, rowwise, columnwise, NVTE_NVFP4_1D_SCALING); output.set_nvfp4_4over6(use_4over6); - output.set_nvfp4_4over6_e4m3_use_256(use_4over6 && use_e4m3_256); + output.set_nvfp4_e4m3_max((use_4over6 ? e4m3_max : 448)); std::unique_ptr ref_output = std::make_unique(rows * (cols / 2)); std::unique_ptr ref_output_t = std::make_unique(cols * (rows / 2)); @@ -847,7 +845,7 @@ void performTest(float (*OP)(const float), fillCase(&input, InputsFillCase::uniform); if (use_4over6 && is_row_scaled_nvfp4) { - const float target_row_amax = (use_e4m3_256 ? 256.0f : 448.0f) * 6.0f * 8.0f; + const float target_row_amax = static_cast(e4m3_max) * 6.0f * 8.0f; auto *input_vals = input.rowwise_cpu_dptr(); for (size_t row = 0; row < rows; ++row) { float row_amax = 0.0f; @@ -899,7 +897,7 @@ void performTest(float (*OP)(const float), } else { // Golden value of amax chosen to make the 2nd-stage scaling mantissa zero and avoid rounding issues if (use_4over6) { - ref_amax.assign(1, (use_e4m3_256 ? 256.0f : 448.0f) * 6.0f * 8.0f); + ref_amax.assign(1, static_cast(e4m3_max) * 6.0f * 8.0f); } else { ref_amax.assign(1, 448.0f * 6.0f * 8.0f); } @@ -935,7 +933,7 @@ void performTest(float (*OP)(const float), is_2d_quantization, is_row_scaled_nvfp4, use_4over6, - use_e4m3_256, + e4m3_max, NVFP4FourOverSixCandidate::Map4); compute_ref(OP, input.rowwise_cpu_dptr(), @@ -952,7 +950,7 @@ void performTest(float (*OP)(const float), is_2d_quantization, is_row_scaled_nvfp4, use_4over6, - use_e4m3_256, + e4m3_max, NVFP4FourOverSixCandidate::Map6); } else { compute_ref(OP, @@ -985,7 +983,7 @@ void performTest(float (*OP)(const float), quant_config.set_rng_state(rng_state.data()); quant_config.set_nvfp4_2d_quantization(is_2d_quantization); quant_config.set_nvfp4_4over6(use_4over6); - quant_config.set_nvfp4_4over6_e4m3_use_256(use_4over6 && use_e4m3_256); + quant_config.set_nvfp4_e4m3_max((use_4over6 ? e4m3_max : 448)); quant_config.set_nvfp4_4over6_err_mode(err_mode); quant_config.set_nvfp4_4over6_err_use_fast_math(use_4over6 && use_4over6_err_use_fast_math); @@ -1132,7 +1130,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, performTest(OP, tensor_dims, use_fast_math, scaling_mode, config.enabled, - config.e4m3_use_256, config.err_mode, + config.e4m3_max, config.err_mode, config.err_use_fast_math); ); } @@ -1172,10 +1170,10 @@ std::string test_name(const FusedCastTransposeNVFP4TestSuite::ParamType& param) const NVFP4FourOverSixTestConfig& config = std::get<5>(param); if (config.enabled) { name += "X4OVER6"; - if (config.e4m3_use_256) { - name += "XE4M3_USE_256"; + if (config.e4m3_max == 448) { + name += "XE4M3_MAX_448"; } else { - name += "XE4M3_USE_448"; + name += "XE4M3_MAX_256"; } if (config.err_mode == kNVTENVFP44Over6ErrMSE) { name += "XMSE"; @@ -1229,14 +1227,14 @@ INSTANTIATE_TEST_SUITE_P( NVFP4ScalingMode::RowScaled1D, NVFP4ScalingMode::Block2D), // scaling_mode ::testing::Values( - NVFP4FourOverSixTestConfig{true, false, kNVTENVFP44Over6ErrMAE, false}, - NVFP4FourOverSixTestConfig{true, false, kNVTENVFP44Over6ErrMAE, true}, - NVFP4FourOverSixTestConfig{true, false, kNVTENVFP44Over6ErrMSE, false}, - NVFP4FourOverSixTestConfig{true, false, kNVTENVFP44Over6ErrMSE, true}, - NVFP4FourOverSixTestConfig{true, true, kNVTENVFP44Over6ErrMAE, false}, - NVFP4FourOverSixTestConfig{true, true, kNVTENVFP44Over6ErrMAE, true}, - NVFP4FourOverSixTestConfig{true, true, kNVTENVFP44Over6ErrMSE, false}, - NVFP4FourOverSixTestConfig{true, true, kNVTENVFP44Over6ErrMSE, true})), // four_over_six_config + NVFP4FourOverSixTestConfig{true, 448, kNVTENVFP44Over6ErrMAE, false}, + NVFP4FourOverSixTestConfig{true, 448, kNVTENVFP44Over6ErrMAE, true}, + NVFP4FourOverSixTestConfig{true, 448, kNVTENVFP44Over6ErrMSE, false}, + NVFP4FourOverSixTestConfig{true, 448, kNVTENVFP44Over6ErrMSE, true}, + NVFP4FourOverSixTestConfig{true, 256, kNVTENVFP44Over6ErrMAE, false}, + NVFP4FourOverSixTestConfig{true, 256, kNVTENVFP44Over6ErrMAE, true}, + NVFP4FourOverSixTestConfig{true, 256, kNVTENVFP44Over6ErrMSE, false}, + NVFP4FourOverSixTestConfig{true, 256, kNVTENVFP44Over6ErrMSE, true})), // four_over_six_config [](const testing::TestParamInfo& info) { return test_name(info.param); }); diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu index 823fb2f9c8..0dfaca7e58 100644 --- a/tests/cpp/operator/test_dequantize_nvfp4.cu +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -47,8 +47,8 @@ void compute_ref_dequantize_nvfp4(const uint8_t *packed_data, size_t rows, size_t cols, size_t scale_stride, - bool use_e4m3_256) { - const float factor_inv = 1.0f / (6.0f * (use_e4m3_256 ? 256.0f : 448.0f)); + int e4m3_max) { + const float factor_inv = 1.0f / (6.0f * static_cast(e4m3_max)); constexpr size_t BLOCK_SIZE = 16; const size_t Mread = cols / BLOCK_SIZE; const size_t bytes_per_block = BLOCK_SIZE / 2; @@ -93,7 +93,7 @@ template void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, const bool row_scaled_nvfp4, const bool use_4over6, - const bool use_e4m3_256) { + const int e4m3_max) { using namespace test; DType otype = TypeInfo::dtype; @@ -109,9 +109,9 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, // Configure quantized tensor amax size_t amax_size = 1; quantized.set_nvfp4_4over6(use_4over6); - quantized.set_nvfp4_4over6_e4m3_use_256(use_4over6 && use_e4m3_256); + quantized.set_nvfp4_e4m3_max((use_4over6 ? e4m3_max : 448)); ASSERT_EQ(quantized.nvfp4_4over6(), use_4over6); - ASSERT_EQ(quantized.nvfp4_4over6_e4m3_use_256(), use_4over6 && use_e4m3_256); + ASSERT_EQ(quantized.nvfp4_e4m3_max(), (use_4over6 ? e4m3_max : 448)); if (row_scaled_nvfp4) { quantized.set_row_scaled_nvfp4(true); amax_size = rows; @@ -125,7 +125,7 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, if (rows > 0 && cols > 0) { QuantizationConfigWrapper quant_config; quant_config.set_nvfp4_4over6(use_4over6); - quant_config.set_nvfp4_4over6_e4m3_use_256(use_4over6 && use_e4m3_256); + quant_config.set_nvfp4_e4m3_max((use_4over6 ? e4m3_max : 448)); nvte_quantize_v2(input.data(), quantized.data(), quant_config, 0); cudaDeviceSynchronize(); auto err = cudaGetLastError(); @@ -156,7 +156,7 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, std::make_unique(rows * cols); compute_ref_dequantize_nvfp4( fp4_data, scales, amax_vals, ref_output.get(), - rows, cols, scale_stride, use_4over6 && use_e4m3_256); + rows, cols, scale_stride, (use_4over6 ? e4m3_max : 448)); // Compare results from TE and reference impls auto [atol, rtol] = getTolerances(otype); @@ -168,7 +168,7 @@ template void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, const bool row_scaled_nvfp4, const bool use_4over6, - const bool use_e4m3_256) { + const int e4m3_max) { using namespace test; DType otype = TypeInfo::dtype; @@ -178,9 +178,9 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, Tensor quantized_compact("quantized_compact", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); quantized_compact.set_nvfp4_4over6(use_4over6); - quantized_compact.set_nvfp4_4over6_e4m3_use_256(use_4over6 && use_e4m3_256); + quantized_compact.set_nvfp4_e4m3_max((use_4over6 ? e4m3_max : 448)); ASSERT_EQ(quantized_compact.nvfp4_4over6(), use_4over6); - ASSERT_EQ(quantized_compact.nvfp4_4over6_e4m3_use_256(), use_4over6 && use_e4m3_256); + ASSERT_EQ(quantized_compact.nvfp4_e4m3_max(), (use_4over6 ? e4m3_max : 448)); if (row_scaled_nvfp4) { quantized_compact.set_row_scaled_nvfp4(true); } else if (rows > 0 && cols > 0) { @@ -192,7 +192,7 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, if (rows > 0 && cols > 0) { QuantizationConfigWrapper quant_config; quant_config.set_nvfp4_4over6(use_4over6); - quant_config.set_nvfp4_4over6_e4m3_use_256(use_4over6 && use_e4m3_256); + quant_config.set_nvfp4_e4m3_max((use_4over6 ? e4m3_max : 448)); nvte_quantize_v2(input.data(), quantized_compact.data(), quant_config, 0); cudaDeviceSynchronize(); } @@ -206,9 +206,9 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, Tensor quantized_swizzled("quantized_swizzled", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); quantized_swizzled.set_nvfp4_4over6(use_4over6); - quantized_swizzled.set_nvfp4_4over6_e4m3_use_256(use_4over6 && use_e4m3_256); + quantized_swizzled.set_nvfp4_e4m3_max((use_4over6 ? e4m3_max : 448)); ASSERT_EQ(quantized_swizzled.nvfp4_4over6(), use_4over6); - ASSERT_EQ(quantized_swizzled.nvfp4_4over6_e4m3_use_256(), use_4over6 && use_e4m3_256); + ASSERT_EQ(quantized_swizzled.nvfp4_e4m3_max(), (use_4over6 ? e4m3_max : 448)); if (row_scaled_nvfp4) { quantized_swizzled.set_row_scaled_nvfp4(true); } else { @@ -285,7 +285,7 @@ class DequantizeNVFP4TestSuite : public ::testing::TestWithParam transformer_engine::DType, bool, bool, - bool>> {}; + int>> {}; TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4) { @@ -297,11 +297,11 @@ TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4) const DType output_type = std::get<1>(GetParam()); const bool row_scaled_nvfp4 = std::get<2>(GetParam()); const bool use_4over6 = std::get<3>(GetParam()); - const bool use_e4m3_256 = use_4over6 && std::get<4>(GetParam()); + const int e4m3_max = use_4over6 ? std::get<4>(GetParam()) : 448; TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, performTest_dequantize_nvfp4( - tensor_size.first, tensor_size.second, row_scaled_nvfp4, use_4over6, use_e4m3_256); + tensor_size.first, tensor_size.second, row_scaled_nvfp4, use_4over6, e4m3_max); ); } @@ -313,7 +313,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), ::testing::Bool(), ::testing::Bool(), - ::testing::Bool()), + ::testing::Values(448, 256)), [](const testing::TestParamInfo& info) { std::string name = std::to_string(std::get<0>(info.param).first) + "X" + @@ -322,9 +322,9 @@ INSTANTIATE_TEST_SUITE_P( (std::get<2>(info.param) ? "RowScaled" : "PerTensor") + "X" + (std::get<3>(info.param) ? "FourOverSix" : "Default") + "X" + (std::get<3>(info.param) - ? (std::get<4>(info.param) ? "E4M3Use256" : "E4M3Use448") - : (std::get<4>(info.param) ? "E4M3Use256Ignored" - : "E4M3Use448")); + ? (std::get<4>(info.param) == 256 ? "E4M3Max256" : "E4M3Max448") + : (std::get<4>(info.param) == 256 ? "E4M3Max256Ignored" + : "E4M3Max448")); return name; } ); @@ -334,7 +334,7 @@ class DequantizeNVFP4SwizzledTestSuite : public ::testing::TestWithParam transformer_engine::DType, bool, bool, - bool>> {}; + int>> {}; TEST_P(DequantizeNVFP4SwizzledTestSuite, TestDequantizeNVFP4Swizzled) { @@ -346,11 +346,11 @@ TEST_P(DequantizeNVFP4SwizzledTestSuite, TestDequantizeNVFP4Swizzled) const DType output_type = std::get<1>(GetParam()); const bool row_scaled_nvfp4 = std::get<2>(GetParam()); const bool use_4over6 = std::get<3>(GetParam()); - const bool use_e4m3_256 = use_4over6 && std::get<4>(GetParam()); + const int e4m3_max = use_4over6 ? std::get<4>(GetParam()) : 448; TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, performTest_dequantize_nvfp4_swizzled( - tensor_size.first, tensor_size.second, row_scaled_nvfp4, use_4over6, use_e4m3_256); + tensor_size.first, tensor_size.second, row_scaled_nvfp4, use_4over6, e4m3_max); ); } @@ -362,7 +362,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), ::testing::Bool(), ::testing::Bool(), - ::testing::Bool()), + ::testing::Values(448, 256)), [](const testing::TestParamInfo& info) { std::string name = std::to_string(std::get<0>(info.param).first) + "X" + @@ -371,9 +371,9 @@ INSTANTIATE_TEST_SUITE_P( (std::get<2>(info.param) ? "RowScaled" : "PerTensor") + "X" + (std::get<3>(info.param) ? "FourOverSix" : "Default") + "X" + (std::get<3>(info.param) - ? (std::get<4>(info.param) ? "E4M3Use256" : "E4M3Use448") - : (std::get<4>(info.param) ? "E4M3Use256Ignored" - : "E4M3Use448")) + "X" + + ? (std::get<4>(info.param) == 256 ? "E4M3Max256" : "E4M3Max448") + : (std::get<4>(info.param) == 256 ? "E4M3Max256Ignored" + : "E4M3Max448")) + "X" + "Swizzled"; return name; } diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 055c3b744e..3569ab258a 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -446,10 +446,10 @@ void Tensor::set_nvfp4_4over6(bool nvfp4_4over6) { tensor_.set_nvfp4_4over6(nvfp4_4over6); } -void Tensor::set_nvfp4_4over6_e4m3_use_256(bool nvfp4_4over6_e4m3_use_256) { +void Tensor::set_nvfp4_e4m3_max(int nvfp4_e4m3_max) { NVTE_CHECK(tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING, - "NVFP4 4over6 E4M3 256 scale bound is only supported for NVFP4 tensors."); - tensor_.set_nvfp4_4over6_e4m3_use_256(nvfp4_4over6_e4m3_use_256); + "NVFP4 E4M3 max is only supported for NVFP4 tensors."); + tensor_.set_nvfp4_e4m3_max(nvfp4_e4m3_max); } bool Tensor::nvfp4_4over6() const { @@ -458,10 +458,10 @@ bool Tensor::nvfp4_4over6() const { return tensor_.get_nvfp4_4over6(); } -bool Tensor::nvfp4_4over6_e4m3_use_256() const { +int Tensor::nvfp4_e4m3_max() const { NVTE_CHECK(tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING, - "NVFP4 4over6 E4M3 256 scale bound is only supported for NVFP4 tensors."); - return tensor_.get_nvfp4_4over6_e4m3_use_256(); + "NVFP4 E4M3 max is only supported for NVFP4 tensors."); + return tensor_.get_nvfp4_e4m3_max(); } void Tensor::to_cpu() { diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 8c6c900040..851593cae7 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -294,14 +294,14 @@ class Tensor { } bool nvfp4_4over6() const; - bool nvfp4_4over6_e4m3_use_256() const; + int nvfp4_e4m3_max() const; void set_tensor_amax_nullptr(); void set_with_gemm_swizzled_scales(bool with_gemm_swizzled_scales); void set_row_scaled_nvfp4(bool row_scaled_nvfp4); void set_nvfp4_4over6(bool nvfp4_4over6); - void set_nvfp4_4over6_e4m3_use_256(bool nvfp4_4over6_e4m3_use_256); + void set_nvfp4_e4m3_max(int nvfp4_e4m3_max); void to_cpu(); void from_cpu(); diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 5f9a839a78..16a9387dc6 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -29,11 +29,11 @@ def check_nvfp4_gemm_versus_reference( w_columnwise: bool = False, row_scaled_nvfp4: bool = False, use_4over6: bool = False, - four_over_six_e4m3_use_256: bool = False, - four_over_six_err_mode: str = "MAE", + nvfp4_e4m3_max: int = 448, + nvfp4_4over6_err_mode: str = "MAE", ): - if four_over_six_e4m3_use_256 and not use_4over6: - pytest.skip("E4M3 256 bound is only meaningful for 4over6") + if nvfp4_e4m3_max != 448 and not use_4over6: + pytest.skip("E4M3 max 256 is only meaningful for 4over6") te_dtype = tex.DType.kFloat4E2M1 # Setup device and random seed @@ -65,8 +65,8 @@ def check_nvfp4_gemm_versus_reference( with_post_rht_amax=False, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, - four_over_six_e4m3_use_256=four_over_six_e4m3_use_256, - four_over_six_err_mode=four_over_six_err_mode, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) w_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -77,8 +77,8 @@ def check_nvfp4_gemm_versus_reference( with_rht=False, with_post_rht_amax=False, use_4over6=use_4over6, - four_over_six_e4m3_use_256=four_over_six_e4m3_use_256, - four_over_six_err_mode=four_over_six_err_mode, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) # Quantize x and w @@ -135,8 +135,8 @@ def check_nvfp4_gemm_versus_reference( quant_tile_shape=(1, 16), row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, - four_over_six_e4m3_use_256=four_over_six_e4m3_use_256, - four_over_six_err_mode=four_over_six_err_mode, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) w_ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, @@ -146,8 +146,8 @@ def check_nvfp4_gemm_versus_reference( eps=0.0, quant_tile_shape=(1, 16), use_4over6=use_4over6, - four_over_six_e4m3_use_256=four_over_six_e4m3_use_256, - four_over_six_err_mode=four_over_six_err_mode, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) # Create reference quantized tensors needed by reference GEMM @@ -250,7 +250,7 @@ def check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( use_bias: bool, single_output: bool, use_4over6: bool = False, - four_over_six_err_mode: str = "MAE", + nvfp4_4over6_err_mode: str = "MAE", ): te_dtype = tex.DType.kFloat4E2M1 device = "cuda" @@ -269,7 +269,7 @@ def check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( with_post_rht_amax=False, row_scaled_nvfp4=True, use_4over6=use_4over6, - four_over_six_err_mode=four_over_six_err_mode, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) w_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -280,7 +280,7 @@ def check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( with_rht=False, with_post_rht_amax=False, use_4over6=use_4over6, - four_over_six_err_mode=four_over_six_err_mode, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) x_nvfp4 = [] @@ -345,7 +345,7 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( K: int, N: int, use_4over6: bool = False, - four_over_six_err_mode: str = "MAE", + nvfp4_4over6_err_mode: str = "MAE", ): te_dtype = tex.DType.kFloat4E2M1 device = "cuda" @@ -365,7 +365,7 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( with_post_rht_amax=False, row_scaled_nvfp4=True, use_4over6=use_4over6, - four_over_six_err_mode=four_over_six_err_mode, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) x_tensorwise_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -376,7 +376,7 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( with_rht=False, with_post_rht_amax=False, use_4over6=use_4over6, - four_over_six_err_mode=four_over_six_err_mode, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) w_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -387,7 +387,7 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( with_rht=False, with_post_rht_amax=False, use_4over6=use_4over6, - four_over_six_err_mode=four_over_six_err_mode, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) x_row_scaled = x_row_scaled_quantizer.update_quantized( @@ -449,8 +449,8 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) @pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) -@pytest.mark.parametrize("four_over_six_e4m3_use_256", [False, True], ids=["e4m3_448", "e4m3_256"]) -@pytest.mark.parametrize("four_over_six_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) +@pytest.mark.parametrize("nvfp4_e4m3_max", [448, 256], ids=["e4m3_448", "e4m3_256"]) +@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) def test_nvfp4_gemm_versus_reference( M: int, K: int, @@ -463,8 +463,8 @@ def test_nvfp4_gemm_versus_reference( is_w_columnwise: bool, row_scaled_nvfp4: bool, use_4over6: bool, - four_over_six_e4m3_use_256: bool, - four_over_six_err_mode: str, + nvfp4_e4m3_max: int, + nvfp4_4over6_err_mode: str, ): if row_scaled_nvfp4: if accumulate: @@ -484,8 +484,8 @@ def test_nvfp4_gemm_versus_reference( w_columnwise=is_w_columnwise, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, - four_over_six_e4m3_use_256=four_over_six_e4m3_use_256, - four_over_six_err_mode=four_over_six_err_mode, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) @@ -512,7 +512,7 @@ def test_nvfp4_gemm_versus_reference( @pytest.mark.parametrize("use_bias", [False, True], ids=["no_bias", "bias"]) @pytest.mark.parametrize("single_output", [False, True], ids=["list_output", "single_output"]) @pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) -@pytest.mark.parametrize("four_over_six_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) +@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) def test_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( m_splits: list[int], k: int, @@ -523,7 +523,7 @@ def test_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( use_bias: bool, single_output: bool, use_4over6: bool, - four_over_six_err_mode: str, + nvfp4_4over6_err_mode: str, ): check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( x_dtype=x_dtype, @@ -535,7 +535,7 @@ def test_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( use_bias=use_bias, single_output=single_output, use_4over6=use_4over6, - four_over_six_err_mode=four_over_six_err_mode, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) @@ -560,7 +560,7 @@ def test_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( @pytest.mark.parametrize("w_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) @pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) -@pytest.mark.parametrize("four_over_six_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) +@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) def test_nvfp4_row_scaled_gemm_matches_emulated( M: int, K: int, @@ -569,7 +569,7 @@ def test_nvfp4_row_scaled_gemm_matches_emulated( w_dtype: torch.dtype, out_dtype: torch.dtype, use_4over6: bool, - four_over_six_err_mode: str, + nvfp4_4over6_err_mode: str, ): check_nvfp4_row_scaled_gemm_matches_emulated( x_dtype=x_dtype, @@ -579,5 +579,5 @@ def test_nvfp4_row_scaled_gemm_matches_emulated( K=K, N=N, use_4over6=use_4over6, - four_over_six_err_mode=four_over_six_err_mode, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 12ffdc9329..043d71919e 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -53,11 +53,11 @@ def check_quantization_nvfp4_versus_reference( with_2d_quantization: bool, row_scaled_nvfp4: bool = False, use_4over6: bool = False, - four_over_six_e4m3_use_256: bool = False, - four_over_six_err_mode: str = "MAE", + nvfp4_e4m3_max: int = 448, + nvfp4_4over6_err_mode: str = "MAE", ) -> None: - if four_over_six_e4m3_use_256 and not use_4over6: - pytest.skip("E4M3 256 bound is only meaningful for 4over6") + if nvfp4_e4m3_max != 448 and not use_4over6: + pytest.skip("E4M3 max 256 is only meaningful for 4over6") maybe_skip_row_scaled_unsupported_quantization( row_scaled_nvfp4, return_transpose, with_2d_quantization, use_4over6, x_dtype, M, N ) @@ -84,8 +84,8 @@ def check_quantization_nvfp4_versus_reference( with_2d_quantization=with_2d_quantization, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, - four_over_six_e4m3_use_256=four_over_six_e4m3_use_256, - four_over_six_err_mode=four_over_six_err_mode, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) if use_cpp_allocator: x_nvfp4_sut = nvfp4_quantizer(x) @@ -120,8 +120,8 @@ def check_quantization_nvfp4_versus_reference( quant_tile_shape=quant_tile_shape, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, - four_over_six_e4m3_use_256=four_over_six_e4m3_use_256, - four_over_six_err_mode=four_over_six_err_mode, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -198,8 +198,8 @@ def check_quantization_nvfp4_versus_reference( ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) @pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) -@pytest.mark.parametrize("four_over_six_e4m3_use_256", [False, True], ids=["e4m3_448", "e4m3_256"]) -@pytest.mark.parametrize("four_over_six_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) +@pytest.mark.parametrize("nvfp4_e4m3_max", [448, 256], ids=["e4m3_448", "e4m3_256"]) +@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) def test_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, M: int, @@ -210,8 +210,8 @@ def test_quantization_block_tiling_versus_reference( with_2d_quantization: bool, row_scaled_nvfp4: bool, use_4over6: bool, - four_over_six_e4m3_use_256: bool, - four_over_six_err_mode: str, + nvfp4_e4m3_max: int, + nvfp4_4over6_err_mode: str, ) -> None: check_quantization_nvfp4_versus_reference( x_dtype=x_dtype, @@ -223,8 +223,8 @@ def test_quantization_block_tiling_versus_reference( with_2d_quantization=with_2d_quantization, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, - four_over_six_e4m3_use_256=four_over_six_e4m3_use_256, - four_over_six_err_mode=four_over_six_err_mode, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) @@ -243,7 +243,7 @@ def test_quantization_block_tiling_versus_reference( ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) @pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) -@pytest.mark.parametrize("four_over_six_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) +@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) def test_nvfp4_quantization_extrema_versus_reference( x_dtype: torch.dtype, M: int, @@ -253,7 +253,7 @@ def test_nvfp4_quantization_extrema_versus_reference( use_cpp_allocator: bool, row_scaled_nvfp4: bool, use_4over6: bool, - four_over_six_err_mode: str, + nvfp4_4over6_err_mode: str, ): maybe_skip_row_scaled_unsupported_quantization( row_scaled_nvfp4, return_transpose, use_4over6=use_4over6 @@ -281,7 +281,7 @@ def test_nvfp4_quantization_extrema_versus_reference( with_post_rht_amax=False, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, - four_over_six_err_mode=four_over_six_err_mode, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) if use_cpp_allocator: @@ -314,7 +314,7 @@ def test_nvfp4_quantization_extrema_versus_reference( quant_tile_shape=(1, 16), row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, - four_over_six_err_mode=four_over_six_err_mode, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -360,7 +360,7 @@ def test_nvfp4_quantization_extrema_versus_reference( ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) @pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) -@pytest.mark.parametrize("four_over_six_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) +@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) def test_nvfp4_quantization_boundary_values( x_dtype: torch.dtype, M: int, @@ -369,7 +369,7 @@ def test_nvfp4_quantization_boundary_values( use_cpp_allocator: bool, row_scaled_nvfp4: bool, use_4over6: bool, - four_over_six_err_mode: str, + nvfp4_4over6_err_mode: str, ): """ Stress rounding/threshold behavior by placing values just below/above @@ -411,7 +411,7 @@ def test_nvfp4_quantization_boundary_values( with_post_rht_amax=False, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, - four_over_six_err_mode=four_over_six_err_mode, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) if use_cpp_allocator: @@ -444,7 +444,7 @@ def test_nvfp4_quantization_boundary_values( quant_tile_shape=(1, 16), row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, - four_over_six_err_mode=four_over_six_err_mode, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -490,7 +490,7 @@ def test_nvfp4_quantization_boundary_values( ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) @pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) -@pytest.mark.parametrize("four_over_six_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) +@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) def test_nvfp4_quantization_noncontiguous_inputs( x_dtype: torch.dtype, M: int, @@ -499,7 +499,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( use_cpp_allocator: bool, row_scaled_nvfp4: bool, use_4over6: bool, - four_over_six_err_mode: str, + nvfp4_4over6_err_mode: str, ): maybe_skip_row_scaled_unsupported_quantization( row_scaled_nvfp4, return_transpose, use_4over6=use_4over6 @@ -527,7 +527,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( with_post_rht_amax=False, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, - four_over_six_err_mode=four_over_six_err_mode, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) if use_cpp_allocator: @@ -560,7 +560,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( quant_tile_shape=(1, 16), row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, - four_over_six_err_mode=four_over_six_err_mode, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) x_nvfp4_ref = ref_quantizer.quantize(x_nc) diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 9425c7019a..0c8de0526f 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -521,20 +521,18 @@ def test_quantizer_update(self, module_class): ids=["default", "weights", "activations", "all"], ) @pytest.mark.parametrize( - "nvfp4_4over6_e4m3_use_256", + "nvfp4_e4m3_max", [None, "weights", "activations", "all"], ids=["e4m3_448", "e4m3_256_weights", "e4m3_256_activations", "e4m3_256_all"], ) @pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) -def test_nvfp4_row_scaled_quantizer_roles( - nvfp4_4over6, nvfp4_4over6_e4m3_use_256, nvfp4_4over6_err_mode -): +def test_nvfp4_row_scaled_quantizer_roles(nvfp4_4over6, nvfp4_e4m3_max, nvfp4_4over6_err_mode): recipe = NVFP4BlockScaling( disable_rht=True, disable_stochastic_rounding=True, disable_2d_quantization=True, nvfp4_4over6=nvfp4_4over6, - nvfp4_4over6_e4m3_use_256=nvfp4_4over6_e4m3_use_256, + nvfp4_e4m3_max=nvfp4_e4m3_max, nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, row_scaled_activation=True, ) @@ -548,16 +546,18 @@ def expected_use_4over6(tensor_type): return tensor_type != "weight" return False - def expected_e4m3_use_256(tensor_type): + def expected_e4m3_max(tensor_type): if not expected_use_4over6(tensor_type): - return False - if nvfp4_4over6_e4m3_use_256 == "all": - return True - if nvfp4_4over6_e4m3_use_256 == "weights": - return tensor_type == "weight" - if nvfp4_4over6_e4m3_use_256 == "activations": - return tensor_type != "weight" - return False + return 448 + if nvfp4_e4m3_max == "all": + return 256 + if nvfp4_e4m3_max == "weights": + if tensor_type == "weight": + return 256 + if nvfp4_e4m3_max == "activations": + if tensor_type != "weight": + return 256 + return 448 forward_quantizers = NVFP4BlockScalingRecipeState( recipe, @@ -568,10 +568,10 @@ def expected_e4m3_use_256(tensor_type): assert [q.use_4over6 for q in forward_quantizers] == [ expected_use_4over6(tensor_type) for tensor_type in ("input", "weight", "output") ] - assert [q.four_over_six_e4m3_use_256 for q in forward_quantizers] == [ - expected_e4m3_use_256(tensor_type) for tensor_type in ("input", "weight", "output") + assert [q.nvfp4_e4m3_max for q in forward_quantizers] == [ + expected_e4m3_max(tensor_type) for tensor_type in ("input", "weight", "output") ] - assert [q.four_over_six_err_mode for q in forward_quantizers] == [nvfp4_4over6_err_mode] * 3 + assert [q.nvfp4_4over6_err_mode for q in forward_quantizers] == [nvfp4_4over6_err_mode] * 3 assert not forward_quantizers[0].is_quantizable(torch.empty(16, 16)) assert forward_quantizers[1].is_quantizable(torch.empty(16, 16)) @@ -590,10 +590,10 @@ def expected_e4m3_use_256(tensor_type): assert [q.use_4over6 for q in role_quantizers] == [ expected_use_4over6(tensor_type) for tensor_type in ("weight", "input", "output", "input") ] - assert [q.four_over_six_e4m3_use_256 for q in role_quantizers] == [ - expected_e4m3_use_256(tensor_type) for tensor_type in ("weight", "input", "output", "input") + assert [q.nvfp4_e4m3_max for q in role_quantizers] == [ + expected_e4m3_max(tensor_type) for tensor_type in ("weight", "input", "output", "input") ] - assert [q.four_over_six_err_mode for q in role_quantizers] == [nvfp4_4over6_err_mode] * 4 + assert [q.nvfp4_4over6_err_mode for q in role_quantizers] == [nvfp4_4over6_err_mode] * 4 backward_quantizers = NVFP4BlockScalingRecipeState( recipe, @@ -608,10 +608,10 @@ def expected_e4m3_use_256(tensor_type): assert [q.use_4over6 for q in backward_quantizers] == [ expected_use_4over6(tensor_type) for tensor_type in ("grad_output", "grad_input") ] - assert [q.four_over_six_e4m3_use_256 for q in backward_quantizers] == [ - expected_e4m3_use_256(tensor_type) for tensor_type in ("grad_output", "grad_input") + assert [q.nvfp4_e4m3_max for q in backward_quantizers] == [ + expected_e4m3_max(tensor_type) for tensor_type in ("grad_output", "grad_input") ] - assert [q.four_over_six_err_mode for q in backward_quantizers] == [nvfp4_4over6_err_mode] * 2 + assert [q.nvfp4_4over6_err_mode for q in backward_quantizers] == [nvfp4_4over6_err_mode] * 2 @pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index e8acf5f276..ea016260d7 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -103,13 +103,12 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, const bool row_scaled_nvfp4 = output_tensor->row_scaled_nvfp4; NVTE_CHECK(quant_config_cpp.nvfp4_4over6 == output_tensor->nvfp4_4over6, "Tensor and quantization config have inconsistent options for NVFP4 4over6."); - NVTE_CHECK( - quant_config_cpp.nvfp4_4over6_e4m3_use_256 == output_tensor->nvfp4_4over6_e4m3_use_256, - "Tensor and quantization config have inconsistent options for NVFP4 4over6 " - "E4M3 scale bound."); + NVTE_CHECK(quant_config_cpp.nvfp4_e4m3_max == output_tensor->nvfp4_e4m3_max, + "Tensor and quantization config have inconsistent options for NVFP4 4over6 " + "E4M3 scale bound."); const bool use_4over6 = quant_config_cpp.nvfp4_4over6; - NVTE_CHECK(use_4over6 || !quant_config_cpp.nvfp4_4over6_e4m3_use_256, - "NVFP4 4over6 E4M3 256 scale bound requires 4over6 quantization."); + NVTE_CHECK(use_4over6 || quant_config_cpp.nvfp4_e4m3_max == 448, + "Non-4over6 NVFP4 quantization requires E4M3 max 448."); NVTE_CHECK(!use_4over6 || !quant_config_cpp.stochastic_rounding, "NVFP4 4over6 quantization does not support stochastic rounding."); if (row_scaled_nvfp4) { @@ -147,7 +146,7 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, /*row_scaled_nvfp4=*/row_scaled_nvfp4, /*use_4over6=*/use_4over6, - /*use_4over6_e4m3_use_256=*/quant_config_cpp.nvfp4_4over6_e4m3_use_256, + /*nvfp4_e4m3_max=*/quant_config_cpp.nvfp4_e4m3_max, /*nvfp4_4over6_err_mode=*/quant_config_cpp.nvfp4_4over6_err_mode, /*nvfp4_4over6_err_use_fast_math=*/quant_config_cpp.nvfp4_4over6_err_use_fast_math, /*noop_tensor=*/noop_tensor->data, @@ -266,13 +265,12 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens auto dtype = grad_tensor->dtype(); NVTE_CHECK(quant_config_cpp.nvfp4_4over6 == output_tensor->nvfp4_4over6, "Tensor and quantization config have inconsistent options for NVFP4 4over6."); - NVTE_CHECK( - quant_config_cpp.nvfp4_4over6_e4m3_use_256 == output_tensor->nvfp4_4over6_e4m3_use_256, - "Tensor and quantization config have inconsistent options for NVFP4 4over6 " - "E4M3 scale bound."); + NVTE_CHECK(quant_config_cpp.nvfp4_e4m3_max == output_tensor->nvfp4_e4m3_max, + "Tensor and quantization config have inconsistent options for NVFP4 4over6 " + "E4M3 scale bound."); const bool use_4over6 = quant_config_cpp.nvfp4_4over6; - NVTE_CHECK(use_4over6 || !quant_config_cpp.nvfp4_4over6_e4m3_use_256, - "NVFP4 4over6 E4M3 256 scale bound requires 4over6 quantization."); + NVTE_CHECK(use_4over6 || quant_config_cpp.nvfp4_e4m3_max == 448, + "Non-4over6 NVFP4 quantization requires E4M3 max 448."); NVTE_CHECK(!use_4over6 || !quant_config_cpp.stochastic_rounding, "NVFP4 4over6 quantization does not support stochastic rounding."); NVTE_CHECK(!output_tensor->row_scaled_nvfp4, @@ -305,7 +303,7 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, /*row_scaled_nvfp4=*/false, /*use_4over6=*/use_4over6, - /*use_4over6_e4m3_use_256=*/quant_config_cpp.nvfp4_4over6_e4m3_use_256, + /*nvfp4_e4m3_max=*/quant_config_cpp.nvfp4_e4m3_max, /*nvfp4_4over6_err_mode=*/quant_config_cpp.nvfp4_4over6_err_mode, /*nvfp4_4over6_err_use_fast_math=*/quant_config_cpp.nvfp4_4over6_err_use_fast_math, /*noop_tensor=*/noop_tensor->data, @@ -406,14 +404,13 @@ void group_quantize_fwd_host_aware_helper(const NVTETensor input, NVTETensor *ou for (const auto *output_tensor : output_tensors) { NVTE_CHECK(quant_config_cpp.nvfp4_4over6 == output_tensor->nvfp4_4over6, "Tensor and quantization config have inconsistent options for NVFP4 4over6."); - NVTE_CHECK( - quant_config_cpp.nvfp4_4over6_e4m3_use_256 == output_tensor->nvfp4_4over6_e4m3_use_256, - "Tensor and quantization config have inconsistent options for NVFP4 4over6 " - "E4M3 scale bound."); + NVTE_CHECK(quant_config_cpp.nvfp4_e4m3_max == output_tensor->nvfp4_e4m3_max, + "Tensor and quantization config have inconsistent options for NVFP4 4over6 " + "E4M3 scale bound."); } const bool use_4over6 = quant_config_cpp.nvfp4_4over6; - NVTE_CHECK(use_4over6 || !quant_config_cpp.nvfp4_4over6_e4m3_use_256, - "NVFP4 4over6 E4M3 256 scale bound requires 4over6 quantization."); + NVTE_CHECK(use_4over6 || quant_config_cpp.nvfp4_e4m3_max == 448, + "Non-4over6 NVFP4 quantization requires E4M3 max 448."); NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, "2D quantization is not supported for group quantize."); NVTE_CHECK(!use_4over6, "NVFP4 4over6 quantization is not supported for group quantize."); diff --git a/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh index ace00fb415..3820430d5b 100644 --- a/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh @@ -76,13 +76,14 @@ namespace core { using namespace ptx; // Compute the global encode scale factor for a given global amax. -// NVFP4 uses 448 by default. Some 4over6 tensors use 256 to leave room for -// map-to-4 scale expansion. -template +// NVFP4 uses the full E4M3 range by default. Some 4over6 tensors dispatch +// E4M3_MAX=256 to leave room for map-to-4 scale expansion. +template __device__ __forceinline__ float compute_global_encode_scaling_factor_FP4(const float global_amax) { using namespace detail; - constexpr float fp8_max = USE_E4M3_256 ? 256.0f : TypeExtrema::max; // 448.0f; - constexpr float fp4_max = TypeExtrema::max; // 6.0f; + static_assert(E4M3_MAX == 448 || E4M3_MAX == 256, "Unsupported NVFP4 E4M3 max."); + constexpr float fp8_max = static_cast(E4M3_MAX); + constexpr float fp4_max = TypeExtrema::max; // 6.0f; float global_encode_scale = fp8_max * fp4_max / global_amax; // If scale is infinity, return max value of float32 global_encode_scale = fminf(global_encode_scale, TypeExtrema::max); diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index 0228d2786a..faf3c58adf 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -31,7 +31,7 @@ namespace dispatch { namespace nvfp4 { namespace dequantize_kernel { #if FP4_TYPE_SUPPORTED -template +template __global__ void __launch_bounds__(512) dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales, const float *const tensor_amax, const size_t N, const size_t M, @@ -64,7 +64,8 @@ __global__ void __launch_bounds__(512) value.vec = input_vectorized[my_index]; fp8e4m3 scale = scales[my_scale_index]; float amax = ROW_SCALED_NVFP4 ? tensor_amax[y] : tensor_amax[0]; - constexpr float factor_inv = 1.0f / (6.0f * (USE_E4M3_256 ? 256.0f : 448.0f)); + static_assert(E4M3_MAX == 448 || E4M3_MAX == 256, "Unsupported NVFP4 E4M3 max."); + constexpr float factor_inv = 1.0f / (6.0f * static_cast(E4M3_MAX)); float final_scale = static_cast(scale) * amax * factor_inv; #pragma unroll for (int i = 0; i < 4; i++) { @@ -91,7 +92,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) const bool with_gemm_swizzled_scales = input.with_gemm_swizzled_scales; const bool row_scaled_nvfp4 = input.row_scaled_nvfp4; - const bool use_e4m3_256 = input.nvfp4_4over6_e4m3_use_256; + const int e4m3_max = input.nvfp4_e4m3_max; constexpr int FP4_BLOCK_SIZE = 16; const size_t N = input.flat_first_dim(); @@ -114,15 +115,23 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, TRANSFORMER_ENGINE_SWITCH_CONDITION( row_scaled_nvfp4, ROW_SCALED_NVFP4, - TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_e4m3_256, USE_E4M3_256, - dequantize_fp4_kernel<<>>( - input.data.dptr, reinterpret_cast(output->data.dptr), - reinterpret_cast(input.scale_inv.dptr), - reinterpret_cast(input.amax.dptr), N, Mread, - input.scale_inv.shape.back(), num_scale_tiles_X);););); // NOLINT(*) - ); // NOLINT(*) + if (e4m3_max == 256) { + dequantize_fp4_kernel + <<>>( + input.data.dptr, reinterpret_cast(output->data.dptr), + reinterpret_cast(input.scale_inv.dptr), + reinterpret_cast(input.amax.dptr), N, Mread, + input.scale_inv.shape.back(), num_scale_tiles_X); + } else { + NVTE_CHECK(e4m3_max == 448, "Unsupported NVFP4 E4M3 max (got ", e4m3_max, ")"); + dequantize_fp4_kernel + <<>>( + input.data.dptr, reinterpret_cast(output->data.dptr), + reinterpret_cast(input.scale_inv.dptr), + reinterpret_cast(input.amax.dptr), N, Mread, + input.scale_inv.shape.back(), num_scale_tiles_X); + });); // NOLINT(*) + ); // NOLINT(*) NVTE_CHECK_CUDA(cudaGetLastError()); #else NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); diff --git a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh index 9d4b0410c8..0f75d00498 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh @@ -46,6 +46,40 @@ namespace core { } \ } +#define TRANSFORMER_ENGINE_NVFP4_4OVER6_E4M3_MAX_SWITCH(E4M3_MAX_VALUE, E4M3_MAX_CONST, ...) \ + if ((E4M3_MAX_VALUE) == 256) { \ + constexpr int E4M3_MAX_CONST = 256; \ + { __VA_ARGS__ } \ + } else { \ + NVTE_CHECK((E4M3_MAX_VALUE) == 448, "Unsupported NVFP4 E4M3 max."); \ + constexpr int E4M3_MAX_CONST = 448; \ + { __VA_ARGS__ } \ + } + +template +struct NVFP44Over6Config { + static constexpr bool enabled = kEnabled; + static constexpr NVTENVFP44Over6ErrMode err_mode = kErrMode; + static constexpr bool err_use_fast_math = kErrUseFastMath; +}; + +using NVFP44Over6DisabledConfig = NVFP44Over6Config; + +#define TRANSFORMER_ENGINE_NVFP4_4OVER6_CONFIG_SWITCH(USE_4OVER6_VALUE, ERR_MODE_VALUE, \ + ERR_USE_FAST_MATH_VALUE, CONFIG_CONST, ...) \ + if (USE_4OVER6_VALUE) { \ + TRANSFORMER_ENGINE_NVFP4_4OVER6_ERR_MODE_SWITCH( \ + ERR_MODE_VALUE, ERR_MODE_CONST, \ + TRANSFORMER_ENGINE_SWITCH_CONDITION(ERR_USE_FAST_MATH_VALUE, ERR_USE_FAST_MATH_CONST, { \ + using CONFIG_CONST = NVFP44Over6Config; \ + { __VA_ARGS__ } \ + });); \ + } else { \ + using CONFIG_CONST = NVFP44Over6DisabledConfig; \ + { __VA_ARGS__ } \ + } + __device__ __forceinline__ void compute_4over6_decoding_scaling_factors( const float block_amax, const float S_enc, nvfp4_scale_t &S_dec_b_fp8_map4, nvfp4_scale_t &S_dec_b_fp8_map6) { @@ -126,10 +160,12 @@ __device__ __forceinline__ float compute_4over6_error(const float diff) { } } -template +template __device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error_rn( const float (&x)[8], const float block_scale_inverse, const nvfp4_scale_t S_dec_b_fp8, const float global_amax, float *err) { + static_assert(FourOverSixConfig::enabled, + "4over6 conversion helpers require an enabled 4over6 config."); uint32_t out = 0; uint32_t out_dequant_1 = 0; uint32_t out_dequant_2 = 0; @@ -172,11 +208,11 @@ __device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error_rn( const uint16_t out_dequant_4_lo = out_dequant_4 & 0xFFFF; constexpr float fp4_max = detail::TypeExtrema::max; // 6.0f - constexpr float fp8_4over6_max = - USE_E4M3_256 ? 256.0f : detail::TypeExtrema::max; // 448.0f + static_assert(E4M3_MAX == 448 || E4M3_MAX == 256, "Unsupported NVFP4 E4M3 max."); + constexpr float fp8_4over6_max = static_cast(E4M3_MAX); constexpr float err_denom = fp4_max * fp8_4over6_max; const float sf = static_cast(S_dec_b_fp8); - if constexpr (USE_ERR_USE_FAST_MATH) { + if constexpr (FourOverSixConfig::err_use_fast_math) { const float dequant[8] = { __half2float(__ushort_as_half(out_dequant_1_lo)), __half2float(__ushort_as_half(out_dequant_1_hi)), @@ -191,7 +227,7 @@ __device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error_rn( for (int i = 0; i < 8; ++i) { const float val = dequant[i] * sf * global_amax / err_denom; const float diff = val - x[i]; - *err += compute_4over6_error(diff); + *err += compute_4over6_error(diff); } } else { const float val0 = __fdiv_rn( @@ -228,14 +264,14 @@ __device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error_rn( const float diff6 = __fsub_rn(val6, x[6]); const float diff7 = __fsub_rn(val7, x[7]); - *err = __fadd_rn(*err, compute_4over6_error_rn(diff0)); - *err = __fadd_rn(*err, compute_4over6_error_rn(diff1)); - *err = __fadd_rn(*err, compute_4over6_error_rn(diff2)); - *err = __fadd_rn(*err, compute_4over6_error_rn(diff3)); - *err = __fadd_rn(*err, compute_4over6_error_rn(diff4)); - *err = __fadd_rn(*err, compute_4over6_error_rn(diff5)); - *err = __fadd_rn(*err, compute_4over6_error_rn(diff6)); - *err = __fadd_rn(*err, compute_4over6_error_rn(diff7)); + *err = __fadd_rn(*err, compute_4over6_error_rn(diff0)); + *err = __fadd_rn(*err, compute_4over6_error_rn(diff1)); + *err = __fadd_rn(*err, compute_4over6_error_rn(diff2)); + *err = __fadd_rn(*err, compute_4over6_error_rn(diff3)); + *err = __fadd_rn(*err, compute_4over6_error_rn(diff4)); + *err = __fadd_rn(*err, compute_4over6_error_rn(diff5)); + *err = __fadd_rn(*err, compute_4over6_error_rn(diff6)); + *err = __fadd_rn(*err, compute_4over6_error_rn(diff7)); } } else { NVTE_DEVICE_ERROR( @@ -246,8 +282,7 @@ __device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error_rn( return out; } -template +template __device__ __forceinline__ void quantize_4over6_16x(const float (&first_half)[8], const float (&second_half)[8], const QuantizationScales4Over6 &scaling_factors, @@ -255,29 +290,29 @@ __device__ __forceinline__ void quantize_4over6_16x(const float (&first_half)[8] float &err_map6, uint32_t (&rOut_map4)[2], uint32_t (&rOut_map6)[2]) { if constexpr (REVERSE_PACK_ORDER) { - rOut_map4[1] = cvt_fp32_to_fp4_8x_with_error_rn( + rOut_map4[1] = cvt_fp32_to_fp4_8x_with_error_rn( second_half, static_cast(scaling_factors.SFcoefficient_map4), scaling_factors.S_dec_b_fp8_map4, global_amax, &err_map4); - rOut_map6[1] = cvt_fp32_to_fp4_8x_with_error_rn( + rOut_map6[1] = cvt_fp32_to_fp4_8x_with_error_rn( second_half, static_cast(scaling_factors.SFcoefficient_map6), scaling_factors.S_dec_b_fp8_map6, global_amax, &err_map6); - rOut_map4[0] = cvt_fp32_to_fp4_8x_with_error_rn( + rOut_map4[0] = cvt_fp32_to_fp4_8x_with_error_rn( first_half, static_cast(scaling_factors.SFcoefficient_map4), scaling_factors.S_dec_b_fp8_map4, global_amax, &err_map4); - rOut_map6[0] = cvt_fp32_to_fp4_8x_with_error_rn( + rOut_map6[0] = cvt_fp32_to_fp4_8x_with_error_rn( first_half, static_cast(scaling_factors.SFcoefficient_map6), scaling_factors.S_dec_b_fp8_map6, global_amax, &err_map6); } else { - rOut_map4[0] = cvt_fp32_to_fp4_8x_with_error_rn( + rOut_map4[0] = cvt_fp32_to_fp4_8x_with_error_rn( first_half, static_cast(scaling_factors.SFcoefficient_map4), scaling_factors.S_dec_b_fp8_map4, global_amax, &err_map4); - rOut_map6[0] = cvt_fp32_to_fp4_8x_with_error_rn( + rOut_map6[0] = cvt_fp32_to_fp4_8x_with_error_rn( first_half, static_cast(scaling_factors.SFcoefficient_map6), scaling_factors.S_dec_b_fp8_map6, global_amax, &err_map6); - rOut_map4[1] = cvt_fp32_to_fp4_8x_with_error_rn( + rOut_map4[1] = cvt_fp32_to_fp4_8x_with_error_rn( second_half, static_cast(scaling_factors.SFcoefficient_map4), scaling_factors.S_dec_b_fp8_map4, global_amax, &err_map4); - rOut_map6[1] = cvt_fp32_to_fp4_8x_with_error_rn( + rOut_map6[1] = cvt_fp32_to_fp4_8x_with_error_rn( second_half, static_cast(scaling_factors.SFcoefficient_map6), scaling_factors.S_dec_b_fp8_map6, global_amax, &err_map6); } @@ -295,8 +330,7 @@ selected_4over6_scale(const bool pick_map4, const QuantizationScales4Over6 &scal return scaling_factors.S_dec_b_fp8_map6; } -template +template __device__ __forceinline__ void quantize_4over6_16x(const float (&first_half)[8], const float (&second_half)[8], const QuantizationScales4Over6 &scaling_factors, @@ -308,7 +342,7 @@ __device__ __forceinline__ void quantize_4over6_16x(const float (&first_half)[8] __align__(8) uint32_t rOut_map4[2]; __align__(8) uint32_t rOut_map6[2]; - quantize_4over6_16x( + quantize_4over6_16x( first_half, second_half, scaling_factors, global_amax, err_map4, err_map6, rOut_map4, rOut_map6); @@ -408,7 +442,7 @@ __device__ __forceinline__ void load_4over6_vec_index_halves_16x(const vec_type } } -template +template __device__ __forceinline__ void quantize_4over6_candidates_16x( const float (&x)[16], const QuantizationScales4Over6 &scaling_factors, const float global_amax, QuantizationCandidates4Over6 &candidates) { @@ -417,7 +451,7 @@ __device__ __forceinline__ void quantize_4over6_candidates_16x( load_4over6_contiguous_halves_16x(x, first_half, second_half); candidates.reset_errors(); - quantize_4over6_16x( + quantize_4over6_16x( first_half, second_half, scaling_factors, global_amax, candidates.err_map4, candidates.err_map6, candidates.rOut_map4, candidates.rOut_map6); } @@ -455,8 +489,8 @@ __device__ __forceinline__ bool record_and_select_4over6_2d_block( return scratch.pick_map4_matrix[block_in_tile_y][block_in_tile_x] != 0; } -template +template __device__ __forceinline__ bool quantize_and_select_4over6_2d_block_16x( const float (&x)[16], const float block_amax, const float global_encode_scale, const float global_decode_scale, const float global_amax, const size_t block_in_tile_y, @@ -465,8 +499,8 @@ __device__ __forceinline__ bool quantize_and_select_4over6_2d_block_16x( nvfp4_scale_t &S_dec_b_fp8, QuantizationCandidates4Over6 &candidates) { const auto scaling_factors = compute_4over6_fp4_encode_quantization_scaling_factors( block_amax, global_encode_scale, global_decode_scale); - quantize_4over6_candidates_16x( - x, scaling_factors, global_amax, candidates); + quantize_4over6_candidates_16x(x, scaling_factors, global_amax, + candidates); const bool pick_map4 = record_and_select_4over6_2d_block( @@ -529,8 +563,7 @@ __device__ __forceinline__ void store_selected_4over6_packed_16x( store_4over6_packed_16x(candidates.selected_packed(pick_map4), output_vec); } -template +template __device__ __forceinline__ void quantize_4over6_contiguous_16x( const input_type *x, const QuantizationScales4Over6 &scaling_factors, const float global_amax, nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { @@ -538,12 +571,11 @@ __device__ __forceinline__ void quantize_4over6_contiguous_16x( float second_half[8]; load_4over6_contiguous_halves_16x(x, first_half, second_half); - quantize_4over6_16x( + quantize_4over6_16x( first_half, second_half, scaling_factors, global_amax, S_dec_b_fp8, rOut); } -template +template __device__ __forceinline__ void quantize_4over6_pair_array_16x( const pair_type (&x)[2][4], const QuantizationScales4Over6 &scaling_factors, const float global_amax, nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { @@ -551,12 +583,11 @@ __device__ __forceinline__ void quantize_4over6_pair_array_16x( float second_half[8]; load_4over6_pair_array_halves_16x(x, first_half, second_half); - quantize_4over6_16x( + quantize_4over6_16x( first_half, second_half, scaling_factors, global_amax, S_dec_b_fp8, rOut); } -template +template __device__ __forceinline__ void quantize_4over6_vec2_array_candidates_16x( const vec_type (&x)[8], const QuantizationScales4Over6 &scaling_factors, const float global_amax, QuantizationCandidates4Over6 &candidates) { @@ -565,13 +596,12 @@ __device__ __forceinline__ void quantize_4over6_vec2_array_candidates_16x( load_4over6_vec2_array_halves_16x(x, first_half, second_half); candidates.reset_errors(); - quantize_4over6_16x( + quantize_4over6_16x( first_half, second_half, scaling_factors, global_amax, candidates.err_map4, candidates.err_map6, candidates.rOut_map4, candidates.rOut_map6); } -template +template __device__ __forceinline__ void quantize_4over6_vec2_array_16x( const vec_type (&x)[8], const QuantizationScales4Over6 &scaling_factors, const float global_amax, nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { @@ -579,12 +609,11 @@ __device__ __forceinline__ void quantize_4over6_vec2_array_16x( float second_half[8]; load_4over6_vec2_array_halves_16x(x, first_half, second_half); - quantize_4over6_16x( + quantize_4over6_16x( first_half, second_half, scaling_factors, global_amax, S_dec_b_fp8, rOut); } -template +template __device__ __forceinline__ void quantize_4over6_vec_index_candidates_16x( const vec_type (&x)[16], const int idx, const QuantizationScales4Over6 &scaling_factors, const float global_amax, QuantizationCandidates4Over6 &candidates) { @@ -593,13 +622,12 @@ __device__ __forceinline__ void quantize_4over6_vec_index_candidates_16x( load_4over6_vec_index_halves_16x(x, idx, first_half, second_half); candidates.reset_errors(); - quantize_4over6_16x( + quantize_4over6_16x( first_half, second_half, scaling_factors, global_amax, candidates.err_map4, candidates.err_map6, candidates.rOut_map4, candidates.rOut_map6); } -template +template __device__ __forceinline__ void quantize_4over6_vec_index_16x( const vec_type (&x)[16], const int idx, const QuantizationScales4Over6 &scaling_factors, const float global_amax, nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { @@ -607,7 +635,7 @@ __device__ __forceinline__ void quantize_4over6_vec_index_16x( float second_half[8]; load_4over6_vec_index_halves_16x(x, idx, first_half, second_half); - quantize_4over6_16x( + quantize_4over6_16x( first_half, second_half, scaling_factors, global_amax, S_dec_b_fp8, rOut); } diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index 09372f8b98..c7448d1849 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -784,9 +784,8 @@ __global__ void __launch_bounds__(THREADS_NUM) } template + typename IType, bool USE_STOCHASTIC_ROUNDING, bool RETURN_TRANSPOSE, int E4M3_MAX, + typename FourOverSixConfig> __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_output, @@ -900,7 +899,7 @@ __global__ void __launch_bounds__(THREADS_NUM) using FourOverSixScratch = QuantizationScratch4Over6; FourOverSixScratch *four_over_six_scratch = nullptr; - if constexpr (USE_4OVER6) { + if constexpr (FourOverSixConfig::enabled) { constexpr size_t four_over_six_scratch_offset = in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales + out_mem_colwise_scales; @@ -916,14 +915,14 @@ __global__ void __launch_bounds__(THREADS_NUM) const float S_enc_rowwise = (amax_rowwise_ptr == nullptr) ? 1.0f - : compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); + : compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); // NOTE: This is to match with how emulation code was written. const float S_dec_rowwise = 1.0 / S_enc_rowwise; const float S_enc_colwise = (amax_colwise_ptr == nullptr) ? S_enc_rowwise - : compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); + : compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); const float S_dec_colwise = 1.0 / S_enc_colwise; const float global_amax_rowwise = (amax_rowwise_ptr == nullptr) ? 1.0f : *amax_rowwise_ptr; const float global_amax_colwise = @@ -1099,7 +1098,7 @@ __global__ void __launch_bounds__(THREADS_NUM) } } - if constexpr (USE_4OVER6) { + if constexpr (FourOverSixConfig::enabled) { float x_4over6[SCALE_DIM]; #pragma unroll for (int i = 0; i < SCALE_DIM; ++i) { @@ -1113,12 +1112,12 @@ __global__ void __launch_bounds__(THREADS_NUM) const size_t block_col = threadIdx.x % BLOCK_DIM; QuantizationCandidates4Over6 candidates; nvfp4_scale_t S_dec_b_fp8; - const bool pick_map4 = quantize_and_select_4over6_2d_block_16x< - USE_4OVER6_ERR_MODE, USE_4OVER6_ERR_USE_FAST_MATH, USE_4OVER6_E4M3_USE_256, BLOCK_DIM, - BLOCKS_PER_TILE_Y, BLOCKS_PER_TILE_X>( - x_4over6, block_amax, S_enc_colwise, S_dec_colwise, global_amax_colwise, - block_in_tile_y, block_in_tile_x, block_col, *four_over_six_scratch, S_dec_b_fp8, - candidates); + const bool pick_map4 = + quantize_and_select_4over6_2d_block_16x( + x_4over6, block_amax, S_enc_colwise, S_dec_colwise, global_amax_colwise, + block_in_tile_y, block_in_tile_x, block_col, *four_over_six_scratch, S_dec_b_fp8, + candidates); const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; @@ -1213,7 +1212,7 @@ __global__ void __launch_bounds__(THREADS_NUM) const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; // Load elements in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); - if constexpr (USE_4OVER6) { + if constexpr (FourOverSixConfig::enabled) { #pragma unroll for (int e = 0; e < PACK_SIZE / 2; ++e) { in_4over6_rowwise[swizzled_group_idx + 2 * e] = @@ -1234,7 +1233,7 @@ __global__ void __launch_bounds__(THREADS_NUM) // Load cached elements in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); - if constexpr (USE_4OVER6) { + if constexpr (FourOverSixConfig::enabled) { #pragma unroll for (int e = 0; e < PACK_SIZE; ++e) { in_4over6_rowwise[swizzled_group_idx + e] = @@ -1266,22 +1265,22 @@ __global__ void __launch_bounds__(THREADS_NUM) elt = static_cast(static_cast(elt)); } in_compute_rowwise[j] = elt; - if constexpr (USE_4OVER6) { + if constexpr (FourOverSixConfig::enabled) { in_4over6_rowwise[swizzled_group_idx + e] = elt; } } } } - if constexpr (USE_4OVER6) { + if constexpr (FourOverSixConfig::enabled) { QuantizationCandidates4Over6 candidates; nvfp4_scale_t S_dec_b_fp8; - const bool pick_map4 = quantize_and_select_4over6_2d_block_16x< - USE_4OVER6_ERR_MODE, USE_4OVER6_ERR_USE_FAST_MATH, USE_4OVER6_E4M3_USE_256, BLOCK_DIM, - BLOCKS_PER_TILE_Y, BLOCKS_PER_TILE_X>( - in_4over6_rowwise, block_amax, S_enc_rowwise, S_dec_rowwise, global_amax_rowwise, - block_in_tile_y, block_in_tile_x, tid_Y_rowwise, *four_over_six_scratch, S_dec_b_fp8, - candidates); + const bool pick_map4 = + quantize_and_select_4over6_2d_block_16x( + in_4over6_rowwise, block_amax, S_enc_rowwise, S_dec_rowwise, global_amax_rowwise, + block_in_tile_y, block_in_tile_x, tid_Y_rowwise, *four_over_six_scratch, + S_dec_b_fp8, candidates); const size_t scales_offset_Y = scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; @@ -1421,8 +1420,7 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; const bool use_4over6 = quant_config ? quant_config->nvfp4_4over6 : false; - const bool use_4over6_e4m3_use_256 = - use_4over6 && quant_config && quant_config->nvfp4_4over6_e4m3_use_256; + const int nvfp4_e4m3_max = use_4over6 && quant_config ? quant_config->nvfp4_e4m3_max : 448; const NVTENVFP44Over6ErrMode use_4over6_err_mode = use_4over6 && quant_config ? quant_config->nvfp4_4over6_err_mode : kNVTENVFP44Over6ErrMAE; const bool use_4over6_err_use_fast_math = @@ -1544,43 +1542,34 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, TRANSFORMER_ENGINE_SWITCH_CONDITION( row_scaled_nvfp4, ROW_SCALED_NVFP4, - TRANSFORMER_ENGINE_SWITCH_CONDITION(use_4over6, USE_4OVER6, { - TRANSFORMER_ENGINE_NVFP4_4OVER6_ERR_MODE_SWITCH( - use_4over6_err_mode, USE_4OVER6_ERR_MODE, { - TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_4over6_err_use_fast_math, USE_4OVER6_ERR_USE_FAST_MATH, - TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_4over6_e4m3_use_256, USE_4OVER6_E4M3_USE_256, - TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { - auto kernel = - quantize_transpose_nvfp4_kernel; - - if constexpr (use_2d_quantization) { - kernel = quantize_transpose_nvfp4_2D_kernel< - COMPUTE_ACTIVATIONS, ParamOP, OP, IType, USE_STOCHASTIC_ROUNDING, - RETURN_TRANSPOSE, USE_4OVER6, USE_4OVER6_E4M3_USE_256, - USE_4OVER6_ERR_MODE, USE_4OVER6_ERR_USE_FAST_MATH>; - } - using FourOverSixScratch = - core::QuantizationScratch4Over6; - constexpr size_t dshmem_size = - base_dshmem_size + - FourOverSixScratch::template dynamic_shared_memory_size< - use_2d_quantization, USE_4OVER6>(); - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - kernel<<>>( - tensor_map_input, tensor_map_output, tensor_map_output_transpose, - scales_ptr, scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, - amax_colwise_ptr, rows, cols, scale_stride, scale_stride_transpose, - rng_state); - }););); - }); - }););); + TRANSFORMER_ENGINE_NVFP4_4OVER6_CONFIG_SWITCH( + use_4over6, use_4over6_err_mode, use_4over6_err_use_fast_math, FourOverSixConfig, + TRANSFORMER_ENGINE_NVFP4_4OVER6_E4M3_MAX_SWITCH( + nvfp4_e4m3_max, E4M3_MAX, + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + auto kernel = + quantize_transpose_nvfp4_kernel; + + if constexpr (use_2d_quantization) { + kernel = quantize_transpose_nvfp4_2D_kernel< + COMPUTE_ACTIVATIONS, ParamOP, OP, IType, USE_STOCHASTIC_ROUNDING, + RETURN_TRANSPOSE, E4M3_MAX, FourOverSixConfig>; + } + using FourOverSixScratch = core::QuantizationScratch4Over6< + NVFP4_2D_BLOCK_DIM, NVFP4_2D_BLOCKS_PER_TILE_Y, NVFP4_2D_BLOCKS_PER_TILE_X>; + constexpr size_t dshmem_size = + base_dshmem_size + FourOverSixScratch::template dynamic_shared_memory_size< + use_2d_quantization, FourOverSixConfig::enabled>(); + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, + scales_ptr, scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, + amax_colwise_ptr, rows, cols, scale_stride, scale_stride_transpose, + rng_state); + }););););); #else NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); #endif // FP4_TYPE_SUPPORTED diff --git a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh index 68bdc84da2..a60097b9df 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh @@ -185,9 +185,8 @@ compute_nvfp4_scaling_coefficient(const nvfp4_scale_t S_dec_block, const f return static_cast(scale_rcp); } -template +template __device__ __forceinline__ void colwise_scaling( const IType *__restrict__ sIn_ptr, fp4e2m1x2 *__restrict__ sOut_tr_ptr, nvfp4_scale_t *__restrict__ sSFcolwise_ptr, const float S_enc_colwise, @@ -232,14 +231,13 @@ __device__ __forceinline__ void colwise_scaling( static_cast(__habs(thread_amax_2x.y))}; #pragma unroll for (int w = 0; w < 2; ++w) { - if constexpr (USE_4OVER6) { + if constexpr (FourOverSixConfig::enabled) { __align__(8) uint32_t rOut[SCALE_DIM / 8]; nvfp4_scale_t S_dec_b_fp8; const auto scaling_factors = core::compute_4over6_nvfp4_quantization_scaling_factors(block_amax[w], S_enc_colwise); - core::quantize_4over6_contiguous_16x( + core::quantize_4over6_contiguous_16x( rIn[w], scaling_factors, global_amax_colwise, S_dec_b_fp8, rOut); // Store scaling factors to SMEM buffer (R2S) @@ -281,9 +279,8 @@ __device__ __forceinline__ void colwise_scaling( } } -template +template __device__ __forceinline__ void rowwise_scaling( const IType *__restrict__ sIn_ptr, fp4e2m1x2 *__restrict__ sOut_ptr, nvfp4_scale_t *__restrict__ sSFrowwise_ptr, const float S_enc_rowwise, const int stage_Y, @@ -335,7 +332,7 @@ __device__ __forceinline__ void rowwise_scaling( } const float block_amax = get_amax_of_pair(thread_amax_2x); - if constexpr (USE_4OVER6) { + if constexpr (FourOverSixConfig::enabled) { nvfp4_scale_t S_dec_b_fp8; float block_S_enc_rowwise; float block_global_amax; @@ -344,8 +341,7 @@ __device__ __forceinline__ void rowwise_scaling( if (row_idx < rows) { block_global_amax = amax_rowwise_ptr[row_idx]; block_S_enc_rowwise = - core::compute_global_encode_scaling_factor_FP4( - block_global_amax); + core::compute_global_encode_scaling_factor_FP4(block_global_amax); } else { block_global_amax = 1.0f; block_S_enc_rowwise = 1.0f; @@ -359,12 +355,10 @@ __device__ __forceinline__ void rowwise_scaling( __align__(8) uint32_t rOut[WAVES]; if (bank_group == 0) { - core::quantize_4over6_pair_array_16x( + core::quantize_4over6_pair_array_16x( rIn, scaling_factors, block_global_amax, S_dec_b_fp8, rOut); } else { - core::quantize_4over6_pair_array_16x( + core::quantize_4over6_pair_array_16x( rIn, scaling_factors, block_global_amax, S_dec_b_fp8, rOut); } @@ -432,8 +426,7 @@ __device__ __forceinline__ void rowwise_scaling( } template + bool ROW_SCALED_NVFP4, int E4M3_MAX, typename FourOverSixConfig> __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D_kernel( const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_output, @@ -501,14 +494,12 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D const float S_enc_rowwise = (amax_rowwise_ptr == nullptr) ? 1.0f - : core::compute_global_encode_scaling_factor_FP4( - *amax_rowwise_ptr); + : core::compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); const float S_enc_colwise = (amax_colwise_ptr == nullptr) ? S_enc_rowwise - : core::compute_global_encode_scaling_factor_FP4( - *amax_colwise_ptr); + : core::compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); // Original NVFP4 uses a scalar per-tensor amax for both rowwise and columnwise output. // If no dedicated columnwise amax buffer is allocated, the rowwise amax is that same scalar. const float global_amax_colwise = (amax_colwise_ptr == nullptr) @@ -661,14 +652,13 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D ptx::cp_async_bulk_wait_group_read(); // NVFP4 Quantization - rowwise_scaling( - sIn_ptr, sOut_ptr, sSFrowwise_ptr, S_enc_rowwise, stage_Y, stage_X, buff_in, buff_out, - amax_rowwise_ptr, block_offset_Y, rows, rng, random_uint4, rnd_idx); + rowwise_scaling(sIn_ptr, sOut_ptr, sSFrowwise_ptr, S_enc_rowwise, stage_Y, + stage_X, buff_in, buff_out, amax_rowwise_ptr, + block_offset_Y, rows, rng, random_uint4, rnd_idx); if constexpr (RETURN_TRANSPOSE) { - colwise_scaling( + colwise_scaling( sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, S_enc_colwise, global_amax_colwise, stage_Y, stage_X, buff_in, buff_out_tr, rng, random_uint4, rnd_idx); } @@ -772,8 +762,7 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, const bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; const bool use_4over6 = quant_config ? quant_config->nvfp4_4over6 : false; - const bool use_4over6_e4m3_use_256 = - use_4over6 && quant_config && quant_config->nvfp4_4over6_e4m3_use_256; + const int nvfp4_e4m3_max = use_4over6 && quant_config ? quant_config->nvfp4_e4m3_max : 448; const NVTENVFP44Over6ErrMode use_4over6_err_mode = use_4over6 && quant_config ? quant_config->nvfp4_4over6_err_mode : kNVTENVFP44Over6ErrMAE; const bool use_4over6_err_use_fast_math = @@ -894,27 +883,25 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, TRANSFORMER_ENGINE_SWITCH_CONDITION( row_scaled_nvfp4, ROW_SCALED_NVFP4, - TRANSFORMER_ENGINE_NVFP4_4OVER6_ERR_MODE_SWITCH( - use_4over6_err_mode, USE_4OVER6_ERR_MODE, - TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_4over6_err_use_fast_math, USE_4OVER6_ERR_USE_FAST_MATH, - TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_4over6_e4m3_use_256, USE_4OVER6_E4M3_USE_256, - TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { - auto kernel = quantize_transpose_nvfp4_tuned_1D_kernel< - USE_STOCHASTIC_ROUNDING, - /*USE_FAST_MATH=*/false, RETURN_TRANSPOSE, ROW_SCALED_NVFP4, - /*USE_4OVER6=*/true, USE_4OVER6_ERR_MODE, - USE_4OVER6_ERR_USE_FAST_MATH, USE_4OVER6_E4M3_USE_256>; - - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - dshmem_size); - kernel<<>>( - tensor_map_input, tensor_map_output, tensor_map_output_transpose, - scales_ptr, scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, - amax_colwise_ptr, rows, cols, scale_stride, scale_stride_transpose, - rng_state); - });););););); + TRANSFORMER_ENGINE_NVFP4_4OVER6_CONFIG_SWITCH( + /*USE_4OVER6_VALUE=*/true, use_4over6_err_mode, use_4over6_err_use_fast_math, + FourOverSixConfig, + TRANSFORMER_ENGINE_NVFP4_4OVER6_E4M3_MAX_SWITCH( + nvfp4_e4m3_max, E4M3_MAX, + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + auto kernel = quantize_transpose_nvfp4_tuned_1D_kernel< + USE_STOCHASTIC_ROUNDING, + /*USE_FAST_MATH=*/false, RETURN_TRANSPOSE, ROW_SCALED_NVFP4, E4M3_MAX, + FourOverSixConfig>; + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, + scales_ptr, scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, + amax_colwise_ptr, rows, cols, scale_stride, scale_stride_transpose, + rng_state); + }););););); } else { const bool use_fast_math = quant_config ? quant_config->use_fast_math : false; TRANSFORMER_ENGINE_SWITCH_CONDITION( @@ -926,9 +913,7 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { auto kernel = quantize_transpose_nvfp4_tuned_1D_kernel< USE_STOCHASTIC_ROUNDING, USE_FAST_MATH, RETURN_TRANSPOSE, ROW_SCALED_NVFP4, - /*USE_4OVER6=*/false, kNVTENVFP44Over6ErrMAE, - /*USE_4OVER6_ERR_USE_FAST_MATH=*/false, - /*USE_4OVER6_E4M3_USE_256=*/false>; + /*E4M3_MAX=*/448, core::NVFP44Over6DisabledConfig>; cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 34f687ab9d..eeb801b758 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -234,8 +234,8 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz chunk.set_nvfp4_4over6(source.get_nvfp4_4over6()); continue; } - if (param_type == NVTETensorParam::kNVTENVFP44Over6E4M3Use256) { - chunk.set_nvfp4_4over6_e4m3_use_256(source.get_nvfp4_4over6_e4m3_use_256()); + if (param_type == NVTETensorParam::kNVTENVFP4E4M3Max) { + chunk.set_nvfp4_e4m3_max(source.get_nvfp4_e4m3_max()); continue; } auto param = source.get_parameter(param_type); diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index a1c31bae20..154dab3143 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -184,12 +184,12 @@ struct Tensor { * map-to-4/map-to-6 candidate for each 1x16 block. */ bool nvfp4_4over6 = false; - /*! \brief Whether NVFP4 4over6 uses 256 as the global E4M3 scale bound. + /*! \brief Global E4M3 scale bound used by NVFP4. * - * Only meaningful when nvfp4_4over6 is true. If false, the standard NVFP4 - * E4M3 bound 448 is used. + * Standard NVFP4 uses 448. Some 4over6 tensors use 256 to leave room for + * map-to-4 local scale expansion. */ - bool nvfp4_4over6_e4m3_use_256 = false; + int nvfp4_e4m3_max = 448; /*! Map from NVTETensorParam to parameter sizes */ static constexpr size_t attr_sizes[] = { @@ -203,7 +203,7 @@ struct Tensor { sizeof(uint8_t), // kNVTEWithGEMMSwizzledScales sizeof(uint8_t), // kNVTERowScaledNVFP4 sizeof(uint8_t), // kNVTENVFP44Over6 - sizeof(uint8_t) // kNVTENVFP44Over6E4M3Use256 + sizeof(int) // kNVTENVFP4E4M3Max }; Tensor() : scaling_mode{NVTE_DELAYED_TENSOR_SCALING}, nvte_tensor{0} {} @@ -221,7 +221,7 @@ struct Tensor { with_gemm_swizzled_scales = false; row_scaled_nvfp4 = false; nvfp4_4over6 = false; - nvfp4_4over6_e4m3_use_256 = false; + nvfp4_e4m3_max = 448; } explicit operator NVTETensor() const noexcept { return nvte_tensor; } @@ -494,7 +494,7 @@ struct QuantizationConfig { bool stochastic_rounding = false; bool use_fast_math = false; bool nvfp4_4over6 = false; - bool nvfp4_4over6_e4m3_use_256 = false; + int nvfp4_e4m3_max = 448; NVTENVFP44Over6ErrMode nvfp4_4over6_err_mode = kNVTENVFP44Over6ErrMAE; bool nvfp4_4over6_err_use_fast_math = false; @@ -508,7 +508,7 @@ struct QuantizationConfig { sizeof(uint8_t), // stochastic_rounding sizeof(uint8_t), // use_fast_math sizeof(uint8_t), // nvfp4_4over6 - sizeof(uint8_t), // nvfp4_4over6_e4m3_use_256 + sizeof(int), // nvfp4_e4m3_max sizeof(uint8_t), // nvfp4_4over6_err_mode sizeof(uint8_t) // nvfp4_4over6_err_use_fast_math }; diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index dcff2fbe1c..c3b2aada00 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -89,13 +89,13 @@ enum NVTETensorParam { * and map-to-6 candidates. */ kNVTENVFP44Over6 = 9, - /*! Whether an NVFP4 4over6 tensor uses 256 as its global E4M3 scale bound. + /*! Global E4M3 scale bound used by an NVFP4 tensor. * * This is part of the tensor data contract. Downstream dequantization and - * GEMM scale consumers must use the same global E4M3 bound used during - * quantization. + * GEMM scale consumers must use the same bound used during quantization. + * Standard NVFP4 uses 448; 4over6 may use 256 for map-to-4 headroom. */ - kNVTENVFP44Over6E4M3Use256 = 10, + kNVTENVFP4E4M3Max = 10, kNVTENumTensorParams }; @@ -410,12 +410,13 @@ enum NVTEQuantizationConfigAttribute { * kNVTENVFP44Over6 metadata must match this option. */ kNVTEQuantizationConfigNVFP44Over6 = 8, - /*! Whether NVFP4 4over6 should use 256 as the global E4M3 scale bound. + /*! Global E4M3 scale bound to use for NVFP4 quantization. * - * If disabled, 4over6 uses the default NVFP4 448 bound. The output tensor's - * kNVTENVFP44Over6E4M3Use256 metadata must match this option. + * Standard NVFP4 uses 448. Some 4over6 tensors use 256 to leave room for + * map-to-4 local scale expansion. The output tensor's kNVTENVFP4E4M3Max + * metadata must match this option. */ - kNVTEQuantizationConfigNVFP44Over6E4M3Use256 = 9, + kNVTEQuantizationConfigNVFP4E4M3Max = 9, /*! Candidate-selection error mode for NVFP4 4over6 quantization. * * The value is an NVTENVFP44Over6ErrMode encoded as uint8_t. It is only @@ -834,9 +835,9 @@ class TensorWrapper { nvte_set_tensor_param_v2(tensor_, kNVTENVFP44Over6, &val, sizeof(val)); } - void set_nvfp4_4over6_e4m3_use_256(bool nvfp4_4over6_e4m3_use_256) { - const auto val = static_cast(nvfp4_4over6_e4m3_use_256); - nvte_set_tensor_param_v2(tensor_, kNVTENVFP44Over6E4M3Use256, &val, sizeof(val)); + void set_nvfp4_e4m3_max(int nvfp4_e4m3_max) { + const auto val = nvfp4_e4m3_max; + nvte_set_tensor_param_v2(tensor_, kNVTENVFP4E4M3Max, &val, sizeof(val)); } // Parameter getters @@ -887,10 +888,10 @@ class TensorWrapper { return static_cast(val); } - bool get_nvfp4_4over6_e4m3_use_256() const { - uint8_t val = 0; - nvte_get_tensor_param_v2(tensor_, kNVTENVFP44Over6E4M3Use256, &val, sizeof(val), nullptr); - return static_cast(val); + int get_nvfp4_e4m3_max() const { + int val = 448; + nvte_get_tensor_param_v2(tensor_, kNVTENVFP4E4M3Max, &val, sizeof(val), nullptr); + return val; } /*! \brief Get an underlying NVTETensor. @@ -1395,11 +1396,11 @@ class QuantizationConfigWrapper { sizeof(val)); } - /*! \brief Set whether NVFP4 4over6 uses the 256 global E4M3 scale bound */ - void set_nvfp4_4over6_e4m3_use_256(bool nvfp4_4over6_e4m3_use_256) { - const auto val = static_cast(nvfp4_4over6_e4m3_use_256); - nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNVFP44Over6E4M3Use256, - &val, sizeof(val)); + /*! \brief Set the global E4M3 scale bound used by NVFP4 quantization */ + void set_nvfp4_e4m3_max(int nvfp4_e4m3_max) { + const auto val = nvfp4_e4m3_max; + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNVFP4E4M3Max, &val, + sizeof(val)); } /*! \brief Set NVFP4 4over6 candidate-selection error mode */ diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index f3658efacf..a3f4f2d700 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -533,7 +533,7 @@ class NVFP4BlockScaling(Recipe): supported on tensors that use 4over6; activation and backward scopes therefore require ``disable_rht=True`` and ``disable_stochastic_rounding=True``. - nvfp4_4over6_e4m3_use_256 : {None, 'weights', 'activations', 'all'}, default = None + nvfp4_e4m3_max : {None, 'weights', 'activations', 'all'}, default = None Select 4over6 tensors that use 256 as the global E4M3 scale bound. If unset, 4over6 uses the standard NVFP4 448 bound. nvfp4_4over6_err_mode : {'MAE', 'MSE'}, default = 'MAE' @@ -553,7 +553,7 @@ class NVFP4BlockScaling(Recipe): disable_2d_quantization: bool = os.getenv("NVTE_NVFP4_DISABLE_2D_QUANTIZATION", "0") == "1" row_scaled_activation: bool = os.getenv("NVTE_NVFP4_ROW_SCALED_ACTIVATION", "0") == "1" nvfp4_4over6: Optional[str] = os.getenv("NVTE_NVFP4_4OVER6", None) - nvfp4_4over6_e4m3_use_256: Optional[str] = os.getenv("NVTE_NVFP4_4OVER6_E4M3_USE_256", None) + nvfp4_e4m3_max: Optional[str] = os.getenv("NVTE_NVFP4_4OVER6_E4M3_USE_256", None) nvfp4_4over6_err_mode: str = os.getenv("NVTE_NVFP4_4OVER6_ERR_MODE", "MAE").upper() fp4_format: Format = Format.E2M1 @@ -573,7 +573,7 @@ def __post_init__(self) -> None: assert ( self.nvfp4_4over6 in _NVFP4_4OVER6_SCOPES ), "NVTE_NVFP4_4OVER6 must be unset or one of: 'weights', 'activations', 'all'." - assert self.nvfp4_4over6_e4m3_use_256 in _NVFP4_4OVER6_SCOPES, ( + assert self.nvfp4_e4m3_max in _NVFP4_4OVER6_SCOPES, ( "NVTE_NVFP4_4OVER6_E4M3_USE_256 must be unset or one of: " "'weights', 'activations', 'all'." ) @@ -615,7 +615,7 @@ def _make_repr(self) -> str: f"backward_override={self.backward_override}, " f"row_scaled_activation={self.row_scaled_activation}, " f"nvfp4_4over6={self.nvfp4_4over6}, " - f"nvfp4_4over6_e4m3_use_256={self.nvfp4_4over6_e4m3_use_256}, " + f"nvfp4_e4m3_max={self.nvfp4_e4m3_max}, " f"nvfp4_4over6_err_mode={self.nvfp4_4over6_err_mode}, " f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, " f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, " diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu index 4b1950cbd3..576e6139c7 100644 --- a/transformer_engine/common/recipe/nvfp4.cu +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -924,8 +924,8 @@ void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_r void *amax_A_ptr = use_rowwise_amax_A ? tA->amax.dptr : tA->columnwise_amax.dptr; void *amax_B_ptr = use_rowwise_amax_B ? tB->amax.dptr : tB->columnwise_amax.dptr; void *alpha_ptr = tOut->data.dptr; - const float fp8_max_A = tA->nvfp4_4over6_e4m3_use_256 ? 256.0f : 448.0f; - const float fp8_max_B = tB->nvfp4_4over6_e4m3_use_256 ? 256.0f : 448.0f; + const float fp8_max_A = static_cast(tA->nvfp4_e4m3_max); + const float fp8_max_B = static_cast(tB->nvfp4_e4m3_max); // check for not null pointers NVTE_CHECK(amax_A_ptr != nullptr, "amax_A_ptr is null"); diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 1971e78bdc..2378943526 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -858,8 +858,10 @@ void nvte_set_tensor_param_v2(NVTETensor tensor, NVTETensorParam param, const vo case kNVTENVFP44Over6: t.nvfp4_4over6 = static_cast(*reinterpret_cast(buf)); break; - case kNVTENVFP44Over6E4M3Use256: - t.nvfp4_4over6_e4m3_use_256 = static_cast(*reinterpret_cast(buf)); + case kNVTENVFP4E4M3Max: + std::memcpy(&t.nvfp4_e4m3_max, buf, attr_size); + NVTE_CHECK(t.nvfp4_e4m3_max == 448 || t.nvfp4_e4m3_max == 256, + "Unsupported NVFP4 E4M3 max (got ", t.nvfp4_e4m3_max, ")"); break; default: NVTE_ERROR("Unsupported tensor parameter (", static_cast(param), ")"); @@ -947,8 +949,8 @@ void nvte_get_tensor_param_v2(const NVTETensor tensor, NVTETensorParam param, vo case kNVTENVFP44Over6: *reinterpret_cast(buf) = static_cast(t->nvfp4_4over6); break; - case kNVTENVFP44Over6E4M3Use256: - *reinterpret_cast(buf) = static_cast(t->nvfp4_4over6_e4m3_use_256); + case kNVTENVFP4E4M3Max: + std::memcpy(buf, &t->nvfp4_e4m3_max, attr_size); break; default: NVTE_ERROR("Unsupported tensor parameter (", static_cast(param), ")"); @@ -1064,8 +1066,8 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigNVFP44Over6: bool_to_uint8(config_.nvfp4_4over6, buf); break; - case kNVTEQuantizationConfigNVFP44Over6E4M3Use256: - bool_to_uint8(config_.nvfp4_4over6_e4m3_use_256, buf); + case kNVTEQuantizationConfigNVFP4E4M3Max: + std::memcpy(buf, &config_.nvfp4_e4m3_max, attr_size); break; case kNVTEQuantizationConfigNVFP44Over6ErrMode: { const auto val = static_cast(config_.nvfp4_4over6_err_mode); @@ -1133,8 +1135,10 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigNVFP44Over6: uint8_to_bool(buf, config_.nvfp4_4over6); break; - case kNVTEQuantizationConfigNVFP44Over6E4M3Use256: - uint8_to_bool(buf, config_.nvfp4_4over6_e4m3_use_256); + case kNVTEQuantizationConfigNVFP4E4M3Max: + std::memcpy(&config_.nvfp4_e4m3_max, buf, attr_size); + NVTE_CHECK(config_.nvfp4_e4m3_max == 448 || config_.nvfp4_e4m3_max == 256, + "Unsupported NVFP4 E4M3 max (got ", config_.nvfp4_e4m3_max, ")"); break; case kNVTEQuantizationConfigNVFP44Over6ErrMode: { const auto val = *reinterpret_cast(buf); diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h index 87ead8d186..9586927508 100644 --- a/transformer_engine/common/transpose/cast_transpose.h +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -68,7 +68,7 @@ void quantize_transpose_vector_blockwise_fp4( const bool return_identity, const bool return_transpose, const bool pow2_scale, const bool swizzled_scale, const bool use_stochastic_rounding, const NVTETensor rng_state_tensor, const bool use_2d_quantization, const bool row_scaled_nvfp4, - const bool use_4over6, const bool use_4over6_e4m3_use_256, + const bool use_4over6, const int nvfp4_e4m3_max, const NVTENVFP44Over6ErrMode nvfp4_4over6_err_mode, const bool nvfp4_4over6_err_use_fast_math, const SimpleTensor &noop_tensor, cudaStream_t stream); diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index ff8caa5712..a885ea12da 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -307,9 +307,8 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x(const float2 in01, template + bool kApplyStochasticRounding, bool kIs2DBlockScaling, bool kRowScaledNVFP4, int kE4M3Max, + typename FourOverSixConfig> __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel( const IType* const input, const float* global_amax, OType* const output_c, OType* const output_t, ScaleType* const tile_scales_inv_c, ScaleType* const tile_scales_inv_t, @@ -364,7 +363,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo __shared__ CType amax_smem_red[k2DBlockAmaxDim][k2DBlockAmaxDim][k2DBlockAmaxReduceDim]; __shared__ CType amax_smem[k2DBlockAmaxDim][k2DBlockAmaxDim]; constexpr int k4Over62DSelectionDim = - (kUse4Over6 && kIs2DBlockScaling) ? kFP4BlockScalingSize : 1; + (FourOverSixConfig::enabled && kIs2DBlockScaling) ? kFP4BlockScalingSize : 1; using FourOverSixScratch = nvfp4_core::QuantizationScratch4Over6; @@ -417,8 +416,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo const int kNumThreadsReduce = kScaleBlockDim / kNVecOut; const float global_encode_scale = - kIsE8Scaling ? 1.0f - : compute_global_encode_scaling_factor_FP4(global_amax[0]); + kIsE8Scaling ? 1.0f : compute_global_encode_scaling_factor_FP4(global_amax[0]); constexpr float fp4_max_inv = 1.0f / TypeExtrema::max; const float global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; const float global_decode_scale = 1.0 / global_encode_scale; @@ -513,9 +511,9 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo float row_global_encode_scale = global_encode_scale; if constexpr (kRowScaledNVFP4) { row_global_encode_scale = - row_idx < num_rows ? compute_global_encode_scaling_factor_FP4( - global_amax[row_idx]) - : 1.0f; + row_idx < num_rows + ? compute_global_encode_scaling_factor_FP4(global_amax[row_idx]) + : 1.0f; } const float row_global_encode_scale_multiplier = kRowScaledNVFP4 ? row_global_encode_scale * fp4_max_inv : global_encode_scale_multiplier; @@ -524,7 +522,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ScaleType scale_inv; float encode_scale; OVec output_vec; - if constexpr (kUse4Over6) { + if constexpr (FourOverSixConfig::enabled) { const auto scaling_factors = nvfp4_core::compute_4over6_fp4_encode_quantization_scaling_factors( amax, row_global_encode_scale, row_global_decode_scale); @@ -551,8 +549,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo const size_t participant_idx = data_row_idx % kFP4BlockScalingSize; nvfp4_core::QuantizationCandidates4Over6 candidates; - nvfp4_core::quantize_4over6_vec2_array_candidates_16x< - k4Over6ErrMode, kUse4Over6ErrUseFastMath, kUse4Over6E4M3Use256>( + nvfp4_core::quantize_4over6_vec2_array_candidates_16x( smem_vec, scaling_factors, row_global_amax, candidates); const bool pick_map4 = nvfp4_core::record_and_select_4over6_2d_block( + nvfp4_core::quantize_4over6_vec2_array_16x( smem_vec, scaling_factors, row_global_amax, scale_inv, output_vec_4over6); nvfp4_core::store_4over6_packed_16x(output_vec_4over6, output_vec); } @@ -591,7 +588,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo } } // Step 2.6: Quantize - if constexpr (!kUse4Over6) { + if constexpr (!FourOverSixConfig::enabled) { #pragma unroll for (int i = 0; i < kNVecOut / kNVecSMem; i += 2) { // Pack two elements into __nv_bfloat162 @@ -698,7 +695,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ScaleType scale_inv; float encode_scale; OVec output_vec; - if constexpr (kUse4Over6) { + if constexpr (FourOverSixConfig::enabled) { const auto scaling_factors = nvfp4_core::compute_4over6_fp4_encode_quantization_scaling_factors( amax, global_encode_scale, global_decode_scale); @@ -717,8 +714,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo const size_t participant_idx = data_col_idx % kFP4BlockScalingSize; nvfp4_core::QuantizationCandidates4Over6 candidates; - nvfp4_core::quantize_4over6_vec_index_candidates_16x< - k4Over6ErrMode, kUse4Over6ErrUseFastMath, kUse4Over6E4M3Use256>( + nvfp4_core::quantize_4over6_vec_index_candidates_16x( smem_vec, smem_idx, scaling_factors, global_amax[0], candidates); const bool pick_map4 = nvfp4_core::record_and_select_4over6_2d_block( + nvfp4_core::quantize_4over6_vec_index_16x( smem_vec, smem_idx, scaling_factors, global_amax[0], scale_inv, output_vec_4over6); nvfp4_core::store_4over6_packed_16x(output_vec_4over6, output_vec); } @@ -757,7 +753,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo } } // Step 3.6: Quantize - if constexpr (!kUse4Over6) { + if constexpr (!FourOverSixConfig::enabled) { #pragma unroll for (int i = 0; i < kNVecOut / kNFP4PerContainer; i += 2) { // Pack two elements into __nv_bfloat162 @@ -814,7 +810,7 @@ void quantize_transpose_vector_blockwise_fp4( const bool return_identity, const bool return_transpose, const bool pow2_scale, const bool swizzled_scale, const bool use_stochastic_rounding, const NVTETensor rng_state_tensor, const bool use_2d_quantization, const bool row_scaled_nvfp4, - const bool use_4over6, const bool use_4over6_e4m3_use_256, + const bool use_4over6, const int nvfp4_e4m3_max, const NVTENVFP44Over6ErrMode nvfp4_4over6_err_mode, const bool nvfp4_4over6_err_use_fast_math, const SimpleTensor& noop_tensor, cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_vector_blockwise_fp4); @@ -835,12 +831,14 @@ void quantize_transpose_vector_blockwise_fp4( "Row-scaled NVFP4 quantization does not support 2D quantization."); NVTE_CHECK(!use_4over6 || !use_stochastic_rounding, "NVFP4 4over6 quantization does not support stochastic rounding."); - NVTE_CHECK(use_4over6 || !use_4over6_e4m3_use_256, - "NVFP4 4over6 E4M3 256 scale bound requires 4over6 quantization."); + NVTE_CHECK(nvfp4_e4m3_max == 448 || nvfp4_e4m3_max == 256, "Unsupported NVFP4 E4M3 max (got ", + nvfp4_e4m3_max, ")"); + NVTE_CHECK(use_4over6 || nvfp4_e4m3_max == 448, + "Non-4over6 NVFP4 quantization requires E4M3 max 448."); const NVTENVFP44Over6ErrMode use_4over6_err_mode = use_4over6 ? nvfp4_4over6_err_mode : kNVTENVFP44Over6ErrMAE; const bool use_4over6_err_use_fast_math = use_4over6 && nvfp4_4over6_err_use_fast_math; - const bool enabled_4over6_e4m3_use_256 = use_4over6 && use_4over6_e4m3_use_256; + const int enabled_nvfp4_e4m3_max = use_4over6 ? nvfp4_e4m3_max : 448; const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; size_t num_elements = row_length; @@ -921,8 +919,8 @@ void quantize_transpose_vector_blockwise_fp4( TRANSFORMER_ENGINE_SWITCH_CONDITION( row_scaled_nvfp4, kRowScaledNVFP4, - TRANSFORMER_ENGINE_SWITCH_CONDITION( - enabled_4over6_e4m3_use_256, kUse4Over6E4M3Use256, + TRANSFORMER_ENGINE_NVFP4_4OVER6_E4M3_MAX_SWITCH( + enabled_nvfp4_e4m3_max, kE4M3Max, TRANSFORMER_ENGINE_NVFP4_4OVER6_ERR_MODE_SWITCH( use_4over6_err_mode, k4Over6ErrMode, @@ -932,14 +930,17 @@ void quantize_transpose_vector_blockwise_fp4( kUse4Over6ErrUseFastMath, size_t smem_bytes = kSMemSize * sizeof(InputType); + using FourOverSixConfig = + nvfp4_core::NVFP44Over6Config< + true, k4Over6ErrMode, + kUse4Over6ErrUseFastMath>; auto kernel = block_scaled_1d_cast_transpose_kernel< kReturnIdentity, kReturnTranspose, kPow2Scale, kAligned, float, InputType, OutputType, ScaleType, kSwizzledScale, /*kApplyStochasticRounding=*/false, - kIs2DBlockScaling, kRowScaledNVFP4, - /*kUse4Over6=*/true, kUse4Over6E4M3Use256, - k4Over6ErrMode, kUse4Over6ErrUseFastMath>; + kIs2DBlockScaling, kRowScaledNVFP4, kE4M3Max, + FourOverSixConfig>; if (smem_bytes >= 48 * 1024) { cudaError_t err = cudaFuncSetAttribute( kernel, @@ -963,7 +964,7 @@ void quantize_transpose_vector_blockwise_fp4( rng_state, noop_ptr);) // kUse4Over6ErrUseFastMath ) // k4Over6ErrMode - ) // kUse4Over6E4M3Use256 + ) // kE4M3Max ) // kRowScaledNVFP4 ) // kIs2DBlockScaling } else { @@ -981,9 +982,8 @@ void quantize_transpose_vector_blockwise_fp4( kReturnIdentity, kReturnTranspose, kPow2Scale, kAligned, float, InputType, OutputType, ScaleType, kSwizzledScale, kApplyStochasticRounding, kIs2DBlockScaling, - kRowScaledNVFP4, /*kUse4Over6=*/false, - /*kUse4Over6E4M3Use256=*/false, kNVTENVFP44Over6ErrMAE, - /*kUse4Over6ErrUseFastMath=*/false>; + kRowScaledNVFP4, /*kE4M3Max=*/448, + nvfp4_core::NVFP44Over6DisabledConfig>; if (smem_bytes >= 48 * 1024) { cudaError_t err = cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 8b86f26b24..2ce132adb4 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -329,8 +329,8 @@ class NVFP4Quantizer : public Quantizer { bool stochastic_rounding; // Whether emitted NVFP4 tensors use 4over6 candidate selection. bool use_4over6; - // Whether emitted NVFP4 4over6 tensors use 256 as the global E4M3 scale bound. - bool four_over_six_e4m3_use_256; + // Global E4M3 scale bound used by emitted NVFP4 tensors. + int nvfp4_e4m3_max; NVTENVFP44Over6ErrMode nvfp4_4over6_err_mode; // Whether tensors emitted by this quantizer use row-scaled NVFP4 metadata. bool row_scaled_nvfp4; diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index caaca112d2..d64098903d 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -725,7 +725,7 @@ std::tuple, std::vector, bool> bulk_alloc const auto rowwise_usage = quantizer_cpp_list[0]->rowwise_usage; const bool row_scaled_nvfp4 = quantizer_cpp_list[0]->row_scaled_nvfp4; const bool use_4over6 = quantizer_cpp_list[0]->use_4over6; - const bool four_over_six_e4m3_use_256 = quantizer_cpp_list[0]->four_over_six_e4m3_use_256; + const int nvfp4_e4m3_max = quantizer_cpp_list[0]->nvfp4_e4m3_max; const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage; if (row_scaled_nvfp4) { NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 bulk allocation requires rowwise usage."); @@ -874,7 +874,7 @@ std::tuple, std::vector, bool> bulk_alloc rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, amax_rowwise, amax_columnwise, fp4_dtype, quantizer_py_list[i], with_gemm_swizzled_scales, py::arg("row_scaled_nvfp4") = row_scaled_nvfp4, py::arg("use_4over6") = use_4over6, - py::arg("four_over_six_e4m3_use_256") = four_over_six_e4m3_use_256)); + py::arg("nvfp4_e4m3_max") = nvfp4_e4m3_max)); // Construct C++ tensor // Use a TensorWrapper variable to hold the output of makeTransformerEngineTensor, @@ -893,7 +893,7 @@ std::tuple, std::vector, bool> bulk_alloc tensor_wrapper.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); tensor_wrapper.set_row_scaled_nvfp4(row_scaled_nvfp4); tensor_wrapper.set_nvfp4_4over6(use_4over6); - tensor_wrapper.set_nvfp4_4over6_e4m3_use_256(four_over_six_e4m3_use_256); + tensor_wrapper.set_nvfp4_e4m3_max(nvfp4_e4m3_max); // Set the amax rowwise and amax columnwise if available if (rowwise_usage) { @@ -1043,12 +1043,12 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, for (auto &config : quant_config_list) { config.set_nvfp4_4over6(quantizer.use_4over6); - config.set_nvfp4_4over6_e4m3_use_256(quantizer.four_over_six_e4m3_use_256); + config.set_nvfp4_e4m3_max(quantizer.nvfp4_e4m3_max); config.set_nvfp4_4over6_err_mode(quantizer.nvfp4_4over6_err_mode); } for (auto &config : quant_config_list_colwise) { config.set_nvfp4_4over6(quantizer.use_4over6); - config.set_nvfp4_4over6_e4m3_use_256(quantizer.four_over_six_e4m3_use_256); + config.set_nvfp4_e4m3_max(quantizer.nvfp4_e4m3_max); config.set_nvfp4_4over6_err_mode(quantizer.nvfp4_4over6_err_mode); } @@ -1226,7 +1226,7 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, for (auto &config : quant_config_list) { config.set_nvfp4_4over6(quantizer.use_4over6); - config.set_nvfp4_4over6_e4m3_use_256(quantizer.four_over_six_e4m3_use_256); + config.set_nvfp4_e4m3_max(quantizer.nvfp4_e4m3_max); config.set_nvfp4_4over6_err_mode(quantizer.nvfp4_4over6_err_mode); } diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index e671fa996c..6d6e173e1b 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1730,9 +1730,10 @@ NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantize this->with_2d_quantization = quantizer.attr("with_2d_quantization").cast(); this->stochastic_rounding = quantizer.attr("stochastic_rounding").cast(); this->use_4over6 = quantizer.attr("use_4over6").cast(); - this->four_over_six_e4m3_use_256 = - this->use_4over6 && quantizer.attr("four_over_six_e4m3_use_256").cast(); - const auto nvfp4_4over6_err_mode = quantizer.attr("four_over_six_err_mode").cast(); + this->nvfp4_e4m3_max = this->use_4over6 ? quantizer.attr("nvfp4_e4m3_max").cast() : 448; + NVTE_CHECK(this->nvfp4_e4m3_max == 448 || this->nvfp4_e4m3_max == 256, + "Unsupported NVFP4 E4M3 max: ", this->nvfp4_e4m3_max); + const auto nvfp4_4over6_err_mode = quantizer.attr("nvfp4_4over6_err_mode").cast(); if (nvfp4_4over6_err_mode == "MAE") { this->nvfp4_4over6_err_mode = kNVTENVFP44Over6ErrMAE; } else if (nvfp4_4over6_err_mode == "MSE") { @@ -1790,7 +1791,7 @@ std::pair NVFP4Quantizer::create_tensor( " (got shape=", shape, ")"); const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; const bool use_4over6 = this->use_4over6; - const bool four_over_six_e4m3_use_256 = this->four_over_six_e4m3_use_256; + const int nvfp4_e4m3_max = this->nvfp4_e4m3_max; if (row_scaled_nvfp4) { NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 quantization requires rowwise usage."); NVTE_CHECK(!columnwise_usage, @@ -1859,7 +1860,7 @@ std::pair NVFP4Quantizer::create_tensor( kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); kwargs["use_4over6"] = py::cast(use_4over6); - kwargs["four_over_six_e4m3_use_256"] = py::cast(four_over_six_e4m3_use_256); + kwargs["nvfp4_e4m3_max"] = py::cast(nvfp4_e4m3_max); kwargs["fake_dtype"] = GetATenDType(dtype); py::tuple args(0); @@ -1891,7 +1892,7 @@ std::pair NVFP4Quantizer::create_tensor( kwargs["device"] = py::cast(device); kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); kwargs["use_4over6"] = py::cast(use_4over6); - kwargs["four_over_six_e4m3_use_256"] = py::cast(four_over_six_e4m3_use_256); + kwargs["nvfp4_e4m3_max"] = py::cast(nvfp4_e4m3_max); py::tuple args(0); PyObject* result = PyObject_Call(reinterpret_cast(NVFP4TensorPythonClass), args.ptr(), kwargs.ptr()); @@ -1926,7 +1927,7 @@ std::pair NVFP4Quantizer::create_tensor( out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); out_cpp.set_row_scaled_nvfp4(row_scaled_nvfp4); out_cpp.set_nvfp4_4over6(use_4over6); - out_cpp.set_nvfp4_4over6_e4m3_use_256(four_over_six_e4m3_use_256); + out_cpp.set_nvfp4_e4m3_max(nvfp4_e4m3_max); this->set_quantization_params(&out_cpp); return {std::move(out_cpp), std::move(out_py)}; @@ -1956,7 +1957,7 @@ std::pair NVFP4Quantizer::create_grouped_tenso const std::vector logical_shape_vec = {logical_first_dim, logical_last_dim}; const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; const bool use_4over6 = this->use_4over6; - const bool four_over_six_e4m3_use_256 = this->four_over_six_e4m3_use_256; + const int nvfp4_e4m3_max = this->nvfp4_e4m3_max; if (row_scaled_nvfp4) { NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 grouped quantization requires rowwise usage."); NVTE_CHECK(!columnwise_usage, @@ -2032,7 +2033,7 @@ std::pair NVFP4Quantizer::create_grouped_tenso kwargs["with_gemm_swizzled_scales"] = this->optimize_for_gemm; kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); kwargs["use_4over6"] = py::cast(use_4over6); - kwargs["four_over_six_e4m3_use_256"] = py::cast(four_over_six_e4m3_use_256); + kwargs["nvfp4_e4m3_max"] = py::cast(nvfp4_e4m3_max); PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); if (result == nullptr) { PyErr_Print(); @@ -2109,7 +2110,7 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( const auto [flat_first_dim, flat_last_dim] = get_2d_dims(shape); const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; const bool use_4over6 = this->use_4over6; - const bool four_over_six_e4m3_use_256 = this->four_over_six_e4m3_use_256; + const int nvfp4_e4m3_max = this->nvfp4_e4m3_max; if (row_scaled_nvfp4) { NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 quantization requires rowwise usage."); NVTE_CHECK(!columnwise_usage, @@ -2117,7 +2118,7 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( } tensor.attr("_row_scaled_nvfp4") = py::cast(row_scaled_nvfp4); tensor.attr("_use_4over6") = py::cast(use_4over6); - tensor.attr("_four_over_six_e4m3_use_256") = py::cast(four_over_six_e4m3_use_256); + tensor.attr("_nvfp4_e4m3_max") = py::cast(nvfp4_e4m3_max); // Coerce row-wise data if (rowwise_usage) { @@ -2223,7 +2224,7 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); out_cpp.set_row_scaled_nvfp4(row_scaled_nvfp4); out_cpp.set_nvfp4_4over6(use_4over6); - out_cpp.set_nvfp4_4over6_e4m3_use_256(four_over_six_e4m3_use_256); + out_cpp.set_nvfp4_e4m3_max(nvfp4_e4m3_max); this->set_quantization_params(&out_cpp); return {std::move(out_cpp), std::move(tensor)}; @@ -2315,10 +2316,10 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou quant_config.set_nvfp4_2d_quantization(this->with_2d_quantization); quant_config.set_stochastic_rounding(this->stochastic_rounding); quant_config.set_nvfp4_4over6(this->use_4over6); - quant_config.set_nvfp4_4over6_e4m3_use_256(this->four_over_six_e4m3_use_256); + quant_config.set_nvfp4_e4m3_max(this->nvfp4_e4m3_max); quant_config.set_nvfp4_4over6_err_mode(this->nvfp4_4over6_err_mode); quant_config_columnwise.set_nvfp4_4over6(this->use_4over6); - quant_config_columnwise.set_nvfp4_4over6_e4m3_use_256(this->four_over_six_e4m3_use_256); + quant_config_columnwise.set_nvfp4_e4m3_max(this->nvfp4_e4m3_max); quant_config_columnwise.set_nvfp4_4over6_err_mode(this->nvfp4_4over6_err_mode); if (this->use_4over6) { diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index fc707a209d..e25dacf76e 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -136,7 +136,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) const bool with_gemm_swizzled_scales = tensor.attr("_with_gemm_swizzled_scales").cast(); const bool row_scaled_nvfp4 = tensor.attr("_row_scaled_nvfp4").cast(); const bool use_4over6 = tensor.attr("_use_4over6").cast(); - const bool four_over_six_e4m3_use_256 = tensor.attr("_four_over_six_e4m3_use_256").cast(); + const int nvfp4_e4m3_max = tensor.attr("_nvfp4_e4m3_max").cast(); NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for NVFP4 Tensor."); @@ -168,7 +168,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) ret.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); ret.set_row_scaled_nvfp4(row_scaled_nvfp4); ret.set_nvfp4_4over6(use_4over6); - ret.set_nvfp4_4over6_e4m3_use_256(four_over_six_e4m3_use_256); + ret.set_nvfp4_e4m3_max(nvfp4_e4m3_max); // Quantizer state quantizer->set_quantization_params(&ret); diff --git a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py index 7d8e8d4745..627b098b1a 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py @@ -222,7 +222,7 @@ class NVFP4TensorRef(QuantizedTensorStorage): global_amax_row: Optional[torch.Tensor] = None global_amax_col: Optional[torch.Tensor] = None use_4over6: bool = False - four_over_six_e4m3_use_256: bool = False + nvfp4_e4m3_max: int = 448 dtype: Optional[torch.dtype] = None device: Optional[torch.device] = None @@ -353,14 +353,14 @@ def __init__( quant_tile_shape: Tuple[int, int] = (1, 16), row_scaled_nvfp4: bool = False, use_4over6: bool = False, - four_over_six_e4m3_use_256: bool = False, - four_over_six_err_mode: str = "MAE", + nvfp4_e4m3_max: int = 448, + nvfp4_4over6_err_mode: str = "MAE", with_rht: bool = False, with_random_sign_mask: bool = True, ): - four_over_six_err_mode = four_over_six_err_mode.upper() - if four_over_six_err_mode not in ("MAE", "MSE"): - raise ValueError("four_over_six_err_mode must be 'MAE' or 'MSE'.") + nvfp4_4over6_err_mode = nvfp4_4over6_err_mode.upper() + if nvfp4_4over6_err_mode not in ("MAE", "MSE"): + raise ValueError("nvfp4_4over6_err_mode must be 'MAE' or 'MSE'.") if row_scaled_nvfp4: if not rowwise: raise ValueError("Row-scaled NVFP4 reference quantization requires rowwise usage.") @@ -384,8 +384,10 @@ def __init__( self.quant_tile_shape = quant_tile_shape self.row_scaled_nvfp4 = row_scaled_nvfp4 self.use_4over6 = use_4over6 - self.four_over_six_e4m3_use_256 = use_4over6 and four_over_six_e4m3_use_256 - self.four_over_six_err_mode = four_over_six_err_mode + self.nvfp4_e4m3_max = nvfp4_e4m3_max if use_4over6 else 448 + if self.nvfp4_e4m3_max not in (448, 256): + raise ValueError("nvfp4_e4m3_max must be 448 or 256.") + self.nvfp4_4over6_err_mode = nvfp4_4over6_err_mode self.with_rht = with_rht self.with_random_sign_mask = with_random_sign_mask @@ -473,8 +475,8 @@ def _quantize_blockwise_4over6_reference( global_decode_scale: torch.Tensor, row_scaled_nvfp4: bool, tile_len_y: int, - four_over_six_err_mode: str, - four_over_six_e4m3_use_256: bool, + nvfp4_4over6_err_mode: str, + nvfp4_e4m3_max: int, ) -> Tuple[torch.Tensor, torch.Tensor]: """Quantize NVFP4 with 4over6 candidate selection. @@ -487,9 +489,7 @@ def _quantize_blockwise_4over6_reference( FLOAT4_E2M1_MAX = torch.tensor(6.0, device=x.device, dtype=torch.float32) FLOAT8_E4M3_MAX = torch.tensor(448.0, device=x.device, dtype=torch.float32) GLOBAL_SCALE_E4M3_MAX = torch.tensor( - 256.0 if four_over_six_e4m3_use_256 else 448.0, - device=x.device, - dtype=torch.float32, + float(nvfp4_e4m3_max), device=x.device, dtype=torch.float32 ) decode_scale_base = torch.div(vec_max, FLOAT4_E2M1_MAX) * global_encode_scale @@ -546,7 +546,7 @@ def _quantize_blockwise_4over6_reference( val_map4 = val_map4 * error_global_amax val_map4 = val_map4 / denom diff_map4 = val_map4 - x_float[:, :, idx] - if four_over_six_err_mode == "MSE": + if nvfp4_4over6_err_mode == "MSE": err_map4 = err_map4 + (diff_map4 * diff_map4).unsqueeze(-1) else: err_map4 = err_map4 + torch.abs(diff_map4).unsqueeze(-1) @@ -555,7 +555,7 @@ def _quantize_blockwise_4over6_reference( val_map6 = val_map6 * error_global_amax val_map6 = val_map6 / denom diff_map6 = val_map6 - x_float[:, :, idx] - if four_over_six_err_mode == "MSE": + if nvfp4_4over6_err_mode == "MSE": err_map6 = err_map6 + (diff_map6 * diff_map6).unsqueeze(-1) else: err_map6 = err_map6 + torch.abs(diff_map6).unsqueeze(-1) @@ -584,8 +584,8 @@ def _quantize_blockwise_reference( pow_2_scales: bool, row_scaled_nvfp4: bool = False, use_4over6: bool = False, - four_over_six_e4m3_use_256: bool = False, - four_over_six_err_mode: str = "MAE", + nvfp4_e4m3_max: int = 448, + nvfp4_4over6_err_mode: str = "MAE", eps: float, # pylint: disable=unused-argument ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -618,7 +618,7 @@ def _quantize_blockwise_reference( x = x.view(m, n // tile_len_x, tile_len_x) FLOAT4_E2M1_MAX = torch.tensor(6.0, device=x.device, dtype=torch.float32) FLOAT8_E4M3_MAX = torch.tensor(448.0, device=x.device, dtype=torch.float32) - global_scale_e4m3_max = 256.0 if (use_4over6 and four_over_six_e4m3_use_256) else 448.0 + global_scale_e4m3_max = float(nvfp4_e4m3_max if use_4over6 else 448) GLOBAL_SCALE_E4M3_MAX = torch.tensor( global_scale_e4m3_max, device=x.device, dtype=torch.float32 ) @@ -665,8 +665,8 @@ def _quantize_blockwise_reference( global_decode_scale, row_scaled_nvfp4, tile_len_y, - four_over_six_err_mode, - four_over_six_e4m3_use_256, + nvfp4_4over6_err_mode, + nvfp4_e4m3_max, ) global_encode_scale_multiplier = global_encode_scale * torch.reciprocal(FLOAT4_E2M1_MAX) @@ -830,8 +830,8 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ pow_2_scales=self.pow_2_scales, row_scaled_nvfp4=self.row_scaled_nvfp4, use_4over6=self.use_4over6, - four_over_six_e4m3_use_256=self.four_over_six_e4m3_use_256, - four_over_six_err_mode=self.four_over_six_err_mode, + nvfp4_e4m3_max=self.nvfp4_e4m3_max, + nvfp4_4over6_err_mode=self.nvfp4_4over6_err_mode, eps=self.eps, ) if transpose_scales: @@ -856,8 +856,8 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ self.quant_tile_shape[0], pow_2_scales=self.pow_2_scales, use_4over6=self.use_4over6, - four_over_six_e4m3_use_256=self.four_over_six_e4m3_use_256, - four_over_six_err_mode=self.four_over_six_err_mode, + nvfp4_e4m3_max=self.nvfp4_e4m3_max, + nvfp4_4over6_err_mode=self.nvfp4_4over6_err_mode, eps=self.eps, ) @@ -898,7 +898,7 @@ def quantize( global_amax_row=global_amax_row, global_amax_col=global_amax_col, use_4over6=self.use_4over6, - four_over_six_e4m3_use_256=self.four_over_six_e4m3_use_256, + nvfp4_e4m3_max=self.nvfp4_e4m3_max, dtype=tensor.dtype, device=tensor.device, quant_dtype=self.dtype, @@ -947,7 +947,7 @@ def update_quantized( dst.global_amax_row = global_amax_row dst.global_amax_col = global_amax_col dst.use_4over6 = self.use_4over6 - dst.four_over_six_e4m3_use_256 = self.four_over_six_e4m3_use_256 + dst.nvfp4_e4m3_max = self.nvfp4_e4m3_max dst.dtype = src.dtype dst.quant_dtype = self.dtype dst.original_shape = original_shape @@ -1059,22 +1059,22 @@ def qgemm( qresult_w_use_4over6 = getattr( qresult_w, "use_4over6", getattr(qresult_w, "_use_4over6", self.use_4over6) ) - qresult_x_use_256 = getattr( + qresult_x_e4m3_max = getattr( qresult_x, - "four_over_six_e4m3_use_256", - getattr(qresult_x, "_four_over_six_e4m3_use_256", self.four_over_six_e4m3_use_256), + "nvfp4_e4m3_max", + getattr(qresult_x, "_nvfp4_e4m3_max", self.nvfp4_e4m3_max), ) - qresult_w_use_256 = getattr( + qresult_w_e4m3_max = getattr( qresult_w, - "four_over_six_e4m3_use_256", - getattr(qresult_w, "_four_over_six_e4m3_use_256", self.four_over_six_e4m3_use_256), + "nvfp4_e4m3_max", + getattr(qresult_w, "_nvfp4_e4m3_max", self.nvfp4_e4m3_max), ) - if qresult_x_use_4over6 and qresult_x_use_256: - fp8_max_x = 256.0 + if qresult_x_use_4over6: + fp8_max_x = float(qresult_x_e4m3_max) else: fp8_max_x = 448.0 - if qresult_w_use_4over6 and qresult_w_use_256: - fp8_max_w = 256.0 + if qresult_w_use_4over6: + fp8_max_w = float(qresult_w_e4m3_max) else: fp8_max_w = 448.0 factor = 6.0 * 6.0 * fp8_max_x * fp8_max_w diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index bab8d15559..b29119798c 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -1662,14 +1662,16 @@ def _make(tensor_type: str) -> NVFP4Quantizer: use_4over6 = tensor_type == "weight" elif self.recipe.nvfp4_4over6 == "activations": use_4over6 = tensor_type != "weight" - use_4over6_e4m3_use_256 = False + nvfp4_e4m3_max = 448 if use_4over6: - if self.recipe.nvfp4_4over6_e4m3_use_256 == "all": - use_4over6_e4m3_use_256 = True - elif self.recipe.nvfp4_4over6_e4m3_use_256 == "weights": - use_4over6_e4m3_use_256 = tensor_type == "weight" - elif self.recipe.nvfp4_4over6_e4m3_use_256 == "activations": - use_4over6_e4m3_use_256 = tensor_type != "weight" + if self.recipe.nvfp4_e4m3_max == "all": + nvfp4_e4m3_max = 256 + elif self.recipe.nvfp4_e4m3_max == "weights": + if tensor_type == "weight": + nvfp4_e4m3_max = 256 + elif self.recipe.nvfp4_e4m3_max == "activations": + if tensor_type != "weight": + nvfp4_e4m3_max = 256 return NVFP4Quantizer( fp4_dtype=self.dtype, rowwise=True, @@ -1684,8 +1686,8 @@ def _make(tensor_type: str) -> NVFP4Quantizer: and self.recipe.row_scaled_activation ), use_4over6=use_4over6, - four_over_six_e4m3_use_256=use_4over6_e4m3_use_256, - four_over_six_err_mode=self.recipe.nvfp4_4over6_err_mode, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=self.recipe.nvfp4_4over6_err_mode, ) if self.mode not in ("forward", "backward"): diff --git a/transformer_engine/pytorch/tensor/grouped_tensor.py b/transformer_engine/pytorch/tensor/grouped_tensor.py index e47389dc27..51e89d8829 100644 --- a/transformer_engine/pytorch/tensor/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/grouped_tensor.py @@ -94,7 +94,7 @@ def __new__( with_gemm_swizzled_scales: bool = False, row_scaled_nvfp4: bool = False, use_4over6: bool = False, - four_over_six_e4m3_use_256: bool = False, + nvfp4_e4m3_max: int = 448, ): if ( shapes is not None @@ -169,7 +169,7 @@ def __new__( with_gemm_swizzled_scales=with_gemm_swizzled_scales, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, - four_over_six_e4m3_use_256=four_over_six_e4m3_use_256, + nvfp4_e4m3_max=nvfp4_e4m3_max, ) return instance @@ -203,7 +203,7 @@ def copy_grouped_storage_metadata(dst: GroupedTensor, src: GroupedTensor) -> Non dst._with_gemm_swizzled_scales = src._with_gemm_swizzled_scales dst.row_scaled_nvfp4 = src.row_scaled_nvfp4 dst.use_4over6 = src.use_4over6 - dst.four_over_six_e4m3_use_256 = src.four_over_six_e4m3_use_256 + dst.nvfp4_e4m3_max = src.nvfp4_e4m3_max def make_wrapper_like(src: GroupedTensor, requires_grad: bool) -> GroupedTensor: """Create a wrapper of the same type and tensor metadata as src.""" diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index cef3bd4f2b..afc9e102d5 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -132,10 +132,10 @@ class NVFP4Quantizer(Quantizer): row_scaled_nvfp4: bool """Whether to use NVFP4 4over6 map-to-4/map-to-6 block selection.""" use_4over6: bool - """Whether 4over6 uses 256 instead of 448 as the global E4M3 scale bound.""" - four_over_six_e4m3_use_256: bool + """Global E4M3 scale bound used by emitted NVFP4 tensors.""" + nvfp4_e4m3_max: int """NVFP4 4over6 candidate-selection error mode.""" - four_over_six_err_mode: str + nvfp4_4over6_err_mode: str """RHT matrix random sign mask""" rht_matrix_random_sign_mask_t: int @@ -154,8 +154,8 @@ def __init__( stochastic_rounding: bool = False, row_scaled_nvfp4: bool = False, use_4over6: bool = False, - four_over_six_e4m3_use_256: bool = False, - four_over_six_err_mode: str = "MAE", + nvfp4_e4m3_max: int = 448, + nvfp4_4over6_err_mode: str = "MAE", with_random_sign_mask: bool = True, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) @@ -168,10 +168,12 @@ def __init__( self.stochastic_rounding = stochastic_rounding self.row_scaled_nvfp4 = row_scaled_nvfp4 self.use_4over6 = use_4over6 - self.four_over_six_e4m3_use_256 = use_4over6 and four_over_six_e4m3_use_256 - self.four_over_six_err_mode = four_over_six_err_mode.upper() - if self.four_over_six_err_mode not in ("MAE", "MSE"): - raise ValueError("four_over_six_err_mode must be 'MAE' or 'MSE'.") + self.nvfp4_e4m3_max = nvfp4_e4m3_max if use_4over6 else 448 + if self.nvfp4_e4m3_max not in (448, 256): + raise ValueError("nvfp4_e4m3_max must be 448 or 256.") + self.nvfp4_4over6_err_mode = nvfp4_4over6_err_mode.upper() + if self.nvfp4_4over6_err_mode not in ("MAE", "MSE"): + raise ValueError("nvfp4_4over6_err_mode must be 'MAE' or 'MSE'.") self.rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht( with_random_sign_mask, torch.cuda.current_device() ) @@ -219,8 +221,8 @@ def copy(self) -> NVFP4Quantizer: stochastic_rounding=self.stochastic_rounding, row_scaled_nvfp4=self.row_scaled_nvfp4, use_4over6=self.use_4over6, - four_over_six_e4m3_use_256=self.four_over_six_e4m3_use_256, - four_over_six_err_mode=self.four_over_six_err_mode, + nvfp4_e4m3_max=self.nvfp4_e4m3_max, + nvfp4_4over6_err_mode=self.nvfp4_4over6_err_mode, ) quantizer.internal = self.internal quantizer.optimize_for_gemm = self.optimize_for_gemm @@ -374,7 +376,7 @@ def __new__( with_gemm_swizzled_scales: bool, row_scaled_nvfp4: bool = False, use_4over6: bool = False, - four_over_six_e4m3_use_256: bool = False, + nvfp4_e4m3_max: int = 448, **kwargs, ): instance = super().__new__( @@ -391,7 +393,7 @@ def __new__( *args, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, - four_over_six_e4m3_use_256=four_over_six_e4m3_use_256, + nvfp4_e4m3_max=nvfp4_e4m3_max, **kwargs, ) return instance @@ -551,7 +553,7 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m self._amax_columnwise, self._row_scaled_nvfp4, self._use_4over6, - self._four_over_six_e4m3_use_256, + self._nvfp4_e4m3_max, self.shape[-1], ) return sharded_tensors, metadata @@ -577,7 +579,7 @@ def fsdp_post_all_gather( amax_columnwise, row_scaled_nvfp4, use_4over6, - four_over_six_e4m3_use_256, + nvfp4_e4m3_max, K, ) = metadata @@ -604,7 +606,7 @@ def fsdp_post_all_gather( out._amax_columnwise = amax_columnwise out._row_scaled_nvfp4 = row_scaled_nvfp4 out._use_4over6 = use_4over6 - out._four_over_six_e4m3_use_256 = four_over_six_e4m3_use_256 + out._nvfp4_e4m3_max = nvfp4_e4m3_max else: # Construct new tensor (first iteration) out = NVFP4Tensor( @@ -622,7 +624,7 @@ def fsdp_post_all_gather( with_gemm_swizzled_scales=False, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, - four_over_six_e4m3_use_256=four_over_six_e4m3_use_256, + nvfp4_e4m3_max=nvfp4_e4m3_max, ) # Derive columnwise data locally via transpose instead of all-gathering it @@ -763,7 +765,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, row_scaled_nvfp4=tensor._row_scaled_nvfp4, use_4over6=tensor._use_4over6, - four_over_six_e4m3_use_256=tensor._four_over_six_e4m3_use_256, + nvfp4_e4m3_max=tensor._nvfp4_e4m3_max, ) # Default case @@ -785,7 +787,7 @@ def _make_in_reduce_ex( with_gemm_swizzled_scales: bool = False, row_scaled_nvfp4: bool = False, use_4over6: bool = False, - four_over_six_e4m3_use_256: bool = False, + nvfp4_e4m3_max: int = 448, ) -> NVFP4Tensor: """Build NVFP4Tensor, for use in __reduce__ @@ -808,7 +810,7 @@ def _make_in_reduce_ex( with_gemm_swizzled_scales=with_gemm_swizzled_scales, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, - four_over_six_e4m3_use_256=four_over_six_e4m3_use_256, + nvfp4_e4m3_max=nvfp4_e4m3_max, ) def __reduce_ex__(self, protocol: int) -> tuple: @@ -829,7 +831,7 @@ def __reduce_ex__(self, protocol: int) -> tuple: self._with_gemm_swizzled_scales, self._row_scaled_nvfp4, self._use_4over6, - self._four_over_six_e4m3_use_256, + self._nvfp4_e4m3_max, ), ) @@ -884,7 +886,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: self._with_gemm_swizzled_scales = tensor._with_gemm_swizzled_scales self._row_scaled_nvfp4 = tensor._row_scaled_nvfp4 self._use_4over6 = tensor._use_4over6 - self._four_over_six_e4m3_use_256 = tensor._four_over_six_e4m3_use_256 + self._nvfp4_e4m3_max = tensor._nvfp4_e4m3_max return # Quantize to FP8 @@ -1007,7 +1009,7 @@ def forward( with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, row_scaled_nvfp4=tensor._row_scaled_nvfp4, use_4over6=tensor._use_4over6, - four_over_six_e4m3_use_256=tensor._four_over_six_e4m3_use_256, + nvfp4_e4m3_max=tensor._nvfp4_e4m3_max, ) @staticmethod @@ -1052,7 +1054,7 @@ def backward( with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales, row_scaled_nvfp4=grad._row_scaled_nvfp4, use_4over6=grad._use_4over6, - four_over_six_e4m3_use_256=grad._four_over_six_e4m3_use_256, + nvfp4_e4m3_max=grad._nvfp4_e4m3_max, ) return dgrad, None return grad.view(ctx.shape), None @@ -1139,7 +1141,7 @@ def forward( with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, row_scaled_nvfp4=tensor._row_scaled_nvfp4, use_4over6=tensor._use_4over6, - four_over_six_e4m3_use_256=tensor._four_over_six_e4m3_use_256, + nvfp4_e4m3_max=tensor._nvfp4_e4m3_max, ) @staticmethod @@ -1184,7 +1186,7 @@ def backward( with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales, row_scaled_nvfp4=grad._row_scaled_nvfp4, use_4over6=grad._use_4over6, - four_over_six_e4m3_use_256=grad._four_over_six_e4m3_use_256, + nvfp4_e4m3_max=grad._nvfp4_e4m3_max, ) return dgrad, None return grad.view(ctx.shape), None diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index dfde0e9e18..04e396be31 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -74,7 +74,7 @@ def _initialize_storage_fields( with_gemm_swizzled_scales: bool = False, row_scaled_nvfp4: bool = False, use_4over6: bool = False, - four_over_six_e4m3_use_256: bool = False, + nvfp4_e4m3_max: int = 448, ) -> None: """ Initialize a GroupedTensor. @@ -152,7 +152,7 @@ def _initialize_storage_fields( instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales instance.row_scaled_nvfp4 = row_scaled_nvfp4 instance.use_4over6 = use_4over6 - instance.four_over_six_e4m3_use_256 = use_4over6 and four_over_six_e4m3_use_256 + instance.nvfp4_e4m3_max = nvfp4_e4m3_max if use_4over6 else 448 def __new__( cls, @@ -180,7 +180,7 @@ def __new__( with_gemm_swizzled_scales: bool = False, row_scaled_nvfp4: bool = False, use_4over6: bool = False, - four_over_six_e4m3_use_256: bool = False, + nvfp4_e4m3_max: int = 448, ): instance = object.__new__(cls) cls._initialize_storage_fields( @@ -208,7 +208,7 @@ def __new__( with_gemm_swizzled_scales=with_gemm_swizzled_scales, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, - four_over_six_e4m3_use_256=four_over_six_e4m3_use_256, + nvfp4_e4m3_max=nvfp4_e4m3_max, ) return instance @@ -334,13 +334,13 @@ def use_4over6(self, use_4over6: bool) -> None: self._use_4over6 = use_4over6 @property - def four_over_six_e4m3_use_256(self) -> bool: - """Whether grouped NVFP4 4over6 tensors use the 256 E4M3 scale bound.""" - return self._four_over_six_e4m3_use_256 + def nvfp4_e4m3_max(self) -> int: + """Global E4M3 scale bound used by grouped NVFP4 tensors.""" + return self._nvfp4_e4m3_max - @four_over_six_e4m3_use_256.setter - def four_over_six_e4m3_use_256(self, four_over_six_e4m3_use_256: bool) -> None: - self._four_over_six_e4m3_use_256 = four_over_six_e4m3_use_256 + @nvfp4_e4m3_max.setter + def nvfp4_e4m3_max(self, nvfp4_e4m3_max: int) -> None: + self._nvfp4_e4m3_max = nvfp4_e4m3_max def prepare_for_saving( self, @@ -412,7 +412,7 @@ def clear(self) -> None: self.fake_dtype = torch.float32 self.row_scaled_nvfp4 = False self.use_4over6 = False - self.four_over_six_e4m3_use_256 = False + self.nvfp4_e4m3_max = 448 def __repr__(self) -> str: """String representation of the GroupedTensorStorage.""" @@ -583,7 +583,7 @@ def copy(self) -> "GroupedTensorStorage": with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, row_scaled_nvfp4=self.row_scaled_nvfp4, use_4over6=self.use_4over6, - four_over_six_e4m3_use_256=self.four_over_six_e4m3_use_256, + nvfp4_e4m3_max=self.nvfp4_e4m3_max, ) @staticmethod @@ -696,7 +696,7 @@ def make_grouped_tensor( columnwise_scale_inv_offsets = None row_scaled_nvfp4 = False use_4over6 = False - four_over_six_e4m3_use_256 = False + nvfp4_e4m3_max = 448 if no_quantization: assert dtype is not None, "dtype must be provided for unquantized GroupedTensor" if rowwise_usage: @@ -757,7 +757,7 @@ def make_grouped_tensor( elif quantizer._get_compatible_recipe().nvfp4(): row_scaled_nvfp4 = quantizer.row_scaled_nvfp4 use_4over6 = quantizer.use_4over6 - four_over_six_e4m3_use_256 = quantizer.four_over_six_e4m3_use_256 + nvfp4_e4m3_max = quantizer.nvfp4_e4m3_max if row_scaled_nvfp4: if not rowwise_usage: raise ValueError( @@ -887,7 +887,7 @@ def make_grouped_tensor( ), row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, - four_over_six_e4m3_use_256=four_over_six_e4m3_use_256, + nvfp4_e4m3_max=nvfp4_e4m3_max, ) grouped_tensor.quantized_tensors = grouped_tensor.split_into_quantized_tensors() return grouped_tensor @@ -1003,7 +1003,7 @@ def split_into_quantized_tensors( nvfp4_rowwise_amax_offsets = None row_scaled_nvfp4 = self.row_scaled_nvfp4 use_4over6 = self.use_4over6 - four_over_six_e4m3_use_256 = self.four_over_six_e4m3_use_256 + nvfp4_e4m3_max = self.nvfp4_e4m3_max if recipe.nvfp4() and row_scaled_nvfp4: cum = 0 nvfp4_rowwise_amax_offsets = [0] @@ -1232,7 +1232,7 @@ def split_into_quantized_tensors( with_gemm_swizzled_scales=quantizer.optimize_for_gemm, row_scaled_nvfp4=row_scaled_nvfp4, use_4over6=use_4over6, - four_over_six_e4m3_use_256=four_over_six_e4m3_use_256, + nvfp4_e4m3_max=nvfp4_e4m3_max, ) result.append(tensor) diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index 6c3f8ca70c..cc165d5c3e 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -101,8 +101,8 @@ class NVFP4TensorStorage(QuantizedTensorStorage): _row_scaled_nvfp4: bool # Whether this NVFP4 tensor uses 4over6 map-to-4/map-to-6 block selection _use_4over6: bool - # Whether this 4over6 tensor uses 256 instead of 448 as the global E4M3 scale bound - _four_over_six_e4m3_use_256: bool + # Global E4M3 scale bound used by this NVFP4 tensor + _nvfp4_e4m3_max: int def __new__( cls, @@ -119,7 +119,7 @@ def __new__( fake_dtype: Optional[torch.dtype] = None, row_scaled_nvfp4: bool = False, use_4over6: bool = False, - four_over_six_e4m3_use_256: bool = False, + nvfp4_e4m3_max: int = 448, **kwargs, ): if cls is NVFP4TensorStorage: @@ -139,7 +139,7 @@ def __new__( instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales instance._row_scaled_nvfp4 = row_scaled_nvfp4 instance._use_4over6 = use_4over6 - instance._four_over_six_e4m3_use_256 = use_4over6 and four_over_six_e4m3_use_256 + instance._nvfp4_e4m3_max = nvfp4_e4m3_max if use_4over6 else 448 return instance @@ -168,7 +168,7 @@ def copy_from_storage(self, src: QuantizedTensorStorage) -> None: raise RuntimeError("Rowwise amax scaling mode mismatch in copy_from_storage") if self._use_4over6 != src._use_4over6: raise RuntimeError("NVFP4 4over6 mode mismatch in copy_from_storage") - if self._four_over_six_e4m3_use_256 != src._four_over_six_e4m3_use_256: + if self._nvfp4_e4m3_max != src._nvfp4_e4m3_max: raise RuntimeError("NVFP4 4over6 E4M3 scale bound mismatch in copy_from_storage") def _copy_optional(dst: Optional[torch.Tensor], src_tensor: Optional[torch.Tensor]): @@ -196,7 +196,7 @@ def get_metadata(self) -> Dict[str, Any]: "with_gemm_swizzled_scales": self._with_gemm_swizzled_scales, "row_scaled_nvfp4": self._row_scaled_nvfp4, "use_4over6": self._use_4over6, - "four_over_six_e4m3_use_256": self._four_over_six_e4m3_use_256, + "nvfp4_e4m3_max": self._nvfp4_e4m3_max, "fake_dtype": self._dtype, } @@ -331,7 +331,7 @@ def view(self, shape: torch.Size): with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, row_scaled_nvfp4=self._row_scaled_nvfp4, use_4over6=self._use_4over6, - four_over_six_e4m3_use_256=self._four_over_six_e4m3_use_256, + nvfp4_e4m3_max=self._nvfp4_e4m3_max, fake_dtype=self._dtype, ) From 3cdd9d9ce7c71dfb94c0f3f9d55fe02e489ee58e Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 12 May 2026 18:12:57 -0700 Subject: [PATCH 51/57] Add benchmark script and minor optimization Signed-off-by: Ziang Li --- benchmarks/benchmark_4over6.py | 107 ++++++++++++++ .../cast/nvfp4/quantize_4over6_nvfp4.cuh | 138 +++++++----------- .../cast/nvfp4/quantize_transpose_nvfp4.cuh | 32 ++-- 3 files changed, 173 insertions(+), 104 deletions(-) create mode 100644 benchmarks/benchmark_4over6.py diff --git a/benchmarks/benchmark_4over6.py b/benchmarks/benchmark_4over6.py new file mode 100644 index 0000000000..af2b7b611a --- /dev/null +++ b/benchmarks/benchmark_4over6.py @@ -0,0 +1,107 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os + +import torch +import torch.utils.benchmark as benchmark +import transformer_engine.pytorch as te +import transformer_engine_torch as tex + +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + + +SHAPES = [ + (16384, 6144), +] +MIN_RUN_TIME = 5 + + +def make_quantizer(use_2d_quantization, use_4over6, err_mode): + return NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=use_2d_quantization, + stochastic_rounding=False, + row_scaled_nvfp4=False, + use_4over6=use_4over6, + nvfp4_e4m3_max=448, + nvfp4_4over6_err_mode=err_mode, + with_random_sign_mask=True, + ) + + +def set_err_fast_math(enabled): + os.environ["NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH"] = "1" if enabled else "0" + + +def benchmark_quantize(shape, use_2d_quantization, use_4over6, err_mode, err_fast_math): + set_err_fast_math(err_fast_math) + + x = torch.randn(shape, dtype=torch.bfloat16, device="cuda") + quantizer = make_quantizer(use_2d_quantization, use_4over6, err_mode) + out = quantizer.make_empty(shape, dtype=x.dtype, device=x.device, requires_grad=False) + quantizer.update_quantized(x, out) + torch.cuda.synchronize() + + timing = benchmark.Timer( + stmt="quantizer.update_quantized(x, out)", + globals={"quantizer": quantizer, "x": x, "out": out}, + num_threads=1, + ).blocked_autorange(min_run_time=MIN_RUN_TIME) + return timing.median * 1e6 + + +def main(): + rows = [] + for shape in SHAPES: + for mode_name, use_2d_quantization in (("1d", False), ("2d", True)): + baseline_us = benchmark_quantize( + shape=shape, + use_2d_quantization=use_2d_quantization, + use_4over6=False, + err_mode="MAE", + err_fast_math=False, + ) + rows.append((shape, mode_name, "nvfp4", "-", "-", baseline_us, 1.0)) + + for err_mode in ("MAE", "MSE"): + for err_fast_math in (False, True): + timing_us = benchmark_quantize( + shape=shape, + use_2d_quantization=use_2d_quantization, + use_4over6=True, + err_mode=err_mode, + err_fast_math=err_fast_math, + ) + rows.append( + ( + shape, + mode_name, + "4over6", + err_mode, + str(err_fast_math), + timing_us, + timing_us / baseline_us, + ) + ) + + print( + f"{'shape':>18} {'mode':>4} {'kernel':>7} {'err':>3} " + f"{'err_fast':>8} {'time_us':>10} {'slowdown':>8}" + ) + for shape, mode_name, kernel, err_mode, err_fast_math, timing_us, slowdown in rows: + print( + f"{str(shape):>18} {mode_name:>4} {kernel:>7} {err_mode:>3} " + f"{err_fast_math:>8} {timing_us:10.3f} {slowdown:8.3f}x" + ) + + +if __name__ == "__main__": + main() diff --git a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh index 0f75d00498..a1d9b52d28 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh @@ -160,6 +160,30 @@ __device__ __forceinline__ float compute_4over6_error(const float diff) { } } +template +__device__ __forceinline__ void accumulate_4over6_dequant_error(const uint32_t dequant_bits, + const float x, const float sf, + const float global_amax, + float *err) { + constexpr float fp4_max = detail::TypeExtrema::max; // 6.0f + static_assert(E4M3_MAX == 448 || E4M3_MAX == 256, "Unsupported NVFP4 E4M3 max."); + constexpr float fp8_4over6_max = static_cast(E4M3_MAX); + constexpr float err_denom = fp4_max * fp8_4over6_max; + const uint16_t half_bits = (dequant_bits >> SHIFT) & 0xFFFF; + + if constexpr (FourOverSixConfig::err_use_fast_math) { + const float val = __half2float(__ushort_as_half(half_bits)) * sf * global_amax / err_denom; + const float diff = val - x; + *err += compute_4over6_error(diff); + } else { + const float val = + __fdiv_rn(__fmul_rn(__fmul_rn(__half2float(__ushort_as_half(half_bits)), sf), global_amax), + err_denom); + const float diff = __fsub_rn(val, x); + *err = __fadd_rn(*err, compute_4over6_error_rn(diff)); + } +} + template __device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error_rn( const float (&x)[8], const float block_scale_inverse, const nvfp4_scale_t S_dec_b_fp8, @@ -174,12 +198,6 @@ __device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error_rn( constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; if constexpr (is_blackwell) { - float x_scaled[8]; -#pragma unroll - for (int i = 0; i < 8; ++i) { - x_scaled[i] = __fmul_rn(x[i], block_scale_inverse); - } - asm volatile( "{\n" ".reg .b8 byte0, byte1, byte2, byte3;\n" @@ -195,84 +213,28 @@ __device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error_rn( "}" : "=r"(out), "=r"(out_dequant_1), "=r"(out_dequant_2), "=r"(out_dequant_3), "=r"(out_dequant_4) - : "f"(x_scaled[0]), "f"(x_scaled[1]), "f"(x_scaled[2]), "f"(x_scaled[3]), "f"(x_scaled[4]), - "f"(x_scaled[5]), "f"(x_scaled[6]), "f"(x_scaled[7])); - - const uint16_t out_dequant_1_hi = (out_dequant_1 >> 16) & 0xFFFF; - const uint16_t out_dequant_1_lo = out_dequant_1 & 0xFFFF; - const uint16_t out_dequant_2_hi = (out_dequant_2 >> 16) & 0xFFFF; - const uint16_t out_dequant_2_lo = out_dequant_2 & 0xFFFF; - const uint16_t out_dequant_3_hi = (out_dequant_3 >> 16) & 0xFFFF; - const uint16_t out_dequant_3_lo = out_dequant_3 & 0xFFFF; - const uint16_t out_dequant_4_hi = (out_dequant_4 >> 16) & 0xFFFF; - const uint16_t out_dequant_4_lo = out_dequant_4 & 0xFFFF; - - constexpr float fp4_max = detail::TypeExtrema::max; // 6.0f - static_assert(E4M3_MAX == 448 || E4M3_MAX == 256, "Unsupported NVFP4 E4M3 max."); - constexpr float fp8_4over6_max = static_cast(E4M3_MAX); - constexpr float err_denom = fp4_max * fp8_4over6_max; + : "f"(__fmul_rn(x[0], block_scale_inverse)), "f"(__fmul_rn(x[1], block_scale_inverse)), + "f"(__fmul_rn(x[2], block_scale_inverse)), "f"(__fmul_rn(x[3], block_scale_inverse)), + "f"(__fmul_rn(x[4], block_scale_inverse)), "f"(__fmul_rn(x[5], block_scale_inverse)), + "f"(__fmul_rn(x[6], block_scale_inverse)), "f"(__fmul_rn(x[7], block_scale_inverse))); + const float sf = static_cast(S_dec_b_fp8); - if constexpr (FourOverSixConfig::err_use_fast_math) { - const float dequant[8] = { - __half2float(__ushort_as_half(out_dequant_1_lo)), - __half2float(__ushort_as_half(out_dequant_1_hi)), - __half2float(__ushort_as_half(out_dequant_2_lo)), - __half2float(__ushort_as_half(out_dequant_2_hi)), - __half2float(__ushort_as_half(out_dequant_3_lo)), - __half2float(__ushort_as_half(out_dequant_3_hi)), - __half2float(__ushort_as_half(out_dequant_4_lo)), - __half2float(__ushort_as_half(out_dequant_4_hi)), - }; -#pragma unroll - for (int i = 0; i < 8; ++i) { - const float val = dequant[i] * sf * global_amax / err_denom; - const float diff = val - x[i]; - *err += compute_4over6_error(diff); - } - } else { - const float val0 = __fdiv_rn( - __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_1_lo)), sf), global_amax), - err_denom); - const float val1 = __fdiv_rn( - __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_1_hi)), sf), global_amax), - err_denom); - const float val2 = __fdiv_rn( - __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_2_lo)), sf), global_amax), - err_denom); - const float val3 = __fdiv_rn( - __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_2_hi)), sf), global_amax), - err_denom); - const float val4 = __fdiv_rn( - __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_3_lo)), sf), global_amax), - err_denom); - const float val5 = __fdiv_rn( - __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_3_hi)), sf), global_amax), - err_denom); - const float val6 = __fdiv_rn( - __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_4_lo)), sf), global_amax), - err_denom); - const float val7 = __fdiv_rn( - __fmul_rn(__fmul_rn(__half2float(__ushort_as_half(out_dequant_4_hi)), sf), global_amax), - err_denom); - - const float diff0 = __fsub_rn(val0, x[0]); - const float diff1 = __fsub_rn(val1, x[1]); - const float diff2 = __fsub_rn(val2, x[2]); - const float diff3 = __fsub_rn(val3, x[3]); - const float diff4 = __fsub_rn(val4, x[4]); - const float diff5 = __fsub_rn(val5, x[5]); - const float diff6 = __fsub_rn(val6, x[6]); - const float diff7 = __fsub_rn(val7, x[7]); - - *err = __fadd_rn(*err, compute_4over6_error_rn(diff0)); - *err = __fadd_rn(*err, compute_4over6_error_rn(diff1)); - *err = __fadd_rn(*err, compute_4over6_error_rn(diff2)); - *err = __fadd_rn(*err, compute_4over6_error_rn(diff3)); - *err = __fadd_rn(*err, compute_4over6_error_rn(diff4)); - *err = __fadd_rn(*err, compute_4over6_error_rn(diff5)); - *err = __fadd_rn(*err, compute_4over6_error_rn(diff6)); - *err = __fadd_rn(*err, compute_4over6_error_rn(diff7)); - } + accumulate_4over6_dequant_error(out_dequant_1, x[0], sf, + global_amax, err); + accumulate_4over6_dequant_error(out_dequant_1, x[1], sf, + global_amax, err); + accumulate_4over6_dequant_error(out_dequant_2, x[2], sf, + global_amax, err); + accumulate_4over6_dequant_error(out_dequant_2, x[3], sf, + global_amax, err); + accumulate_4over6_dequant_error(out_dequant_3, x[4], sf, + global_amax, err); + accumulate_4over6_dequant_error(out_dequant_3, x[5], sf, + global_amax, err); + accumulate_4over6_dequant_error(out_dequant_4, x[6], sf, + global_amax, err); + accumulate_4over6_dequant_error(out_dequant_4, x[7], sf, + global_amax, err); } else { NVTE_DEVICE_ERROR( "FP4 cvt PTX instructions are architecture-specific. " @@ -442,10 +404,10 @@ __device__ __forceinline__ void load_4over6_vec_index_halves_16x(const vec_type } } -template +template __device__ __forceinline__ void quantize_4over6_candidates_16x( - const float (&x)[16], const QuantizationScales4Over6 &scaling_factors, const float global_amax, - QuantizationCandidates4Over6 &candidates) { + const input_type (&x)[16], const QuantizationScales4Over6 &scaling_factors, + const float global_amax, QuantizationCandidates4Over6 &candidates) { float first_half[8]; float second_half[8]; load_4over6_contiguous_halves_16x(x, first_half, second_half); @@ -490,9 +452,9 @@ __device__ __forceinline__ bool record_and_select_4over6_2d_block( } template + size_t BLOCKS_PER_TILE_X, typename input_type> __device__ __forceinline__ bool quantize_and_select_4over6_2d_block_16x( - const float (&x)[16], const float block_amax, const float global_encode_scale, + const input_type (&x)[16], const float block_amax, const float global_encode_scale, const float global_decode_scale, const float global_amax, const size_t block_in_tile_y, const size_t block_in_tile_x, const size_t participant_idx, QuantizationScratch4Over6 &scratch, diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index c7448d1849..76797d86ad 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -1099,25 +1099,25 @@ __global__ void __launch_bounds__(THREADS_NUM) } if constexpr (FourOverSixConfig::enabled) { - float x_4over6[SCALE_DIM]; -#pragma unroll - for (int i = 0; i < SCALE_DIM; ++i) { - if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { - x_4over6[i] = static_cast(in_colwise_IType[i]); - } else { - x_4over6[i] = in_compute_colwise[i]; - } - } - const size_t block_col = threadIdx.x % BLOCK_DIM; QuantizationCandidates4Over6 candidates; nvfp4_scale_t S_dec_b_fp8; - const bool pick_map4 = - quantize_and_select_4over6_2d_block_16x( - x_4over6, block_amax, S_enc_colwise, S_dec_colwise, global_amax_colwise, - block_in_tile_y, block_in_tile_x, block_col, *four_over_six_scratch, S_dec_b_fp8, - candidates); + bool pick_map4; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + pick_map4 = + quantize_and_select_4over6_2d_block_16x( + in_colwise_IType, block_amax, S_enc_colwise, S_dec_colwise, global_amax_colwise, + block_in_tile_y, block_in_tile_x, block_col, *four_over_six_scratch, + S_dec_b_fp8, candidates); + } else { + pick_map4 = + quantize_and_select_4over6_2d_block_16x( + in_compute_colwise, block_amax, S_enc_colwise, S_dec_colwise, + global_amax_colwise, block_in_tile_y, block_in_tile_x, block_col, + *four_over_six_scratch, S_dec_b_fp8, candidates); + } const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; From 7deba7504adc988d95f989ba2765ed96c6f82096 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 12 May 2026 20:01:57 -0700 Subject: [PATCH 52/57] Use standalone kernels Signed-off-by: Ziang Li --- .../common/cast/dispatch/quantize.cuh | 29 +- .../cast/nvfp4/quantize_4over6_nvfp4.cuh | 866 +++++++++--------- .../cast/nvfp4/quantize_transpose_nvfp4.cuh | 352 +++---- .../quantize_transpose_nvfp4_tuned_1D.cuh | 300 ++---- .../common/transpose/cast_transpose.h | 2 - ...quantize_transpose_vector_blockwise_fp4.cu | 355 ++----- 6 files changed, 757 insertions(+), 1147 deletions(-) diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index ea016260d7..f25a64053f 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -21,6 +21,7 @@ #include "../mxfp8/group_quantize_mxfp8.cuh" #include "../mxfp8/quantize_mxfp8.cuh" #include "../nvfp4/group_quantize_transpose_nvfp4.cuh" +#include "../nvfp4/quantize_4over6_nvfp4.cuh" #include "../nvfp4/quantize_transpose_nvfp4.cuh" namespace transformer_engine { @@ -122,7 +123,15 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, (cols % 32 == 0) && output_tensor->has_data(); // Launch NVFP4 quantize kernel - if (use_optimized_kernel) { + if (use_4over6) { + if (quant_config_cpp.nvfp4_2d_quantization) { + nvfp4::quantize_4over6( + *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } else { + nvfp4::quantize_4over6( + *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } + } else if (use_optimized_kernel) { if (quant_config_cpp.nvfp4_2d_quantization) { nvfp4::quantize_transpose( *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); @@ -145,10 +154,6 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, /*rng_state=*/quant_config_cpp.rng_state, /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, /*row_scaled_nvfp4=*/row_scaled_nvfp4, - /*use_4over6=*/use_4over6, - /*nvfp4_e4m3_max=*/quant_config_cpp.nvfp4_e4m3_max, - /*nvfp4_4over6_err_mode=*/quant_config_cpp.nvfp4_4over6_err_mode, - /*nvfp4_4over6_err_use_fast_math=*/quant_config_cpp.nvfp4_4over6_err_use_fast_math, /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); } @@ -279,7 +284,15 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens (cols % 32 == 0) && output_tensor->has_data(); // Launch NVFP4 quantize kernel - if (use_optimized_kernel) { + if (use_4over6) { + if (quant_config_cpp.nvfp4_2d_quantization) { + nvfp4::quantize_4over6( + *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } else { + nvfp4::quantize_4over6( + *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } + } else if (use_optimized_kernel) { if (quant_config_cpp.nvfp4_2d_quantization) { nvfp4::quantize_transpose( *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); @@ -302,10 +315,6 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens /*rng_state=*/quant_config_cpp.rng_state, /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, /*row_scaled_nvfp4=*/false, - /*use_4over6=*/use_4over6, - /*nvfp4_e4m3_max=*/quant_config_cpp.nvfp4_e4m3_max, - /*nvfp4_4over6_err_mode=*/quant_config_cpp.nvfp4_4over6_err_mode, - /*nvfp4_4over6_err_use_fast_math=*/quant_config_cpp.nvfp4_4over6_err_use_fast_math, /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); } diff --git a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh index a1d9b52d28..d8fa24c1ea 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh @@ -5,13 +5,16 @@ ************************************************************************/ /*! \file quantize_4over6_nvfp4.cuh - * \brief Helpers used by NVFP4 4over6 quantization. + * \brief Dedicated kernels for NVFP4 4over6 quantization. * - * 4over6 evaluates two TE-style NVFP4 encodings for each 1x16 block. The - * map-to-6 candidate uses the normal block scale. The map-to-4 candidate uses - * a 1.5x expanded block scale, which maps the FP4 value 4 to the same dynamic - * range as FP4 value 6. The selected candidate is the one with lower configured - * error after dequantizing back to the original input domain; ties select map-to-6. + * Four Over Six evaluates two TE-style NVFP4 encodings for every 1x16 + * quantization group. The map-to-6 candidate uses the normal scale. The + * map-to-4 candidate expands the E4M3 block scale by 1.5x so FP4 value 4 + * reaches the same range that FP4 value 6 reaches in the normal encoding. + * The selected candidate is the one with lower configured dequantization + * error; ties select map-to-6. The quantized candidates, dequantized values, + * and errors are kept in registers, matching the structure of the official + * Four Over Six implementation. */ #ifndef TRANSFORMER_ENGINE_QUANTIZE_4OVER6_NVFP4_CUH_ @@ -21,13 +24,16 @@ #include #include #include +#include +#include "../../common.h" +#include "../../util/math.h" +#include "../../utils.cuh" #include "core_nvfp4.cuh" namespace transformer_engine { namespace dispatch { namespace nvfp4 { -namespace core { #if FP4_TYPE_SUPPORTED @@ -56,91 +62,46 @@ namespace core { { __VA_ARGS__ } \ } -template -struct NVFP44Over6Config { - static constexpr bool enabled = kEnabled; +namespace quantize_4over6_kernel { + +constexpr int kThreads = 128; +constexpr int kWarpThreads = 32; +constexpr int kGroupSize = 16; +constexpr int kTileRows = 128; +constexpr int kTileCols = 64; +constexpr int kTileColGroups = kTileCols / kGroupSize; +constexpr int kTileRowGroups = kTileRows / kGroupSize; +constexpr int kElementsPerHalfGroup = 8; +constexpr int kPackedWordsPerGroup = 2; + +template +struct Config { static constexpr NVTENVFP44Over6ErrMode err_mode = kErrMode; static constexpr bool err_use_fast_math = kErrUseFastMath; }; -using NVFP44Over6DisabledConfig = NVFP44Over6Config; - -#define TRANSFORMER_ENGINE_NVFP4_4OVER6_CONFIG_SWITCH(USE_4OVER6_VALUE, ERR_MODE_VALUE, \ - ERR_USE_FAST_MATH_VALUE, CONFIG_CONST, ...) \ - if (USE_4OVER6_VALUE) { \ - TRANSFORMER_ENGINE_NVFP4_4OVER6_ERR_MODE_SWITCH( \ - ERR_MODE_VALUE, ERR_MODE_CONST, \ - TRANSFORMER_ENGINE_SWITCH_CONDITION(ERR_USE_FAST_MATH_VALUE, ERR_USE_FAST_MATH_CONST, { \ - using CONFIG_CONST = NVFP44Over6Config; \ - { __VA_ARGS__ } \ - });); \ - } else { \ - using CONFIG_CONST = NVFP44Over6DisabledConfig; \ - { __VA_ARGS__ } \ - } - -__device__ __forceinline__ void compute_4over6_decoding_scaling_factors( - const float block_amax, const float S_enc, nvfp4_scale_t &S_dec_b_fp8_map4, - nvfp4_scale_t &S_dec_b_fp8_map6) { - constexpr float fp4_max = detail::TypeExtrema::max; // 6.0f - constexpr float fp8_max = detail::TypeExtrema::max; // 448.0f - constexpr float scale_expansion_factor = 1.5f; - const float base_sf_high_precision = block_amax / fp4_max * S_enc; - const float sf_high_precision_map4 = - fminf(base_sf_high_precision * scale_expansion_factor, fp8_max); - const float sf_high_precision_map6 = fminf(base_sf_high_precision, fp8_max); - S_dec_b_fp8_map4 = static_cast(sf_high_precision_map4); - S_dec_b_fp8_map6 = static_cast(sf_high_precision_map6); -} - -struct QuantizationScales4Over6 { - nvfp4_scale_t S_dec_b_fp8_map4; - nvfp4_scale_t S_dec_b_fp8_map6; - float SFcoefficient_map4; - float SFcoefficient_map6; +struct Candidate { + uint32_t packed[kPackedWordsPerGroup]; + float err; }; -__device__ __forceinline__ float compute_4over6_nvfp4_scaling_coefficient( - const nvfp4_scale_t S_dec_block, const float S_enc) { - const float S_dec = 1.0f / S_enc; - return fminf(1.0f / (static_cast(S_dec_block) * S_dec), detail::TypeExtrema::max); -} - -__device__ __forceinline__ QuantizationScales4Over6 -compute_4over6_nvfp4_quantization_scaling_factors(const float block_amax, const float S_enc) { - QuantizationScales4Over6 scaling_factors; - compute_4over6_decoding_scaling_factors(block_amax, S_enc, scaling_factors.S_dec_b_fp8_map4, - scaling_factors.S_dec_b_fp8_map6); - scaling_factors.SFcoefficient_map4 = - compute_4over6_nvfp4_scaling_coefficient(scaling_factors.S_dec_b_fp8_map4, S_enc); - scaling_factors.SFcoefficient_map6 = - compute_4over6_nvfp4_scaling_coefficient(scaling_factors.S_dec_b_fp8_map6, S_enc); - return scaling_factors; -} +struct CandidatePair { + Candidate map4; + Candidate map6; +}; -__device__ __forceinline__ QuantizationScales4Over6 -compute_4over6_fp4_encode_quantization_scaling_factors(const float block_amax, - const float global_encode_scale, - const float global_decode_scale) { - QuantizationScales4Over6 scaling_factors; - compute_4over6_decoding_scaling_factors(block_amax, global_encode_scale, - scaling_factors.S_dec_b_fp8_map4, - scaling_factors.S_dec_b_fp8_map6); - scaling_factors.SFcoefficient_map4 = - fminf(1.0f / (static_cast(scaling_factors.S_dec_b_fp8_map4) * global_decode_scale), - detail::TypeExtrema::max); - scaling_factors.SFcoefficient_map6 = - fminf(1.0f / (static_cast(scaling_factors.S_dec_b_fp8_map6) * global_decode_scale), - detail::TypeExtrema::max); - return scaling_factors; -} +struct ScalePair { + nvfp4_scale_t map4; + nvfp4_scale_t map6; + float inv_map4; + float inv_map6; +}; -template -__device__ __forceinline__ float compute_4over6_error_rn(const float diff) { - if constexpr (ERR_MODE == kNVTENVFP44Over6ErrMSE) { +template +__device__ __forceinline__ float compute_error_rn(const float diff) { + if constexpr (kErrMode == kNVTENVFP44Over6ErrMSE) { return __fmul_rn(diff, diff); - } else if constexpr (ERR_MODE == kNVTENVFP44Over6ErrMAE) { + } else if constexpr (kErrMode == kNVTENVFP44Over6ErrMAE) { return fabsf(diff); } else { NVTE_DEVICE_ERROR("Unsupported NVFP4 4over6 error mode."); @@ -148,11 +109,11 @@ __device__ __forceinline__ float compute_4over6_error_rn(const float diff) { } } -template -__device__ __forceinline__ float compute_4over6_error(const float diff) { - if constexpr (ERR_MODE == kNVTENVFP44Over6ErrMSE) { +template +__device__ __forceinline__ float compute_error(const float diff) { + if constexpr (kErrMode == kNVTENVFP44Over6ErrMSE) { return diff * diff; - } else if constexpr (ERR_MODE == kNVTENVFP44Over6ErrMAE) { + } else if constexpr (kErrMode == kNVTENVFP44Over6ErrMAE) { return fabsf(diff); } else { NVTE_DEVICE_ERROR("Unsupported NVFP4 4over6 error mode."); @@ -160,36 +121,93 @@ __device__ __forceinline__ float compute_4over6_error(const float diff) { } } -template -__device__ __forceinline__ void accumulate_4over6_dequant_error(const uint32_t dequant_bits, - const float x, const float sf, - const float global_amax, - float *err) { - constexpr float fp4_max = detail::TypeExtrema::max; // 6.0f +template +__device__ __forceinline__ ScalePair compute_scale_pair(const float block_amax, + const float global_amax) { static_assert(E4M3_MAX == 448 || E4M3_MAX == 256, "Unsupported NVFP4 E4M3 max."); - constexpr float fp8_4over6_max = static_cast(E4M3_MAX); - constexpr float err_denom = fp4_max * fp8_4over6_max; + constexpr float fp4_max = detail::TypeExtrema::max; // 6.0f + constexpr float fp8_max = detail::TypeExtrema::max; // 448.0f + constexpr float expand_to_map4 = 1.5f; + const float S_enc = core::compute_global_encode_scaling_factor_FP4(global_amax); + const float base = block_amax / fp4_max * S_enc; + + ScalePair scales; + scales.map4 = static_cast(fminf(base * expand_to_map4, fp8_max)); + scales.map6 = static_cast(fminf(base, fp8_max)); + + const float S_dec = 1.0f / S_enc; + scales.inv_map4 = + fminf(1.0f / (static_cast(scales.map4) * S_dec), detail::TypeExtrema::max); + scales.inv_map6 = + fminf(1.0f / (static_cast(scales.map6) * S_dec), detail::TypeExtrema::max); + return scales; +} + +template +__device__ __forceinline__ float load_input(const IType *ptr, const size_t idx) { + return static_cast(ptr[idx]); +} + +template +__device__ __forceinline__ void load_row_group(const IType *tile, const int row, + const int col_start, float (&x0)[8], float (&x1)[8], + float *amax) { + *amax = 0.0f; +#pragma unroll + for (int i = 0; i < kElementsPerHalfGroup; ++i) { + const float v0 = load_input(tile, row * kTileCols + col_start + i); + const float v1 = load_input(tile, row * kTileCols + col_start + i + kElementsPerHalfGroup); + x0[i] = v0; + x1[i] = v1; + *amax = fmaxf(*amax, fabsf(v0)); + *amax = fmaxf(*amax, fabsf(v1)); + } +} + +template +__device__ __forceinline__ void load_col_group(const IType *tile, const int row_start, + const int col, float (&x0)[8], float (&x1)[8], + float *amax) { + *amax = 0.0f; +#pragma unroll + for (int i = 0; i < kElementsPerHalfGroup; ++i) { + const float v0 = load_input(tile, (row_start + i) * kTileCols + col); + const float v1 = load_input(tile, (row_start + i + kElementsPerHalfGroup) * kTileCols + col); + x0[i] = v0; + x1[i] = v1; + *amax = fmaxf(*amax, fabsf(v0)); + *amax = fmaxf(*amax, fabsf(v1)); + } +} + +template +__device__ __forceinline__ void accumulate_dequant_error(const uint32_t dequant_bits, const float x, + const float sf, const float global_amax, + float *err) { + constexpr float fp4_max = detail::TypeExtrema::max; // 6.0f + constexpr float fp8_max = static_cast(E4M3_MAX); + constexpr float err_denom = fp4_max * fp8_max; const uint16_t half_bits = (dequant_bits >> SHIFT) & 0xFFFF; - if constexpr (FourOverSixConfig::err_use_fast_math) { - const float val = __half2float(__ushort_as_half(half_bits)) * sf * global_amax / err_denom; + if constexpr (Cfg::err_use_fast_math) { + const float dequant = __half2float(__ushort_as_half(half_bits)); + const float val = dequant * sf * global_amax / err_denom; const float diff = val - x; - *err += compute_4over6_error(diff); + *err += compute_error(diff); } else { - const float val = - __fdiv_rn(__fmul_rn(__fmul_rn(__half2float(__ushort_as_half(half_bits)), sf), global_amax), - err_denom); + const float dequant = __half2float(__ushort_as_half(half_bits)); + const float val = __fdiv_rn(__fmul_rn(__fmul_rn(dequant, sf), global_amax), err_denom); const float diff = __fsub_rn(val, x); - *err = __fadd_rn(*err, compute_4over6_error_rn(diff)); + *err = __fadd_rn(*err, compute_error_rn(diff)); } } -template -__device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error_rn( - const float (&x)[8], const float block_scale_inverse, const nvfp4_scale_t S_dec_b_fp8, - const float global_amax, float *err) { - static_assert(FourOverSixConfig::enabled, - "4over6 conversion helpers require an enabled 4over6 config."); +template +__device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error(const float (&x)[8], + const float block_scale_inverse, + const nvfp4_scale_t sf, + const float global_amax, + float *err) { uint32_t out = 0; uint32_t out_dequant_1 = 0; uint32_t out_dequant_2 = 0; @@ -217,393 +235,367 @@ __device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error_rn( "f"(__fmul_rn(x[2], block_scale_inverse)), "f"(__fmul_rn(x[3], block_scale_inverse)), "f"(__fmul_rn(x[4], block_scale_inverse)), "f"(__fmul_rn(x[5], block_scale_inverse)), "f"(__fmul_rn(x[6], block_scale_inverse)), "f"(__fmul_rn(x[7], block_scale_inverse))); - - const float sf = static_cast(S_dec_b_fp8); - accumulate_4over6_dequant_error(out_dequant_1, x[0], sf, - global_amax, err); - accumulate_4over6_dequant_error(out_dequant_1, x[1], sf, - global_amax, err); - accumulate_4over6_dequant_error(out_dequant_2, x[2], sf, - global_amax, err); - accumulate_4over6_dequant_error(out_dequant_2, x[3], sf, - global_amax, err); - accumulate_4over6_dequant_error(out_dequant_3, x[4], sf, - global_amax, err); - accumulate_4over6_dequant_error(out_dequant_3, x[5], sf, - global_amax, err); - accumulate_4over6_dequant_error(out_dequant_4, x[6], sf, - global_amax, err); - accumulate_4over6_dequant_error(out_dequant_4, x[7], sf, - global_amax, err); } else { NVTE_DEVICE_ERROR( "FP4 cvt PTX instructions are architecture-specific. " "Try recompiling with sm_XXXa instead of sm_XXX."); } + const float sf_float = static_cast(sf); + accumulate_dequant_error(out_dequant_1, x[0], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_1, x[1], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_2, x[2], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_2, x[3], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_3, x[4], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_3, x[5], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_4, x[6], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_4, x[7], sf_float, global_amax, err); return out; } -template -__device__ __forceinline__ void quantize_4over6_16x(const float (&first_half)[8], - const float (&second_half)[8], - const QuantizationScales4Over6 &scaling_factors, - const float global_amax, float &err_map4, - float &err_map6, uint32_t (&rOut_map4)[2], - uint32_t (&rOut_map6)[2]) { - if constexpr (REVERSE_PACK_ORDER) { - rOut_map4[1] = cvt_fp32_to_fp4_8x_with_error_rn( - second_half, static_cast(scaling_factors.SFcoefficient_map4), - scaling_factors.S_dec_b_fp8_map4, global_amax, &err_map4); - rOut_map6[1] = cvt_fp32_to_fp4_8x_with_error_rn( - second_half, static_cast(scaling_factors.SFcoefficient_map6), - scaling_factors.S_dec_b_fp8_map6, global_amax, &err_map6); - rOut_map4[0] = cvt_fp32_to_fp4_8x_with_error_rn( - first_half, static_cast(scaling_factors.SFcoefficient_map4), - scaling_factors.S_dec_b_fp8_map4, global_amax, &err_map4); - rOut_map6[0] = cvt_fp32_to_fp4_8x_with_error_rn( - first_half, static_cast(scaling_factors.SFcoefficient_map6), - scaling_factors.S_dec_b_fp8_map6, global_amax, &err_map6); - } else { - rOut_map4[0] = cvt_fp32_to_fp4_8x_with_error_rn( - first_half, static_cast(scaling_factors.SFcoefficient_map4), - scaling_factors.S_dec_b_fp8_map4, global_amax, &err_map4); - rOut_map6[0] = cvt_fp32_to_fp4_8x_with_error_rn( - first_half, static_cast(scaling_factors.SFcoefficient_map6), - scaling_factors.S_dec_b_fp8_map6, global_amax, &err_map6); - rOut_map4[1] = cvt_fp32_to_fp4_8x_with_error_rn( - second_half, static_cast(scaling_factors.SFcoefficient_map4), - scaling_factors.S_dec_b_fp8_map4, global_amax, &err_map4); - rOut_map6[1] = cvt_fp32_to_fp4_8x_with_error_rn( - second_half, static_cast(scaling_factors.SFcoefficient_map6), - scaling_factors.S_dec_b_fp8_map6, global_amax, &err_map6); +template +__device__ __forceinline__ CandidatePair make_candidates(const float (&x0)[8], const float (&x1)[8], + const ScalePair &scales, + const float global_amax) { + CandidatePair candidates; + candidates.map4.err = 0.0f; + candidates.map6.err = 0.0f; + candidates.map4.packed[0] = cvt_fp32_to_fp4_8x_with_error( + x0, scales.inv_map4, scales.map4, global_amax, &candidates.map4.err); + candidates.map6.packed[0] = cvt_fp32_to_fp4_8x_with_error( + x0, scales.inv_map6, scales.map6, global_amax, &candidates.map6.err); + candidates.map4.packed[1] = cvt_fp32_to_fp4_8x_with_error( + x1, scales.inv_map4, scales.map4, global_amax, &candidates.map4.err); + candidates.map6.packed[1] = cvt_fp32_to_fp4_8x_with_error( + x1, scales.inv_map6, scales.map6, global_amax, &candidates.map6.err); + return candidates; +} + +__device__ __forceinline__ float reduce_group_sum_16(float value) { + const int lane = threadIdx.x & (kWarpThreads - 1); + const int group_base = lane & ~(kGroupSize - 1); + const unsigned mask = 0xffffu << group_base; +#pragma unroll + for (int offset = kGroupSize / 2; offset > 0; offset /= 2) { + value += __shfl_down_sync(mask, value, offset, kGroupSize); + } + return __shfl_sync(mask, value, group_base, kWarpThreads); +} + +__device__ __forceinline__ float reduce_group_max_16(float value) { + const int lane = threadIdx.x & (kWarpThreads - 1); + const int group_base = lane & ~(kGroupSize - 1); + const unsigned mask = 0xffffu << group_base; +#pragma unroll + for (int offset = kGroupSize / 2; offset > 0; offset /= 2) { + value = fmaxf(value, __shfl_down_sync(mask, value, offset, kGroupSize)); } + return __shfl_sync(mask, value, group_base, kWarpThreads); } -__device__ __forceinline__ bool pick_4over6_map4(const float err_map4, const float err_map6) { - return err_map4 < err_map6; +__device__ __forceinline__ void store_packed_group(const uint32_t *packed, fp4e2m1x2 *dst) { + auto *dst32 = reinterpret_cast(dst); + dst32[0] = packed[0]; + dst32[1] = packed[1]; } -__device__ __forceinline__ nvfp4_scale_t -selected_4over6_scale(const bool pick_map4, const QuantizationScales4Over6 &scaling_factors) { +__device__ __forceinline__ const uint32_t *select_packed(const CandidatePair &candidates, + const bool pick_map4) { if (pick_map4) { - return scaling_factors.S_dec_b_fp8_map4; + return candidates.map4.packed; } - return scaling_factors.S_dec_b_fp8_map6; + return candidates.map6.packed; } -template -__device__ __forceinline__ void quantize_4over6_16x(const float (&first_half)[8], - const float (&second_half)[8], - const QuantizationScales4Over6 &scaling_factors, - const float global_amax, - nvfp4_scale_t &S_dec_b_fp8, - uint32_t (&rOut)[2]) { - float err_map4 = 0.0f; - float err_map6 = 0.0f; - __align__(8) uint32_t rOut_map4[2]; - __align__(8) uint32_t rOut_map6[2]; - - quantize_4over6_16x( - first_half, second_half, scaling_factors, global_amax, err_map4, err_map6, rOut_map4, - rOut_map6); - - const bool pick_map4 = pick_4over6_map4(err_map4, err_map6); - S_dec_b_fp8 = selected_4over6_scale(pick_map4, scaling_factors); +__device__ __forceinline__ nvfp4_scale_t select_scale(const ScalePair &scales, + const bool pick_map4) { if (pick_map4) { - rOut[0] = rOut_map4[0]; - rOut[1] = rOut_map4[1]; - } else { - rOut[0] = rOut_map6[0]; - rOut[1] = rOut_map6[1]; + return scales.map4; } + return scales.map6; } -struct QuantizationCandidates4Over6 { - float err_map4; - float err_map6; - uint32_t rOut_map4[2]; - uint32_t rOut_map6[2]; - - __device__ __forceinline__ void reset_errors() { - err_map4 = 0.0f; - err_map6 = 0.0f; +template +__device__ void load_tile_to_shared(const IType *input, IType *tile, const size_t rows, + const size_t cols, const size_t tile_row, + const size_t tile_col) { + constexpr int vec_elems = 16 / sizeof(IType); + constexpr int vecs_per_row = kTileCols / vec_elems; + constexpr int vecs = kTileRows * vecs_per_row; + using TileVec = Vec; + + for (int idx = threadIdx.x; idx < vecs; idx += blockDim.x) { + const int local_row = idx / vecs_per_row; + const int local_vec_col = idx - local_row * vecs_per_row; + const int local_col = local_vec_col * vec_elems; + const size_t global_row = tile_row + local_row; + const size_t global_col = tile_col + local_col; + + TileVec vec; + if (global_row < rows && global_col + vec_elems <= cols) { + vec.load_from(&input[global_row * cols + global_col]); + } else { + vec.clear(); +#pragma unroll + for (int i = 0; i < vec_elems; ++i) { + if (global_row < rows && global_col + i < cols) { + vec.data.elt[i] = input[global_row * cols + global_col + i]; + } + } + } + vec.store_to(&tile[local_row * kTileCols + local_col]); } +} - __device__ __forceinline__ const uint32_t *selected_packed(const bool pick_map4) const { - if (pick_map4) { - return rOut_map4; +template +__device__ void quantize_tile_rowwise(const IType *tile, fp4e2m1x2 *output, nvfp4_scale_t *scales, + const float *amax, const size_t rows, const size_t cols, + const size_t tile_row, const size_t tile_col, + const size_t scale_stride) { + constexpr int groups = kTileRows * kTileColGroups; + for (int group = threadIdx.x; group < groups; group += blockDim.x) { + const int local_row = group % kTileRows; + const int local_col_group = group / kTileRows; + const int local_col = local_col_group * kGroupSize; + const size_t global_row = tile_row + local_row; + const size_t global_col = tile_col + local_col; + if (global_row >= rows || global_col >= cols) { + continue; } - return rOut_map6; - } -}; -template -struct alignas(16) QuantizationScratch4Over6 { - alignas(16) float err_map4_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X][BLOCK_DIM]; - alignas(16) float err_map6_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X][BLOCK_DIM]; - alignas(16) uint8_t pick_map4_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X]; - alignas(16) nvfp4_scale_t selected_scale_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X]; - - template - static constexpr size_t dynamic_shared_memory_size() { - if constexpr (USE_2D_QUANTIZATION && USE_4OVER6) { - return ((sizeof(QuantizationScratch4Over6) + TMA_SHMEM_ALIGNMENT - 1) / TMA_SHMEM_ALIGNMENT) * - TMA_SHMEM_ALIGNMENT; + float x0[8]; + float x1[8]; + float group_amax = 0.0f; + load_row_group(tile, local_row, local_col, x0, x1, &group_amax); + + float block_amax = group_amax; + if constexpr (USE_2D_QUANTIZATION) { + block_amax = reduce_group_max_16(group_amax); } - return 0; - } -}; -template -__device__ __forceinline__ void load_4over6_contiguous_halves_16x(const input_type *x, - float (&first_half)[8], - float (&second_half)[8]) { -#pragma unroll - for (int i = 0; i < 8; ++i) { - first_half[i] = static_cast(x[i]); - second_half[i] = static_cast(x[i + 8]); - } -} + float global_amax = amax[0]; + if constexpr (ROW_SCALED_NVFP4) { + global_amax = amax[global_row]; + } -template -__device__ __forceinline__ void load_4over6_pair_array_halves_16x(const pair_type (&x)[2][4], - float (&first_half)[8], - float (&second_half)[8]) { -#pragma unroll - for (int i = 0; i < 4; ++i) { - first_half[2 * i] = static_cast(x[0][i].x); - first_half[2 * i + 1] = static_cast(x[0][i].y); - second_half[2 * i] = static_cast(x[1][i].x); - second_half[2 * i + 1] = static_cast(x[1][i].y); - } -} + const ScalePair scale_pair = compute_scale_pair(block_amax, global_amax); + CandidatePair candidates = make_candidates(x0, x1, scale_pair, global_amax); -template -__device__ __forceinline__ void load_4over6_vec2_array_halves_16x(const vec_type (&x)[8], - float (&first_half)[8], - float (&second_half)[8]) { -#pragma unroll - for (int i = 0; i < 4; ++i) { - first_half[2 * i] = static_cast(x[i].data.elt[0]); - first_half[2 * i + 1] = static_cast(x[i].data.elt[1]); - second_half[2 * i] = static_cast(x[i + 4].data.elt[0]); - second_half[2 * i + 1] = static_cast(x[i + 4].data.elt[1]); - } -} + float err_map4 = candidates.map4.err; + float err_map6 = candidates.map6.err; + if constexpr (USE_2D_QUANTIZATION) { + err_map4 = reduce_group_sum_16(err_map4); + err_map6 = reduce_group_sum_16(err_map6); + } -template -__device__ __forceinline__ void load_4over6_vec_index_halves_16x(const vec_type (&x)[16], - const int idx, - float (&first_half)[8], - float (&second_half)[8]) { -#pragma unroll - for (int i = 0; i < 8; ++i) { - first_half[i] = static_cast(x[i].data.elt[idx]); - second_half[i] = static_cast(x[i + 8].data.elt[idx]); + const bool pick_map4 = err_map4 < err_map6; + const nvfp4_scale_t selected_scale = select_scale(scale_pair, pick_map4); + const uint32_t *selected = select_packed(candidates, pick_map4); + + const size_t global_col_group = global_col / kGroupSize; + scales[global_row * scale_stride + global_col_group] = selected_scale; + store_packed_group(selected, &output[(global_row * cols + global_col) / 2]); } } -template -__device__ __forceinline__ void quantize_4over6_candidates_16x( - const input_type (&x)[16], const QuantizationScales4Over6 &scaling_factors, - const float global_amax, QuantizationCandidates4Over6 &candidates) { - float first_half[8]; - float second_half[8]; - load_4over6_contiguous_halves_16x(x, first_half, second_half); - - candidates.reset_errors(); - quantize_4over6_16x( - first_half, second_half, scaling_factors, global_amax, candidates.err_map4, - candidates.err_map6, candidates.rOut_map4, candidates.rOut_map6); -} +template +__device__ void quantize_tile_colwise(const IType *tile, fp4e2m1x2 *output_t, + nvfp4_scale_t *scales_t, const float *amax, const size_t rows, + const size_t cols, const size_t tile_row, + const size_t tile_col, const size_t scale_stride_t) { + constexpr int groups = kTileRowGroups * kTileCols; + for (int group = threadIdx.x; group < groups; group += blockDim.x) { + const int local_row_group = group / kTileCols; + const int local_col = group - local_row_group * kTileCols; + const int local_row = local_row_group * kGroupSize; + const size_t global_row = tile_row + local_row; + const size_t global_col = tile_col + local_col; + if (global_row >= rows || global_col >= cols) { + continue; + } -template -__device__ __forceinline__ bool record_and_select_4over6_2d_block( - const QuantizationScales4Over6 &scaling_factors, const size_t block_in_tile_y, - const size_t block_in_tile_x, const size_t participant_idx, - QuantizationScratch4Over6 &scratch, - nvfp4_scale_t &S_dec_b_fp8, const QuantizationCandidates4Over6 &candidates) { - scratch.err_map4_matrix[block_in_tile_y][block_in_tile_x][participant_idx] = candidates.err_map4; - scratch.err_map6_matrix[block_in_tile_y][block_in_tile_x][participant_idx] = candidates.err_map6; - __syncthreads(); + float x0[8]; + float x1[8]; + float group_amax = 0.0f; + load_col_group(tile, local_row, local_col, x0, x1, &group_amax); - if (participant_idx == 0) { - float block_err_map4 = 0.0f; - float block_err_map6 = 0.0f; -#pragma unroll - for (int i = 0; i < BLOCK_DIM; ++i) { - block_err_map4 += scratch.err_map4_matrix[block_in_tile_y][block_in_tile_x][i]; - block_err_map6 += scratch.err_map6_matrix[block_in_tile_y][block_in_tile_x][i]; + float block_amax = group_amax; + if constexpr (USE_2D_QUANTIZATION) { + block_amax = reduce_group_max_16(group_amax); } - const bool pick_map4 = pick_4over6_map4(block_err_map4, block_err_map6); - if (pick_map4) { - scratch.pick_map4_matrix[block_in_tile_y][block_in_tile_x] = 1; - } else { - scratch.pick_map4_matrix[block_in_tile_y][block_in_tile_x] = 0; + const float global_amax = amax[0]; + const ScalePair scale_pair = compute_scale_pair(block_amax, global_amax); + CandidatePair candidates = make_candidates(x0, x1, scale_pair, global_amax); + + float err_map4 = candidates.map4.err; + float err_map6 = candidates.map6.err; + if constexpr (USE_2D_QUANTIZATION) { + err_map4 = reduce_group_sum_16(err_map4); + err_map6 = reduce_group_sum_16(err_map6); } - scratch.selected_scale_matrix[block_in_tile_y][block_in_tile_x] = - selected_4over6_scale(pick_map4, scaling_factors); - } - __syncthreads(); - S_dec_b_fp8 = scratch.selected_scale_matrix[block_in_tile_y][block_in_tile_x]; - return scratch.pick_map4_matrix[block_in_tile_y][block_in_tile_x] != 0; -} -template -__device__ __forceinline__ bool quantize_and_select_4over6_2d_block_16x( - const input_type (&x)[16], const float block_amax, const float global_encode_scale, - const float global_decode_scale, const float global_amax, const size_t block_in_tile_y, - const size_t block_in_tile_x, const size_t participant_idx, - QuantizationScratch4Over6 &scratch, - nvfp4_scale_t &S_dec_b_fp8, QuantizationCandidates4Over6 &candidates) { - const auto scaling_factors = compute_4over6_fp4_encode_quantization_scaling_factors( - block_amax, global_encode_scale, global_decode_scale); - quantize_4over6_candidates_16x(x, scaling_factors, global_amax, - candidates); - - const bool pick_map4 = - record_and_select_4over6_2d_block( - scaling_factors, block_in_tile_y, block_in_tile_x, participant_idx, scratch, S_dec_b_fp8, - candidates); - return pick_map4; -} + const bool pick_map4 = err_map4 < err_map6; + const nvfp4_scale_t selected_scale = select_scale(scale_pair, pick_map4); + const uint32_t *selected = select_packed(candidates, pick_map4); -template -__device__ __forceinline__ void store_4over6_colwise_packed_16x( - const bool pick_map4, const QuantizationCandidates4Over6 &candidates, const int thread_lane, - output_type *out_t_data_sh, const size_t shmem_offset_base_colwise_out_t) { - const uint32_t *regs_4x = candidates.selected_packed(pick_map4); - const int group = thread_lane / 16; - uint32_t val[2]; - switch (group) { - case 0: - val[0] = regs_4x[0]; - val[1] = regs_4x[1]; - break; - case 1: - val[0] = regs_4x[1]; - val[1] = regs_4x[0]; - break; + const size_t global_row_group = global_row / kGroupSize; + scales_t[global_col * scale_stride_t + global_row_group] = selected_scale; + store_packed_group(selected, &output_t[(global_col * rows + global_row) / 2]); } - uint32_t *out_t_data_sh_as_uint32_t = - reinterpret_cast(&out_t_data_sh[shmem_offset_base_colwise_out_t]); - out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2; - out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2; } -template -__device__ __forceinline__ void store_4over6_rowwise_packed_16x( - const bool pick_map4, const QuantizationCandidates4Over6 &candidates, const int bank_group, - const size_t thread_offset_X_rowwise, const size_t shmem_offset_base_rowwise_out, - output_type *out_data_sh) { - const uint32_t *packed = candidates.selected_packed(pick_map4); -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; - const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; - uint32_t *out_data_sh_as_uint32_t = - reinterpret_cast(&out_data_sh[shmem_offset_rowwise]); - out_data_sh_as_uint32_t[0] = packed[swizzled_group_idx / PACK_SIZE]; +template +__global__ void __launch_bounds__(kThreads) + quantize_4over6_kernel(const IType *input, fp4e2m1x2 *output, fp4e2m1x2 *output_t, + nvfp4_scale_t *scales, nvfp4_scale_t *scales_t, + const float *amax_rowwise, const float *amax_colwise, const size_t rows, + const size_t cols, const size_t scale_stride, + const size_t scale_stride_t, const float *noop) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (noop != nullptr && noop[0] == 1.0f) { + return; } -} - -template -__device__ __forceinline__ void store_4over6_packed_16x(const uint32_t *packed, - output_vec_type &output_vec) { - *reinterpret_cast(&output_vec.data.elt[0]) = packed[0]; - *reinterpret_cast(&output_vec.data.elt[4]) = packed[1]; -} - -template -__device__ __forceinline__ void store_selected_4over6_packed_16x( - const bool pick_map4, const QuantizationCandidates4Over6 &candidates, - output_vec_type &output_vec) { - store_4over6_packed_16x(candidates.selected_packed(pick_map4), output_vec); -} -template -__device__ __forceinline__ void quantize_4over6_contiguous_16x( - const input_type *x, const QuantizationScales4Over6 &scaling_factors, const float global_amax, - nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { - float first_half[8]; - float second_half[8]; - load_4over6_contiguous_halves_16x(x, first_half, second_half); + extern __shared__ char dynamic_shmem[]; + auto *tile = reinterpret_cast(dynamic_shmem); + const size_t tile_col = blockIdx.x * kTileCols; + const size_t tile_row = blockIdx.y * kTileRows; - quantize_4over6_16x( - first_half, second_half, scaling_factors, global_amax, S_dec_b_fp8, rOut); -} - -template -__device__ __forceinline__ void quantize_4over6_pair_array_16x( - const pair_type (&x)[2][4], const QuantizationScales4Over6 &scaling_factors, - const float global_amax, nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { - float first_half[8]; - float second_half[8]; - load_4over6_pair_array_halves_16x(x, first_half, second_half); + load_tile_to_shared(input, tile, rows, cols, tile_row, tile_col); + __syncthreads(); - quantize_4over6_16x( - first_half, second_half, scaling_factors, global_amax, S_dec_b_fp8, rOut); -} + if constexpr (RETURN_IDENTITY) { + quantize_tile_rowwise( + tile, output, scales, amax_rowwise, rows, cols, tile_row, tile_col, scale_stride); + } -template -__device__ __forceinline__ void quantize_4over6_vec2_array_candidates_16x( - const vec_type (&x)[8], const QuantizationScales4Over6 &scaling_factors, - const float global_amax, QuantizationCandidates4Over6 &candidates) { - float first_half[8]; - float second_half[8]; - load_4over6_vec2_array_halves_16x(x, first_half, second_half); - - candidates.reset_errors(); - quantize_4over6_16x( - first_half, second_half, scaling_factors, global_amax, candidates.err_map4, - candidates.err_map6, candidates.rOut_map4, candidates.rOut_map6); + if constexpr (RETURN_TRANSPOSE) { + const float *columnwise_amax = amax_colwise; + if (columnwise_amax == nullptr) { + columnwise_amax = amax_rowwise; + } + quantize_tile_colwise( + tile, output_t, scales_t, columnwise_amax, rows, cols, tile_row, tile_col, scale_stride_t); + } +#else + NVTE_DEVICE_ERROR("sm_100 or higher is required."); +#endif } -template -__device__ __forceinline__ void quantize_4over6_vec2_array_16x( - const vec_type (&x)[8], const QuantizationScales4Over6 &scaling_factors, - const float global_amax, nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { - float first_half[8]; - float second_half[8]; - load_4over6_vec2_array_halves_16x(x, first_half, second_half); - - quantize_4over6_16x( - first_half, second_half, scaling_factors, global_amax, S_dec_b_fp8, rOut); +template +void launch_quantize_4over6(const Tensor &input, const Tensor *noop, Tensor *output, + cudaStream_t stream) { + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + const bool row_scaled_nvfp4 = output->row_scaled_nvfp4; + const bool return_identity = output->has_data(); + const bool return_transpose = output->has_columnwise_data(); + + const auto *input_ptr = reinterpret_cast(input.data.dptr); + auto *output_ptr = reinterpret_cast(output->data.dptr); + auto *output_t_ptr = reinterpret_cast(output->columnwise_data.dptr); + auto *scales_ptr = reinterpret_cast(output->scale_inv.dptr); + auto *scales_t_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); + const auto *amax_rowwise_ptr = reinterpret_cast(output->amax.dptr); + const auto *amax_colwise_ptr = reinterpret_cast(output->columnwise_amax.dptr); + const auto *noop_ptr = reinterpret_cast(noop->data.dptr); + + const dim3 grid(DIVUP(cols, static_cast(kTileCols)), + DIVUP(rows, static_cast(kTileRows))); + const dim3 block(kThreads); + const size_t shmem = kTileRows * kTileCols * sizeof(IType); + const size_t scale_stride = return_identity ? output->scale_inv.shape[1] : 0; + const size_t scale_stride_t = return_transpose ? output->columnwise_scale_inv.shape[1] : 0; + + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_identity, RETURN_IDENTITY, { + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + TRANSFORMER_ENGINE_SWITCH_CONDITION(row_scaled_nvfp4, ROW_SCALED_NVFP4, { + auto kernel = quantize_4over6_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem); + kernel<<>>(input_ptr, output_ptr, output_t_ptr, scales_ptr, + scales_t_ptr, amax_rowwise_ptr, amax_colwise_ptr, + rows, cols, scale_stride, scale_stride_t, noop_ptr); + }); + }); + }); } -template -__device__ __forceinline__ void quantize_4over6_vec_index_candidates_16x( - const vec_type (&x)[16], const int idx, const QuantizationScales4Over6 &scaling_factors, - const float global_amax, QuantizationCandidates4Over6 &candidates) { - float first_half[8]; - float second_half[8]; - load_4over6_vec_index_halves_16x(x, idx, first_half, second_half); - - candidates.reset_errors(); - quantize_4over6_16x( - first_half, second_half, scaling_factors, global_amax, candidates.err_map4, - candidates.err_map6, candidates.rOut_map4, candidates.rOut_map6); -} +} // namespace quantize_4over6_kernel -template -__device__ __forceinline__ void quantize_4over6_vec_index_16x( - const vec_type (&x)[16], const int idx, const QuantizationScales4Over6 &scaling_factors, - const float global_amax, nvfp4_scale_t &S_dec_b_fp8, uint32_t (&rOut)[2]) { - float first_half[8]; - float second_half[8]; - load_4over6_vec_index_halves_16x(x, idx, first_half, second_half); +#endif // FP4_TYPE_SUPPORTED - quantize_4over6_16x( - first_half, second_half, scaling_factors, global_amax, S_dec_b_fp8, rOut); -} +template +void quantize_4over6(const Tensor &input, const Tensor *noop, Tensor *output, + const QuantizationConfig *quant_config, cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + using namespace quantize_4over6_kernel; + + checkCuDriverContext(stream); + CheckNoopTensor(*noop, "cast_noop"); + CheckInputTensor(input, "input"); + CheckOutputTensor(*output, "output", false); + + NVTE_CHECK(quant_config != nullptr && quant_config->nvfp4_4over6, + "NVFP4 4over6 quantization requires an enabled quantization config."); + NVTE_CHECK(!quant_config->stochastic_rounding, + "NVFP4 4over6 quantization does not support stochastic rounding."); + NVTE_CHECK(quant_config->nvfp4_e4m3_max == output->nvfp4_e4m3_max, + "Tensor and quantization config have inconsistent options for NVFP4 4over6 " + "E4M3 scale bound."); + NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + "NVFP4 4over6 output tensor must have rowwise or columnwise data."); + NVTE_CHECK(!output->with_gemm_swizzled_scales, "Output must have scales in compact format."); + NVTE_CHECK(input.flat_last_dim() % kGroupSize == 0, + "NVFP4 4over6 quantization requires columns divisible by ", kGroupSize, "."); + NVTE_CHECK(!(output->has_columnwise_data() || use_2d_quantization) || + input.flat_first_dim() % kGroupSize == 0, + "NVFP4 4over6 columnwise or 2D quantization requires rows divisible by ", kGroupSize, + "."); + NVTE_CHECK(!output->row_scaled_nvfp4 || !use_2d_quantization, + "Row-scaled NVFP4 quantization does not support 2D quantization."); + NVTE_CHECK(!output->row_scaled_nvfp4 || !output->has_columnwise_data(), + "Row-scaled NVFP4 quantization does not produce columnwise output."); + NVTE_CHECK(!use_2d_quantization || output->has_data(), + "NVFP4 4over6 2D quantization requires rowwise output."); + + if (output->has_data()) { + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); + NVTE_CHECK(output->amax.dptr != nullptr, "Rowwise amax tensor must be allocated."); + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Transposed scaling tensor must be allocated."); + NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), + "Transposed output must have FP4 type."); + NVTE_CHECK(output->columnwise_amax.dptr != nullptr || output->amax.dptr != nullptr, + "NVFP4 4over6 columnwise quantization requires columnwise amax or rowwise amax."); + } + TRANSFORMER_ENGINE_NVFP4_4OVER6_E4M3_MAX_SWITCH( + quant_config->nvfp4_e4m3_max, E4M3_MAX, + TRANSFORMER_ENGINE_NVFP4_4OVER6_ERR_MODE_SWITCH( + quant_config->nvfp4_4over6_err_mode, ERR_MODE, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + quant_config->nvfp4_4over6_err_use_fast_math, ERR_USE_FAST_MATH, { + using Cfg = quantize_4over6_kernel::Config; + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), IType, + quantize_4over6_kernel::launch_quantize_4over6( + input, noop, output, stream);); + }););); + + NVTE_CHECK_CUDA(cudaGetLastError()); +#else + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); #endif // FP4_TYPE_SUPPORTED +} -} // namespace core } // namespace nvfp4 } // namespace dispatch } // namespace transformer_engine diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index 76797d86ad..9e4aef5a1c 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -23,7 +23,6 @@ #include "../../util/ptx.cuh" #include "../../utils.cuh" #include "core_nvfp4.cuh" -#include "quantize_4over6_nvfp4.cuh" #include "specialized/quantize_transpose_nvfp4_tuned_1D.cuh" namespace transformer_engine { @@ -181,10 +180,6 @@ constexpr size_t RNG_GENS_PER_THREAD = SCALES_PER_THREAD / 4; constexpr size_t TILE_DIM_Y = 32; constexpr size_t TILE_DIM_X = 128; -constexpr size_t NVFP4_2D_BLOCK_DIM = 16; -constexpr size_t NVFP4_2D_BLOCKS_PER_TILE_Y = TILE_DIM_Y / NVFP4_2D_BLOCK_DIM; -constexpr size_t NVFP4_2D_BLOCKS_PER_TILE_X = TILE_DIM_X / NVFP4_2D_BLOCK_DIM; - // SHould this be SCALE_DIM or BLOCK_DIM? Both are 16, should work for both 1D and 2D constexpr size_t SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; constexpr size_t SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 128 / 16 = 8 @@ -784,8 +779,7 @@ __global__ void __launch_bounds__(THREADS_NUM) } template + typename IType, bool USE_STOCHASTIC_ROUNDING, bool RETURN_TRANSPOSE> __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_output, @@ -818,9 +812,9 @@ __global__ void __launch_bounds__(THREADS_NUM) 0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x // NEW: 2D Block-based scaling constants - constexpr size_t BLOCK_DIM = NVFP4_2D_BLOCK_DIM; - constexpr size_t BLOCKS_PER_TILE_Y = NVFP4_2D_BLOCKS_PER_TILE_Y; - constexpr size_t BLOCKS_PER_TILE_X = NVFP4_2D_BLOCKS_PER_TILE_X; + constexpr size_t BLOCK_DIM = 16; + constexpr size_t BLOCKS_PER_TILE_Y = TILE_DIM_Y / BLOCK_DIM; // 32/16 = 2 + constexpr size_t BLOCKS_PER_TILE_X = TILE_DIM_X / BLOCK_DIM; // 128/16 = 8 constexpr size_t ITERATIONS_BLOCK = 2; // iterations to calculate 2d block amaxes of 1 tile constexpr size_t BLOCKS_PER_WARP = BLOCKS_PER_TILE_X / (THREADS_NUM / 32); // 8 / (128/32) = 2 @@ -875,8 +869,6 @@ __global__ void __launch_bounds__(THREADS_NUM) constexpr size_t out_mem_rowwise_data = buff_size_aligned_out; constexpr size_t out_mem_colwise_data = buff_size_aligned_out; constexpr size_t out_mem_rowwise_scales = 0; - constexpr size_t out_mem_colwise_scales = - (CHUNK_DIM_Y * CHUNK_DIM_X) / SCALE_DIM * sizeof(nvfp4_scale_t); extern __shared__ char dynamic_shmem[]; uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); @@ -896,37 +888,21 @@ __global__ void __launch_bounds__(THREADS_NUM) dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer - using FourOverSixScratch = - QuantizationScratch4Over6; - FourOverSixScratch *four_over_six_scratch = nullptr; - if constexpr (FourOverSixConfig::enabled) { - constexpr size_t four_over_six_scratch_offset = in_mem + out_mem_rowwise_data + - out_mem_colwise_data + out_mem_rowwise_scales + - out_mem_colwise_scales; - four_over_six_scratch = - reinterpret_cast(dshmem + four_over_six_scratch_offset); - } - constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; const bool is_master_thread = (threadIdx.x == 0); // Compute a global encoding/decoding scaling factors for all S_dec_b - const float S_enc_rowwise = - (amax_rowwise_ptr == nullptr) - ? 1.0f - : compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); + const float S_enc_rowwise = (amax_rowwise_ptr == nullptr) + ? 1.0f + : compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); // NOTE: This is to match with how emulation code was written. const float S_dec_rowwise = 1.0 / S_enc_rowwise; - const float S_enc_colwise = - (amax_colwise_ptr == nullptr) - ? S_enc_rowwise - : compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); + const float S_enc_colwise = (amax_colwise_ptr == nullptr) + ? S_enc_rowwise + : compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); const float S_dec_colwise = 1.0 / S_enc_colwise; - const float global_amax_rowwise = (amax_rowwise_ptr == nullptr) ? 1.0f : *amax_rowwise_ptr; - const float global_amax_colwise = - (amax_colwise_ptr == nullptr) ? global_amax_rowwise : *amax_colwise_ptr; const size_t warp_id = threadIdx.x / 32; const size_t lane_id = threadIdx.x % 32; @@ -1098,85 +1074,56 @@ __global__ void __launch_bounds__(THREADS_NUM) } } - if constexpr (FourOverSixConfig::enabled) { - const size_t block_col = threadIdx.x % BLOCK_DIM; - QuantizationCandidates4Over6 candidates; - nvfp4_scale_t S_dec_b_fp8; - bool pick_map4; - if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { - pick_map4 = - quantize_and_select_4over6_2d_block_16x( - in_colwise_IType, block_amax, S_enc_colwise, S_dec_colwise, global_amax_colwise, - block_in_tile_y, block_in_tile_x, block_col, *four_over_six_scratch, - S_dec_b_fp8, candidates); - } else { - pick_map4 = - quantize_and_select_4over6_2d_block_16x( - in_compute_colwise, block_amax, S_enc_colwise, S_dec_colwise, - global_amax_colwise, block_in_tile_y, block_in_tile_x, block_col, - *four_over_six_scratch, S_dec_b_fp8, candidates); - } - - const size_t scale_idx_sh = - tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; - out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8; - - store_4over6_colwise_packed_16x(pick_map4, candidates, thread_lane, out_t_data_sh, - shmem_offset_base_colwise_out_t); - } else { - // 2. Compute E4M3 scaling factor - const nvfp4_scale_t S_dec_b_fp8 = - compute_decoding_scaling_factor(block_amax, S_enc_colwise); + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_colwise); - // // Store scaling factors through SHMEM - const size_t scale_idx_sh = - tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; - out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8; + // // Store scaling factors through SHMEM + const size_t scale_idx_sh = + tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; + out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8; - // Compute "correct" per-block encoding scaling factor - constexpr float float_max = detail::TypeExtrema::max; - const float block_scale_inverse = fminf( - 1.0f / (static_cast(S_dec_b_fp8) * S_dec_colwise), float_max); // S_enc_b_fp8 - const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = fminf( + 1.0f / (static_cast(S_dec_b_fp8) * S_dec_colwise), float_max); // S_enc_b_fp8 + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; - fp4e2m1x4 regs[SCALE_DIM / 4]; + fp4e2m1x4 regs[SCALE_DIM / 4]; #pragma unroll - for (int e = 0; e < SCALE_DIM / 4; ++e) { - const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); - if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { - const uint64_t elts = *reinterpret_cast(&in_colwise_IType[4 * e]); - regs[e] = ptx::mul_cvt_bf16_to_fp4_4x( - elts, block_scale_inverse_2x, rbits); - } else { - const float2 in01 = *reinterpret_cast(&in_compute_colwise[4 * e]); - const float2 in23 = *reinterpret_cast(&in_compute_colwise[4 * e + 2]); - regs[e] = ptx::mul_cvt_fp32_to_fp4_4x( - in01, in23, block_scale_inverse_2x, rbits); - } + for (int e = 0; e < SCALE_DIM / 4; ++e) { + const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + const uint64_t elts = *reinterpret_cast(&in_colwise_IType[4 * e]); + regs[e] = ptx::mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else { + const float2 in01 = *reinterpret_cast(&in_compute_colwise[4 * e]); + const float2 in23 = *reinterpret_cast(&in_compute_colwise[4 * e + 2]); + regs[e] = ptx::mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, rbits); } + } - const int group = thread_lane / 16; - uint32_t val[2]; - uint32_t *regs_4x = reinterpret_cast(regs); - - // Helps reducing bank conflicts - switch (group) { - case 0: - val[0] = regs_4x[0]; - val[1] = regs_4x[1]; - break; - case 1: - val[0] = regs_4x[1]; - val[1] = regs_4x[0]; - break; - } - uint32_t *out_t_data_sh_as_uint32_t = - reinterpret_cast(&out_t_data_sh[shmem_offset_base_colwise_out_t]); - out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2; - out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2; + const int group = thread_lane / 16; + uint32_t val[2]; + uint32_t *regs_4x = reinterpret_cast(regs); + + // Helps reducing bank conflicts + switch (group) { + case 0: + val[0] = regs_4x[0]; + val[1] = regs_4x[1]; + break; + case 1: + val[0] = regs_4x[1]; + val[1] = regs_4x[0]; + break; } + uint32_t *out_t_data_sh_as_uint32_t = + reinterpret_cast(&out_t_data_sh[shmem_offset_base_colwise_out_t]); + out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2; + out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2; } } @@ -1196,7 +1143,6 @@ __global__ void __launch_bounds__(THREADS_NUM) block_amax = block_amax_matrix[block_in_tile_y][block_in_tile_x]; float in_compute_rowwise[SCALE_DIM]; - float in_4over6_rowwise[SCALE_DIM]; Vec in_cached[WAVES]; // used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY @@ -1212,15 +1158,6 @@ __global__ void __launch_bounds__(THREADS_NUM) const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; // Load elements in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); - if constexpr (FourOverSixConfig::enabled) { -#pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - in_4over6_rowwise[swizzled_group_idx + 2 * e] = - static_cast(in_IType[w].data.elt[e].x); - in_4over6_rowwise[swizzled_group_idx + 2 * e + 1] = - static_cast(in_IType[w].data.elt[e].y); - } - } } } else if constexpr (IS_CACHED_ACT_OP) { // ensures that all writes to cache made in the section above are visible to all threads @@ -1233,13 +1170,6 @@ __global__ void __launch_bounds__(THREADS_NUM) // Load cached elements in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); - if constexpr (FourOverSixConfig::enabled) { -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - in_4over6_rowwise[swizzled_group_idx + e] = - static_cast(in_cached[w].data.elt[e]); - } - } } } else { #pragma unroll @@ -1265,92 +1195,62 @@ __global__ void __launch_bounds__(THREADS_NUM) elt = static_cast(static_cast(elt)); } in_compute_rowwise[j] = elt; - if constexpr (FourOverSixConfig::enabled) { - in_4over6_rowwise[swizzled_group_idx + e] = elt; - } } } } - if constexpr (FourOverSixConfig::enabled) { - QuantizationCandidates4Over6 candidates; - nvfp4_scale_t S_dec_b_fp8; - const bool pick_map4 = - quantize_and_select_4over6_2d_block_16x( - in_4over6_rowwise, block_amax, S_enc_rowwise, S_dec_rowwise, global_amax_rowwise, - block_in_tile_y, block_in_tile_x, tid_Y_rowwise, *four_over_six_scratch, - S_dec_b_fp8, candidates); - - const size_t scales_offset_Y = - scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; - const size_t scales_offset_X = scales_offset_X_rowwise; - const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; - - const bool rowwise_scale_is_within_bounds_Y = - (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; - if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { - scales_ptr[scale_idx_global] = S_dec_b_fp8; - } - - store_4over6_rowwise_packed_16x( - pick_map4, candidates, bank_group, thread_offset_X_rowwise, - shmem_offset_base_rowwise_out, out_data_sh); - } else { - // 2. Compute E4M3 scaling factor - const nvfp4_scale_t S_dec_b_fp8 = - compute_decoding_scaling_factor(block_amax, S_enc_rowwise); - - // Check boundaries - const size_t scales_offset_Y = - scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; - const size_t scales_offset_X = scales_offset_X_rowwise; - const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; - - const bool rowwise_scale_is_within_bounds_Y = - (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; - if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { - scales_ptr[scale_idx_global] = S_dec_b_fp8; - } + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + + // Check boundaries + const size_t scales_offset_Y = + scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; + const size_t scales_offset_X = scales_offset_X_rowwise; + const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; + + const bool rowwise_scale_is_within_bounds_Y = + (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; + if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { + scales_ptr[scale_idx_global] = S_dec_b_fp8; + } - // Compute "correct" per-block encoding scaling factor - constexpr float float_max = detail::TypeExtrema::max; - const float block_scale_inverse = fminf( - 1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8 - const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = fminf( + 1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8 + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; - // 3. Scale elements + // 3. Scale elements #pragma unroll - for (int w = 0; w < WAVES; ++w) { - Vec out; + for (int w = 0; w < WAVES; ++w) { + Vec out; #pragma unroll - for (int e = 0; e < PACK_SIZE / 4; ++e) { - const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); - IType2 in01; - IType2 in23; - if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { - const uint64_t elts = *reinterpret_cast(&in_IType[w].data.elt[2 * e]); - out.data.elt[e] = ptx::mul_cvt_bf16_to_fp4_4x( - elts, block_scale_inverse_2x, rbits); - } else if constexpr (IS_CACHED_ACT_OP) { - const uint64_t elts = *reinterpret_cast(&in_cached[w].data.elt[4 * e]); - out.data.elt[e] = ptx::mul_cvt_bf16_to_fp4_4x( - elts, block_scale_inverse_2x, rbits); - } else { - const int j = w * PACK_SIZE + 4 * e; - const float2 in01 = make_float2(in_compute_rowwise[j], in_compute_rowwise[j + 1]); - const float2 in23 = - make_float2(in_compute_rowwise[j + 2], in_compute_rowwise[j + 3]); - out.data.elt[e] = ptx::mul_cvt_fp32_to_fp4_4x( - in01, in23, block_scale_inverse_2x, rbits); - } + for (int e = 0; e < PACK_SIZE / 4; ++e) { + const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); + IType2 in01; + IType2 in23; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + const uint64_t elts = *reinterpret_cast(&in_IType[w].data.elt[2 * e]); + out.data.elt[e] = ptx::mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else if constexpr (IS_CACHED_ACT_OP) { + const uint64_t elts = *reinterpret_cast(&in_cached[w].data.elt[4 * e]); + out.data.elt[e] = ptx::mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else { + const int j = w * PACK_SIZE + 4 * e; + const float2 in01 = make_float2(in_compute_rowwise[j], in_compute_rowwise[j + 1]); + const float2 in23 = make_float2(in_compute_rowwise[j + 2], in_compute_rowwise[j + 3]); + out.data.elt[e] = ptx::mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, rbits); } - - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; - const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; - out.store_to(&out_data_sh[shmem_offset_rowwise]); } + + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; + out.store_to(&out_data_sh[shmem_offset_rowwise]); } } } @@ -1419,17 +1319,9 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, using namespace ptx; bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; - const bool use_4over6 = quant_config ? quant_config->nvfp4_4over6 : false; - const int nvfp4_e4m3_max = use_4over6 && quant_config ? quant_config->nvfp4_e4m3_max : 448; - const NVTENVFP44Over6ErrMode use_4over6_err_mode = - use_4over6 && quant_config ? quant_config->nvfp4_4over6_err_mode : kNVTENVFP44Over6ErrMAE; - const bool use_4over6_err_use_fast_math = - use_4over6 && quant_config && quant_config->nvfp4_4over6_err_use_fast_math; const bool row_scaled_nvfp4 = output->row_scaled_nvfp4; NVTE_CHECK(!row_scaled_nvfp4 || !use_2d_quantization, "Row-scaled NVFP4 quantization does not support 2D quantization."); - NVTE_CHECK(!use_4over6 || !use_stochastic_rounding, - "NVFP4 4over6 quantization does not support stochastic rounding."); // If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to // return the transposed data. @@ -1535,41 +1427,29 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, constexpr size_t out_mem = out_data_mem + out_data_transpose_mem; - constexpr size_t base_dshmem_size = - in_mem + out_mem + out_scales_transpose_mem + TMA_SHMEM_ALIGNMENT; + constexpr size_t dshmem_size = in_mem + out_mem + out_scales_transpose_mem + TMA_SHMEM_ALIGNMENT; TRANSFORMER_ENGINE_SWITCH_CONDITION( use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, - TRANSFORMER_ENGINE_SWITCH_CONDITION( - row_scaled_nvfp4, ROW_SCALED_NVFP4, - TRANSFORMER_ENGINE_NVFP4_4OVER6_CONFIG_SWITCH( - use_4over6, use_4over6_err_mode, use_4over6_err_use_fast_math, FourOverSixConfig, - TRANSFORMER_ENGINE_NVFP4_4OVER6_E4M3_MAX_SWITCH( - nvfp4_e4m3_max, E4M3_MAX, - TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { - auto kernel = - quantize_transpose_nvfp4_kernel; - if constexpr (use_2d_quantization) { - kernel = quantize_transpose_nvfp4_2D_kernel< - COMPUTE_ACTIVATIONS, ParamOP, OP, IType, USE_STOCHASTIC_ROUNDING, - RETURN_TRANSPOSE, E4M3_MAX, FourOverSixConfig>; - } - using FourOverSixScratch = core::QuantizationScratch4Over6< - NVFP4_2D_BLOCK_DIM, NVFP4_2D_BLOCKS_PER_TILE_Y, NVFP4_2D_BLOCKS_PER_TILE_X>; - constexpr size_t dshmem_size = - base_dshmem_size + FourOverSixScratch::template dynamic_shared_memory_size< - use_2d_quantization, FourOverSixConfig::enabled>(); - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - dshmem_size); - kernel<<>>( - tensor_map_input, tensor_map_output, tensor_map_output_transpose, - scales_ptr, scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, - amax_colwise_ptr, rows, cols, scale_stride, scale_stride_transpose, - rng_state); - }););););); + if constexpr (use_2d_quantization) { + kernel = quantize_transpose_nvfp4_2D_kernel; + } + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, + scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, + scale_stride, scale_stride_transpose, rng_state); + }); + });); #else NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); #endif // FP4_TYPE_SUPPORTED diff --git a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh index a60097b9df..8adda82131 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh @@ -21,7 +21,6 @@ #include "../../../util/ptx.cuh" #include "../../../utils.cuh" #include "../core_nvfp4.cuh" -#include "../quantize_4over6_nvfp4.cuh" namespace transformer_engine { namespace dispatch { @@ -185,13 +184,14 @@ compute_nvfp4_scaling_coefficient(const nvfp4_scale_t S_dec_block, const f return static_cast(scale_rcp); } -template -__device__ __forceinline__ void colwise_scaling( - const IType *__restrict__ sIn_ptr, fp4e2m1x2 *__restrict__ sOut_tr_ptr, - nvfp4_scale_t *__restrict__ sSFcolwise_ptr, const float S_enc_colwise, - const float global_amax_colwise, const int stage_Y, const int stage_X, const int buff_in, - const int buff_out_tr, RNG_t &rng, uint4 &random_uint4, int &rnd_idx) { +template +__device__ __forceinline__ void colwise_scaling(const IType *__restrict__ sIn_ptr, + fp4e2m1x2 *__restrict__ sOut_tr_ptr, + nvfp4_scale_t *__restrict__ sSFcolwise_ptr, + const float S_enc_colwise, const int stage_Y, + const int stage_X, const int buff_in, + const int buff_out_tr, RNG_t &rng, + uint4 &random_uint4, int &rnd_idx) { using scaling_coeff_type = typename SCALING_COEFFICIENT_TYPE::type; const auto &sIn2x = *reinterpret_cast(sIn_ptr); @@ -231,56 +231,37 @@ __device__ __forceinline__ void colwise_scaling( static_cast(__habs(thread_amax_2x.y))}; #pragma unroll for (int w = 0; w < 2; ++w) { - if constexpr (FourOverSixConfig::enabled) { - __align__(8) uint32_t rOut[SCALE_DIM / 8]; - nvfp4_scale_t S_dec_b_fp8; - const auto scaling_factors = - core::compute_4over6_nvfp4_quantization_scaling_factors(block_amax[w], S_enc_colwise); + const nvfp4_scale_t S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax[w], S_enc_colwise); - core::quantize_4over6_contiguous_16x( - rIn[w], scaling_factors, global_amax_colwise, S_dec_b_fp8, rOut); + // Store scaling factors to SMEM buffer (R2S) + sSFcolwise[scale_tr_offset_Y + w][scale_tr_offset_X] = S_dec_b_fp8; - // Store scaling factors to SMEM buffer (R2S) - sSFcolwise[scale_tr_offset_Y + w][scale_tr_offset_X] = S_dec_b_fp8; + const scaling_coeff_type SFcoefficient = + compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_colwise); - uint64_t &out_pack_16x = *reinterpret_cast(rOut); - ptx::st_shared_b64(&sOut_tr[buff_out_tr][out_tr_thread_offset_Y + w][out_tr_thread_offset_X], - out_pack_16x); - } else { - const nvfp4_scale_t S_dec_b_fp8 = - compute_decoding_scaling_factor(block_amax[w], S_enc_colwise); - - // Store scaling factors to SMEM buffer (R2S) - sSFcolwise[scale_tr_offset_Y + w][scale_tr_offset_X] = S_dec_b_fp8; - - const scaling_coeff_type SFcoefficient = - compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_colwise); - - // Scale elements - __align__(8) uint32_t rOut[SCALE_DIM / 8]; + // Scale elements + __align__(8) uint32_t rOut[SCALE_DIM / 8]; #pragma unroll - for (int e = 0; e < SCALE_DIM / 8; ++e) { - const uint64_t elts03 = *reinterpret_cast(&rIn[w][8 * e]); - const uint64_t elts47 = *reinterpret_cast(&rIn[w][8 * e + 4]); - if constexpr (USE_STOCHASTIC_ROUNDING) { - const uint32_t rbits03 = core::get_rbits(rng, random_uint4, rnd_idx); - const uint32_t rbits47 = core::get_rbits(rng, random_uint4, rnd_idx); - rOut[e] = ptx::mul_cvt_bf16_to_fp4_8x_stochastic_rounding( - elts03, elts47, SFcoefficient, rbits03, rbits47); - } else { - rOut[e] = ptx::mul_cvt_bf16_to_fp4_8x_round_to_nearest(elts03, elts47, - SFcoefficient); - } + for (int e = 0; e < SCALE_DIM / 8; ++e) { + const uint64_t elts03 = *reinterpret_cast(&rIn[w][8 * e]); + const uint64_t elts47 = *reinterpret_cast(&rIn[w][8 * e + 4]); + if constexpr (USE_STOCHASTIC_ROUNDING) { + const uint32_t rbits03 = core::get_rbits(rng, random_uint4, rnd_idx); + const uint32_t rbits47 = core::get_rbits(rng, random_uint4, rnd_idx); + rOut[e] = ptx::mul_cvt_bf16_to_fp4_8x_stochastic_rounding( + elts03, elts47, SFcoefficient, rbits03, rbits47); + } else { + rOut[e] = ptx::mul_cvt_bf16_to_fp4_8x_round_to_nearest(elts03, elts47, + SFcoefficient); } - uint64_t &out_pack_16x = *reinterpret_cast(rOut); - ptx::st_shared_b64(&sOut_tr[buff_out_tr][out_tr_thread_offset_Y + w][out_tr_thread_offset_X], - out_pack_16x); } + uint64_t &out_pack_16x = *reinterpret_cast(rOut); + ptx::st_shared_b64(&sOut_tr[buff_out_tr][out_tr_thread_offset_Y + w][out_tr_thread_offset_X], + out_pack_16x); } } -template +template __device__ __forceinline__ void rowwise_scaling( const IType *__restrict__ sIn_ptr, fp4e2m1x2 *__restrict__ sOut_ptr, nvfp4_scale_t *__restrict__ sSFrowwise_ptr, const float S_enc_rowwise, const int stage_Y, @@ -332,101 +313,55 @@ __device__ __forceinline__ void rowwise_scaling( } const float block_amax = get_amax_of_pair(thread_amax_2x); - if constexpr (FourOverSixConfig::enabled) { - nvfp4_scale_t S_dec_b_fp8; - float block_S_enc_rowwise; - float block_global_amax; - if constexpr (ROW_SCALED_NVFP4) { - const size_t row_idx = row_offset + stage_Y * TILE_DIM_Y + it_offset_Y_rowwise; - if (row_idx < rows) { - block_global_amax = amax_rowwise_ptr[row_idx]; - block_S_enc_rowwise = - core::compute_global_encode_scaling_factor_FP4(block_global_amax); - } else { - block_global_amax = 1.0f; - block_S_enc_rowwise = 1.0f; - } - } else { - block_global_amax = *amax_rowwise_ptr; - block_S_enc_rowwise = S_enc_rowwise; - } - const auto scaling_factors = - core::compute_4over6_nvfp4_quantization_scaling_factors(block_amax, block_S_enc_rowwise); - - __align__(8) uint32_t rOut[WAVES]; - if (bank_group == 0) { - core::quantize_4over6_pair_array_16x( - rIn, scaling_factors, block_global_amax, S_dec_b_fp8, rOut); - } else { - core::quantize_4over6_pair_array_16x( - rIn, scaling_factors, block_global_amax, S_dec_b_fp8, rOut); - } - - // Store scaling factors to SMEM buffer (R2S) - if (SF_storing_thread) { - const int scales_offset_Y = stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE; - const int scales_offset_X = stage_rowwise_scales_offset_X; - sSFrowwise[scales_offset_Y][scales_offset_X] = S_dec_b_fp8; - } - -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % ELTS_PER_THREAD; - const int swizzled_idx = (swizzled_group_idx + thread_offset_X_rowwise) / 2; - ptx::st_shared_b32(&sOut[buff_out][it_offset_Y_rowwise][swizzled_idx], rOut[w]); - } + nvfp4_scale_t S_dec_b_fp8; + scaling_coeff_type SFcoefficient; + if constexpr (ROW_SCALED_NVFP4) { + const size_t row_idx = row_offset + stage_Y * TILE_DIM_Y + it_offset_Y_rowwise; + const float S_enc_rowwise_block = + row_idx < rows ? core::compute_global_encode_scaling_factor_FP4(amax_rowwise_ptr[row_idx]) + : 1.0f; + S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc_rowwise_block); + SFcoefficient = + compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_rowwise_block); } else { - nvfp4_scale_t S_dec_b_fp8; - scaling_coeff_type SFcoefficient; - if constexpr (ROW_SCALED_NVFP4) { - const size_t row_idx = row_offset + stage_Y * TILE_DIM_Y + it_offset_Y_rowwise; - const float S_enc_rowwise_block = - row_idx < rows - ? core::compute_global_encode_scaling_factor_FP4(amax_rowwise_ptr[row_idx]) - : 1.0f; - S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc_rowwise_block); - SFcoefficient = - compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_rowwise_block); - } else { - S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc_rowwise); - SFcoefficient = - compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_rowwise); - } + S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + SFcoefficient = + compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_rowwise); + } - // Store scaling factors to SMEM buffer (R2S) - if (SF_storing_thread) { - const int scales_offset_Y = stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE; - const int scales_offset_X = stage_rowwise_scales_offset_X; - sSFrowwise[scales_offset_Y][scales_offset_X] = S_dec_b_fp8; - } + // Store scaling factors to SMEM buffer (R2S) + if (SF_storing_thread) { + const int scales_offset_Y = stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE; + const int scales_offset_X = stage_rowwise_scales_offset_X; + sSFrowwise[scales_offset_Y][scales_offset_X] = S_dec_b_fp8; + } // Scale elements #pragma unroll - for (int w = 0; w < WAVES; ++w) { - const uint64_t elts03 = *reinterpret_cast(&rIn[w][0]); - const uint64_t elts47 = *reinterpret_cast(&rIn[w][2]); - - uint32_t out_x8; - if constexpr (USE_STOCHASTIC_ROUNDING) { - const uint32_t rbits03 = core::get_rbits(rng, random_uint4, rnd_idx); - const uint32_t rbits47 = core::get_rbits(rng, random_uint4, rnd_idx); - out_x8 = ptx::mul_cvt_bf16_to_fp4_8x_stochastic_rounding( - elts03, elts47, SFcoefficient, rbits03, rbits47); - } else { - out_x8 = ptx::mul_cvt_bf16_to_fp4_8x_round_to_nearest(elts03, elts47, - SFcoefficient); - } - - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % ELTS_PER_THREAD; - const int swizzled_idx = (swizzled_group_idx + thread_offset_X_rowwise) / 2; - ptx::st_shared_b32(&sOut[buff_out][it_offset_Y_rowwise][swizzled_idx], out_x8); + for (int w = 0; w < WAVES; ++w) { + const uint64_t elts03 = *reinterpret_cast(&rIn[w][0]); + const uint64_t elts47 = *reinterpret_cast(&rIn[w][2]); + + uint32_t out_x8; + if constexpr (USE_STOCHASTIC_ROUNDING) { + const uint32_t rbits03 = core::get_rbits(rng, random_uint4, rnd_idx); + const uint32_t rbits47 = core::get_rbits(rng, random_uint4, rnd_idx); + out_x8 = ptx::mul_cvt_bf16_to_fp4_8x_stochastic_rounding( + elts03, elts47, SFcoefficient, rbits03, rbits47); + } else { + out_x8 = ptx::mul_cvt_bf16_to_fp4_8x_round_to_nearest(elts03, elts47, + SFcoefficient); } + + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % ELTS_PER_THREAD; + const int swizzled_idx = (swizzled_group_idx + thread_offset_X_rowwise) / 2; + ptx::st_shared_b32(&sOut[buff_out][it_offset_Y_rowwise][swizzled_idx], out_x8); } } } template + bool ROW_SCALED_NVFP4> __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D_kernel( const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_output, @@ -494,17 +429,12 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D const float S_enc_rowwise = (amax_rowwise_ptr == nullptr) ? 1.0f - : core::compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); + : core::compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); const float S_enc_colwise = (amax_colwise_ptr == nullptr) ? S_enc_rowwise - : core::compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); - // Original NVFP4 uses a scalar per-tensor amax for both rowwise and columnwise output. - // If no dedicated columnwise amax buffer is allocated, the rowwise amax is that same scalar. - const float global_amax_colwise = (amax_colwise_ptr == nullptr) - ? ((amax_rowwise_ptr == nullptr) ? 1.0f : *amax_rowwise_ptr) - : *amax_colwise_ptr; + : core::compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); __shared__ uint64_t workID_mbar; __shared__ __uint128_t workID_response; @@ -652,15 +582,14 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D ptx::cp_async_bulk_wait_group_read(); // NVFP4 Quantization - rowwise_scaling(sIn_ptr, sOut_ptr, sSFrowwise_ptr, S_enc_rowwise, stage_Y, - stage_X, buff_in, buff_out, amax_rowwise_ptr, - block_offset_Y, rows, rng, random_uint4, rnd_idx); + rowwise_scaling( + sIn_ptr, sOut_ptr, sSFrowwise_ptr, S_enc_rowwise, stage_Y, stage_X, buff_in, buff_out, + amax_rowwise_ptr, block_offset_Y, rows, rng, random_uint4, rnd_idx); if constexpr (RETURN_TRANSPOSE) { - colwise_scaling( - sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, S_enc_colwise, global_amax_colwise, stage_Y, - stage_X, buff_in, buff_out_tr, rng, random_uint4, rnd_idx); + colwise_scaling( + sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, S_enc_colwise, stage_Y, stage_X, buff_in, + buff_out_tr, rng, random_uint4, rnd_idx); } // Wait for shared memory writes to be visible to TMA engine @@ -761,12 +690,7 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, using namespace ptx; const bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; - const bool use_4over6 = quant_config ? quant_config->nvfp4_4over6 : false; - const int nvfp4_e4m3_max = use_4over6 && quant_config ? quant_config->nvfp4_e4m3_max : 448; - const NVTENVFP44Over6ErrMode use_4over6_err_mode = - use_4over6 && quant_config ? quant_config->nvfp4_4over6_err_mode : kNVTENVFP44Over6ErrMAE; - const bool use_4over6_err_use_fast_math = - use_4over6 && quant_config && quant_config->nvfp4_4over6_err_use_fast_math; + const bool use_fast_math = quant_config ? quant_config->use_fast_math : false; const bool row_scaled_nvfp4 = output->row_scaled_nvfp4; // If transposed output is allocated, return the transposed data @@ -786,8 +710,6 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, "Row-scaled NVFP4 quantization requires rowwise amax."); NVTE_CHECK(!row_scaled_nvfp4 || !output->has_columnwise_data(), "Row-scaled NVFP4 quantization does not produce columnwise output."); - NVTE_CHECK(!use_4over6 || !use_stochastic_rounding, - "NVFP4 4over6 quantization does not support stochastic rounding."); if (return_transpose) { NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), @@ -821,11 +743,6 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, const float *const amax_rowwise_ptr = reinterpret_cast(output->amax.dptr); const float *const amax_colwise_ptr = reinterpret_cast(output->columnwise_amax.dptr); - if (use_4over6 && return_transpose && amax_colwise_ptr == nullptr) { - NVTE_CHECK(amax_rowwise_ptr != nullptr && output->amax.numel() == 1, - "NVFP4 4over6 quantization with columnwise output requires columnwise amax " - "or scalar per-tensor rowwise amax."); - } const NVTETensor rng_state_tensor = (quant_config != nullptr) ? quant_config->rng_state : nullptr; const size_t *rng_state = nullptr; @@ -878,51 +795,24 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, const int dshmem_size = in_mem + out_mem + out_scales_transpose_mem + out_scales_mem + TMA_SHMEM_ALIGNMENT; - if (use_4over6) { - TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, - TRANSFORMER_ENGINE_SWITCH_CONDITION( - row_scaled_nvfp4, ROW_SCALED_NVFP4, - TRANSFORMER_ENGINE_NVFP4_4OVER6_CONFIG_SWITCH( - /*USE_4OVER6_VALUE=*/true, use_4over6_err_mode, use_4over6_err_use_fast_math, - FourOverSixConfig, - TRANSFORMER_ENGINE_NVFP4_4OVER6_E4M3_MAX_SWITCH( - nvfp4_e4m3_max, E4M3_MAX, - TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { - auto kernel = quantize_transpose_nvfp4_tuned_1D_kernel< - USE_STOCHASTIC_ROUNDING, - /*USE_FAST_MATH=*/false, RETURN_TRANSPOSE, ROW_SCALED_NVFP4, E4M3_MAX, - FourOverSixConfig>; - - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - dshmem_size); - kernel<<>>( - tensor_map_input, tensor_map_output, tensor_map_output_transpose, - scales_ptr, scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, - amax_colwise_ptr, rows, cols, scale_stride, scale_stride_transpose, - rng_state); - }););););); - } else { - const bool use_fast_math = quant_config ? quant_config->use_fast_math : false; - TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, - TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_fast_math, USE_FAST_MATH, - TRANSFORMER_ENGINE_SWITCH_CONDITION( - row_scaled_nvfp4, ROW_SCALED_NVFP4, - TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { - auto kernel = quantize_transpose_nvfp4_tuned_1D_kernel< - USE_STOCHASTIC_ROUNDING, USE_FAST_MATH, RETURN_TRANSPOSE, ROW_SCALED_NVFP4, - /*E4M3_MAX=*/448, core::NVFP44Over6DisabledConfig>; - - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - dshmem_size); - kernel<<>>( - tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, - scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, - cols, scale_stride, scale_stride_transpose, rng_state); - });););); - } + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_fast_math, USE_FAST_MATH, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + row_scaled_nvfp4, ROW_SCALED_NVFP4, + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + auto kernel = + quantize_transpose_nvfp4_tuned_1D_kernel; + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, + scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, + scale_stride, scale_stride_transpose, rng_state); + });););); #else NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); #endif // FP4_TYPE_SUPPORTED diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h index 9586927508..c462b30147 100644 --- a/transformer_engine/common/transpose/cast_transpose.h +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -68,8 +68,6 @@ void quantize_transpose_vector_blockwise_fp4( const bool return_identity, const bool return_transpose, const bool pow2_scale, const bool swizzled_scale, const bool use_stochastic_rounding, const NVTETensor rng_state_tensor, const bool use_2d_quantization, const bool row_scaled_nvfp4, - const bool use_4over6, const int nvfp4_e4m3_max, - const NVTENVFP44Over6ErrMode nvfp4_4over6_err_mode, const bool nvfp4_4over6_err_use_fast_math, const SimpleTensor &noop_tensor, cudaStream_t stream); } // namespace transformer_engine::detail diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index a885ea12da..cf9821f1a9 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -14,8 +14,6 @@ #include #include -#include "common/cast/nvfp4/core_nvfp4.cuh" -#include "common/cast/nvfp4/quantize_4over6_nvfp4.cuh" #include "common/common.h" #include "common/recipe/recipe_common.cuh" #include "common/transpose/cast_transpose.h" @@ -34,8 +32,6 @@ using std::uint32_t; using std::uint8_t; using transformer_engine::detail::TypeExtrema; -namespace nvfp4_core = transformer_engine::dispatch::nvfp4::core; -using transformer_engine::dispatch::nvfp4::core::compute_global_encode_scaling_factor_FP4; // clang-format off /* @@ -191,6 +187,19 @@ __device__ __forceinline__ float ComputeOutputFP4(IType input, float encode_scal return static_cast(input) * encode_scale; } +__device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_amax) { + constexpr float fp8_max = TypeExtrema::max; + constexpr float fp4_max = TypeExtrema::max; + float global_encode_scale = fp8_max * fp4_max / global_amax; + // If scale is infinity, return max value of float32 + global_encode_scale = fminf(global_encode_scale, TypeExtrema::max); + // If global amax is 0 or infinity, return 1 + if (global_amax == 0.f || global_encode_scale == 0.f) { + return 1.f; + } + return global_encode_scale; +} + __device__ __forceinline__ uint32_t get_rbits( transformer_engine::curanddx::detail::philox4x32_native_state& rng, // NVTE_BUILD_NUM_PHILOX_ROUNDS rounds of philox4x32 @@ -307,8 +316,7 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x(const float2 in01, template + bool kApplyStochasticRounding, bool kIs2DBlockScaling, bool kRowScaledNVFP4> __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel( const IType* const input, const float* global_amax, OType* const output_c, OType* const output_t, ScaleType* const tile_scales_inv_c, ScaleType* const tile_scales_inv_t, @@ -362,12 +370,6 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo kIs2DBlockScaling ? (kFP4BlockScalingSize / kNumRowsPerWarp) : 1; __shared__ CType amax_smem_red[k2DBlockAmaxDim][k2DBlockAmaxDim][k2DBlockAmaxReduceDim]; __shared__ CType amax_smem[k2DBlockAmaxDim][k2DBlockAmaxDim]; - constexpr int k4Over62DSelectionDim = - (FourOverSixConfig::enabled && kIs2DBlockScaling) ? kFP4BlockScalingSize : 1; - using FourOverSixScratch = - nvfp4_core::QuantizationScratch4Over6; - __shared__ FourOverSixScratch four_over_six_scratch; // Step 1: Load input to shared memory { @@ -416,7 +418,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo const int kNumThreadsReduce = kScaleBlockDim / kNVecOut; const float global_encode_scale = - kIsE8Scaling ? 1.0f : compute_global_encode_scaling_factor_FP4(global_amax[0]); + kIsE8Scaling ? 1.0f : ComputeGlobalEncodeScaleFP4(global_amax[0]); constexpr float fp4_max_inv = 1.0f / TypeExtrema::max; const float global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; const float global_decode_scale = 1.0 / global_encode_scale; @@ -511,64 +513,15 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo float row_global_encode_scale = global_encode_scale; if constexpr (kRowScaledNVFP4) { row_global_encode_scale = - row_idx < num_rows - ? compute_global_encode_scaling_factor_FP4(global_amax[row_idx]) - : 1.0f; + row_idx < num_rows ? ComputeGlobalEncodeScaleFP4(global_amax[row_idx]) : 1.0f; } const float row_global_encode_scale_multiplier = kRowScaledNVFP4 ? row_global_encode_scale * fp4_max_inv : global_encode_scale_multiplier; const float row_global_decode_scale = kRowScaledNVFP4 ? 1.0f / row_global_encode_scale : global_decode_scale; - ScaleType scale_inv; - float encode_scale; - OVec output_vec; - if constexpr (FourOverSixConfig::enabled) { - const auto scaling_factors = - nvfp4_core::compute_4over6_fp4_encode_quantization_scaling_factors( - amax, row_global_encode_scale, row_global_decode_scale); - float row_global_amax; - if constexpr (kRowScaledNVFP4) { - if (row_idx < num_rows) { - row_global_amax = global_amax[row_idx]; - } else { - row_global_amax = 1.0f; - } - } else { - row_global_amax = global_amax[0]; - } - - if constexpr (kIs2DBlockScaling) { - constexpr int kNumRowsPerIter = kThreadsPerBlock / kNumThreadsStore; - const int warp_idx = threadIdx.x / kThreadsPerWarp; - const int tid_in_warp_x = threadIdx.x % kNumThreadsStore; - const int tid_in_warp_y = (threadIdx.x / kNumThreadsStore) % kNumRowsPerWarp; - const int data_row_idx = - iter * kNumRowsPerIter + warp_idx * kNumRowsPerWarp + tid_in_warp_y; - const size_t block_in_tile_y = data_row_idx / kFP4BlockScalingSize; - const size_t block_in_tile_x = tid_in_warp_x; - const size_t participant_idx = data_row_idx % kFP4BlockScalingSize; - - nvfp4_core::QuantizationCandidates4Over6 candidates; - nvfp4_core::quantize_4over6_vec2_array_candidates_16x( - smem_vec, scaling_factors, row_global_amax, candidates); - const bool pick_map4 = - nvfp4_core::record_and_select_4over6_2d_block( - scaling_factors, block_in_tile_y, block_in_tile_x, participant_idx, - four_over_six_scratch, scale_inv, candidates); - - nvfp4_core::store_selected_4over6_packed_16x(pick_map4, candidates, output_vec); - } else { - uint32_t output_vec_4over6[2]; - nvfp4_core::quantize_4over6_vec2_array_16x( - smem_vec, scaling_factors, row_global_amax, scale_inv, output_vec_4over6); - nvfp4_core::store_4over6_packed_16x(output_vec_4over6, output_vec); - } - } else { - scale_inv = ComputeDecodeScaleFP4(amax, row_global_encode_scale_multiplier); - encode_scale = ComputeEncodeScaleFP4(scale_inv, row_global_decode_scale); - } + ScaleType scale_inv = + ComputeDecodeScaleFP4(amax, row_global_encode_scale_multiplier); + float encode_scale = ComputeEncodeScaleFP4(scale_inv, row_global_decode_scale); // Step 2.5: Write scale_inv bool write_scale_inv = is_src_lane; if constexpr (!kAligned) { @@ -588,24 +541,22 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo } } // Step 2.6: Quantize - if constexpr (!FourOverSixConfig::enabled) { + OVec output_vec; #pragma unroll - for (int i = 0; i < kNVecOut / kNVecSMem; i += 2) { - // Pack two elements into __nv_bfloat162 - float2 f2_a; - float2 f2_b; - f2_a.x = ComputeOutputFP4(smem_vec[i].data.elt[0], encode_scale); - f2_a.y = ComputeOutputFP4(smem_vec[i].data.elt[1], encode_scale); - f2_b.x = ComputeOutputFP4(smem_vec[i + 1].data.elt[0], encode_scale); - f2_b.y = ComputeOutputFP4(smem_vec[i + 1].data.elt[1], encode_scale); - const uint32_t rbits = - kApplyStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0; - // Convert to __nv_fp4x4_e2m1 - __nv_fp4x4_e2m1 out_4x = cvt_fp32_to_fp4_4x(f2_a, f2_b, rbits); - - output_vec.data.elt[i] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[0]; - output_vec.data.elt[i + 1] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[1]; - } + for (int i = 0; i < kNVecOut / kNVecSMem; i += 2) { + // Pack two elements into __nv_bfloat162 + float2 f2_a; + float2 f2_b; + f2_a.x = ComputeOutputFP4(smem_vec[i].data.elt[0], encode_scale); + f2_a.y = ComputeOutputFP4(smem_vec[i].data.elt[1], encode_scale); + f2_b.x = ComputeOutputFP4(smem_vec[i + 1].data.elt[0], encode_scale); + f2_b.y = ComputeOutputFP4(smem_vec[i + 1].data.elt[1], encode_scale); + const uint32_t rbits = kApplyStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0; + // Convert to __nv_fp4x4_e2m1 + __nv_fp4x4_e2m1 out_4x = cvt_fp32_to_fp4_4x(f2_a, f2_b, rbits); + + output_vec.data.elt[i] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[0]; + output_vec.data.elt[i + 1] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[1]; } // Step 2.7: Store output_c if constexpr (kAligned) { @@ -692,48 +643,9 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo amax = __shfl_sync(mask, amax, src_lane); } // Step 3.4: Compute scale - ScaleType scale_inv; - float encode_scale; - OVec output_vec; - if constexpr (FourOverSixConfig::enabled) { - const auto scaling_factors = - nvfp4_core::compute_4over6_fp4_encode_quantization_scaling_factors( - amax, global_encode_scale, global_decode_scale); - - if constexpr (kIs2DBlockScaling) { - const int warp_idx = threadIdx.x / kThreadsPerWarp; - constexpr int kNumColsPerWarp = kThreadsPerWarp / kNumThreadsStore * kNVecSMem; - constexpr int kNumWarpsPerBlock = kThreadsPerBlock / kThreadsPerWarp; - constexpr int kNumColsPerIter = kNumColsPerWarp * kNumWarpsPerBlock; - const int tid_in_warp_x = (threadIdx.x / kNumThreadsStore) % kNumColsPerWarp; - const int tid_in_warp_y = (threadIdx.x % kThreadsPerWarp) % kNumThreadsStore; - const int data_col_idx = - iter * kNumColsPerIter + warp_idx * kNumColsPerWarp + tid_in_warp_x; - const size_t block_in_tile_y = tid_in_warp_y; - const size_t block_in_tile_x = data_col_idx / kFP4BlockScalingSize; - const size_t participant_idx = data_col_idx % kFP4BlockScalingSize; - - nvfp4_core::QuantizationCandidates4Over6 candidates; - nvfp4_core::quantize_4over6_vec_index_candidates_16x( - smem_vec, smem_idx, scaling_factors, global_amax[0], candidates); - const bool pick_map4 = - nvfp4_core::record_and_select_4over6_2d_block( - scaling_factors, block_in_tile_y, block_in_tile_x, participant_idx, - four_over_six_scratch, scale_inv, candidates); - - nvfp4_core::store_selected_4over6_packed_16x(pick_map4, candidates, output_vec); - } else { - uint32_t output_vec_4over6[2]; - nvfp4_core::quantize_4over6_vec_index_16x( - smem_vec, smem_idx, scaling_factors, global_amax[0], scale_inv, output_vec_4over6); - nvfp4_core::store_4over6_packed_16x(output_vec_4over6, output_vec); - } - } else { - scale_inv = ComputeDecodeScaleFP4(amax, global_encode_scale_multiplier); - encode_scale = ComputeEncodeScaleFP4(scale_inv, global_decode_scale); - } + ScaleType scale_inv = + ComputeDecodeScaleFP4(amax, global_encode_scale_multiplier); + float encode_scale = ComputeEncodeScaleFP4(scale_inv, global_decode_scale); // Step 3.5: Write scale_inv_t bool write_scale_inv = is_src_lane; if constexpr (!kAligned) { @@ -753,29 +665,27 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo } } // Step 3.6: Quantize - if constexpr (!FourOverSixConfig::enabled) { + OVec output_vec; #pragma unroll - for (int i = 0; i < kNVecOut / kNFP4PerContainer; i += 2) { - // Pack two elements into __nv_bfloat162 - float2 f2_a; - float2 f2_b; - f2_a.x = ComputeOutputFP4(smem_vec[2 * i].data.elt[smem_idx], - encode_scale); - f2_a.y = ComputeOutputFP4(smem_vec[2 * i + 1].data.elt[smem_idx], - encode_scale); - f2_b.x = ComputeOutputFP4(smem_vec[2 * (i + 1)].data.elt[smem_idx], - encode_scale); - f2_b.y = ComputeOutputFP4( - smem_vec[2 * (i + 1) + 1].data.elt[smem_idx], encode_scale); - const uint32_t rbits = - kApplyStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0; - // Convert to __nv_fp4x4_e2m1 - __nv_fp4x4_e2m1 out_4x = - cvt_fp32_to_fp4_4x(f2_a, f2_b, rbits); - - output_vec.data.elt[i] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[0]; - output_vec.data.elt[i + 1] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[1]; - } + for (int i = 0; i < kNVecOut / kNFP4PerContainer; i += 2) { + // Pack two elements into __nv_bfloat162 + float2 f2_a; + float2 f2_b; + f2_a.x = + ComputeOutputFP4(smem_vec[2 * i].data.elt[smem_idx], encode_scale); + f2_a.y = ComputeOutputFP4(smem_vec[2 * i + 1].data.elt[smem_idx], + encode_scale); + f2_b.x = ComputeOutputFP4(smem_vec[2 * (i + 1)].data.elt[smem_idx], + encode_scale); + f2_b.y = ComputeOutputFP4(smem_vec[2 * (i + 1) + 1].data.elt[smem_idx], + encode_scale); + const uint32_t rbits = + kApplyStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0; + // Convert to __nv_fp4x4_e2m1 + __nv_fp4x4_e2m1 out_4x = cvt_fp32_to_fp4_4x(f2_a, f2_b, rbits); + + output_vec.data.elt[i] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[0]; + output_vec.data.elt[i + 1] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[1]; } // Step 3.7: Store output_t if constexpr (kAligned) { @@ -810,8 +720,6 @@ void quantize_transpose_vector_blockwise_fp4( const bool return_identity, const bool return_transpose, const bool pow2_scale, const bool swizzled_scale, const bool use_stochastic_rounding, const NVTETensor rng_state_tensor, const bool use_2d_quantization, const bool row_scaled_nvfp4, - const bool use_4over6, const int nvfp4_e4m3_max, - const NVTENVFP44Over6ErrMode nvfp4_4over6_err_mode, const bool nvfp4_4over6_err_use_fast_math, const SimpleTensor& noop_tensor, cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_vector_blockwise_fp4); #if CUDA_VERSION >= 12080 @@ -829,16 +737,6 @@ void quantize_transpose_vector_blockwise_fp4( "Row-scaled NVFP4 quantization only supports rowwise quantization."); NVTE_CHECK(!row_scaled_nvfp4 || !use_2d_quantization, "Row-scaled NVFP4 quantization does not support 2D quantization."); - NVTE_CHECK(!use_4over6 || !use_stochastic_rounding, - "NVFP4 4over6 quantization does not support stochastic rounding."); - NVTE_CHECK(nvfp4_e4m3_max == 448 || nvfp4_e4m3_max == 256, "Unsupported NVFP4 E4M3 max (got ", - nvfp4_e4m3_max, ")"); - NVTE_CHECK(use_4over6 || nvfp4_e4m3_max == 448, - "Non-4over6 NVFP4 quantization requires E4M3 max 448."); - const NVTENVFP44Over6ErrMode use_4over6_err_mode = - use_4over6 ? nvfp4_4over6_err_mode : kNVTENVFP44Over6ErrMAE; - const bool use_4over6_err_use_fast_math = use_4over6 && nvfp4_4over6_err_use_fast_math; - const int enabled_nvfp4_e4m3_max = use_4over6 ? nvfp4_e4m3_max : 448; const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; size_t num_elements = row_length; @@ -912,104 +810,47 @@ void quantize_transpose_vector_blockwise_fp4( TRANSFORMER_ENGINE_SWITCH_CONDITION( swizzled_scale, kSwizzledScale, - if (use_4over6) { - TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_2d_quantization, kIs2DBlockScaling, - - TRANSFORMER_ENGINE_SWITCH_CONDITION( - row_scaled_nvfp4, kRowScaledNVFP4, - - TRANSFORMER_ENGINE_NVFP4_4OVER6_E4M3_MAX_SWITCH( - enabled_nvfp4_e4m3_max, kE4M3Max, - - TRANSFORMER_ENGINE_NVFP4_4OVER6_ERR_MODE_SWITCH( - use_4over6_err_mode, k4Over6ErrMode, - - TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_4over6_err_use_fast_math, - kUse4Over6ErrUseFastMath, - - size_t smem_bytes = kSMemSize * sizeof(InputType); - using FourOverSixConfig = - nvfp4_core::NVFP44Over6Config< - true, k4Over6ErrMode, - kUse4Over6ErrUseFastMath>; - auto kernel = block_scaled_1d_cast_transpose_kernel< - kReturnIdentity, kReturnTranspose, kPow2Scale, - kAligned, float, InputType, OutputType, - ScaleType, kSwizzledScale, - /*kApplyStochasticRounding=*/false, - kIs2DBlockScaling, kRowScaledNVFP4, kE4M3Max, - FourOverSixConfig>; - if (smem_bytes >= 48 * 1024) { - cudaError_t err = cudaFuncSetAttribute( - kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_bytes); - NVTE_CHECK(err == cudaSuccess, - "Failed to set dynamic shared memory " - "size."); - } kernel<<>>( - reinterpret_cast(input.dptr), - reinterpret_cast( - global_amax.dptr), - reinterpret_cast(output.dptr), - reinterpret_cast(output_t.dptr), - reinterpret_cast(scale_inv.dptr), - reinterpret_cast(scale_inv_t.dptr), - row_length, num_rows, scale_stride_x, - scale_stride_y, scale_t_stride_x, - scale_t_stride_y, kScaleBlockDim, epsilon, - rng_state, - noop_ptr);) // kUse4Over6ErrUseFastMath - ) // k4Over6ErrMode - ) // kE4M3Max - ) // kRowScaledNVFP4 - ) // kIs2DBlockScaling - } else { - TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_stochastic_rounding, kApplyStochasticRounding, - - TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_2d_quantization, kIs2DBlockScaling, - - TRANSFORMER_ENGINE_SWITCH_CONDITION( - row_scaled_nvfp4, kRowScaledNVFP4, - - size_t smem_bytes = kSMemSize * sizeof(InputType); - auto kernel = block_scaled_1d_cast_transpose_kernel< - kReturnIdentity, kReturnTranspose, kPow2Scale, kAligned, - float, InputType, OutputType, ScaleType, kSwizzledScale, - kApplyStochasticRounding, kIs2DBlockScaling, - kRowScaledNVFP4, /*kE4M3Max=*/448, - nvfp4_core::NVFP44Over6DisabledConfig>; - if (smem_bytes >= 48 * 1024) { - cudaError_t err = cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_bytes); - NVTE_CHECK(err == cudaSuccess, - "Failed to set dynamic shared memory size."); - } kernel<<>>( - reinterpret_cast(input.dptr), - reinterpret_cast(global_amax.dptr), - reinterpret_cast(output.dptr), - reinterpret_cast(output_t.dptr), - reinterpret_cast(scale_inv.dptr), - reinterpret_cast(scale_inv_t.dptr), - row_length, num_rows, scale_stride_x, scale_stride_y, - scale_t_stride_x, scale_t_stride_y, kScaleBlockDim, - epsilon, rng_state, - noop_ptr);) // kRowScaledNVFP4 - ) // kIs2DBlockScaling - ) // kApplyStochasticRounding - }) // kSwizzledScale - ) // kAligned - ) // kReturnTranspose - ) // kReturnIdentity - ) // OutputType - ) // InputType + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, kApplyStochasticRounding, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_2d_quantization, kIs2DBlockScaling, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + row_scaled_nvfp4, kRowScaledNVFP4, + + size_t smem_bytes = kSMemSize * sizeof(InputType); + auto kernel = block_scaled_1d_cast_transpose_kernel< + kReturnIdentity, kReturnTranspose, kPow2Scale, kAligned, + float, InputType, OutputType, ScaleType, kSwizzledScale, + kApplyStochasticRounding, kIs2DBlockScaling, + kRowScaledNVFP4>; + if (smem_bytes >= 48 * 1024) { + cudaError_t err = cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_bytes); + NVTE_CHECK(err == cudaSuccess, + "Failed to set dynamic shared memory size."); + } kernel<<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(global_amax.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), + row_length, num_rows, scale_stride_x, scale_stride_y, + scale_t_stride_x, scale_t_stride_y, kScaleBlockDim, + epsilon, rng_state, + noop_ptr);) // kRowScaledNVFP4 + ) // kIs2DBlockScaling + ) // kApplyStochasticRounding + ) // kSwizzledScale + ) // kAligned + ) // kReturnTranspose + ) // kReturnIdentity + ) // OutputType + ) // InputType NVTE_CHECK_CUDA(cudaGetLastError()); #else From 93dbf2beca450554de29e201a1d8821f11005fbb Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 12 May 2026 21:20:22 -0700 Subject: [PATCH 53/57] Use cp async Signed-off-by: Ziang Li --- .../cast/nvfp4/quantize_4over6_nvfp4.cuh | 147 +++++++++++++----- 1 file changed, 107 insertions(+), 40 deletions(-) diff --git a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh index d8fa24c1ea..4d7c637a55 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh @@ -26,6 +26,8 @@ #include #include +#include + #include "../../common.h" #include "../../util/math.h" #include "../../utils.cuh" @@ -71,8 +73,13 @@ constexpr int kTileRows = 128; constexpr int kTileCols = 64; constexpr int kTileColGroups = kTileCols / kGroupSize; constexpr int kTileRowGroups = kTileRows / kGroupSize; +constexpr int kPipelineStages = 2; +constexpr int kStageRows = kTileRows / kPipelineStages; +constexpr int kStageRowGroups = kStageRows / kGroupSize; constexpr int kElementsPerHalfGroup = 8; constexpr int kPackedWordsPerGroup = 2; +static_assert(kTileRows == kPipelineStages * kStageRows); +static_assert(kStageRows % kGroupSize == 0); template struct Config { @@ -152,11 +159,16 @@ template __device__ __forceinline__ void load_row_group(const IType *tile, const int row, const int col_start, float (&x0)[8], float (&x1)[8], float *amax) { + Vec x0_vec; + Vec x1_vec; + x0_vec.load_from(&tile[row * kTileCols + col_start]); + x1_vec.load_from(&tile[row * kTileCols + col_start + kElementsPerHalfGroup]); + *amax = 0.0f; #pragma unroll for (int i = 0; i < kElementsPerHalfGroup; ++i) { - const float v0 = load_input(tile, row * kTileCols + col_start + i); - const float v1 = load_input(tile, row * kTileCols + col_start + i + kElementsPerHalfGroup); + const float v0 = static_cast(x0_vec.data.elt[i]); + const float v1 = static_cast(x1_vec.data.elt[i]); x0[i] = v0; x1[i] = v1; *amax = fmaxf(*amax, fabsf(v0)); @@ -294,9 +306,9 @@ __device__ __forceinline__ float reduce_group_max_16(float value) { } __device__ __forceinline__ void store_packed_group(const uint32_t *packed, fp4e2m1x2 *dst) { - auto *dst32 = reinterpret_cast(dst); - dst32[0] = packed[0]; - dst32[1] = packed[1]; + const uint64_t packed64 = + static_cast(packed[0]) | (static_cast(packed[1]) << 32); + *reinterpret_cast(dst) = packed64; } __device__ __forceinline__ const uint32_t *select_packed(const CandidatePair &candidates, @@ -315,26 +327,55 @@ __device__ __forceinline__ nvfp4_scale_t select_scale(const ScalePair &scales, return scales.map6; } +__device__ __forceinline__ void cp_async_cg_16(void *dst, const void *src) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + const uint32_t dst_smem_ptr = __cvta_generic_to_shared(dst); + const uint64_t src_gmem_ptr = reinterpret_cast(src); + asm volatile("cp.async.cg.shared.global [%0], [%1], 16;\n" ::"r"(dst_smem_ptr), + "l"(src_gmem_ptr)); +#else + NVTE_DEVICE_ERROR("cp.async is only supported on SM 8.0+."); +#endif +} + +__device__ __forceinline__ void cp_async_commit_group() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + asm volatile("cp.async.commit_group;\n" ::); +#else + NVTE_DEVICE_ERROR("cp.async is only supported on SM 8.0+."); +#endif +} + +template +__device__ __forceinline__ void cp_async_wait_group() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); +#else + NVTE_DEVICE_ERROR("cp.async is only supported on SM 8.0+."); +#endif +} + template -__device__ void load_tile_to_shared(const IType *input, IType *tile, const size_t rows, - const size_t cols, const size_t tile_row, - const size_t tile_col) { +__device__ void load_stage_to_shared_async(const IType *input, IType *tile, const size_t rows, + const size_t cols, const size_t stage_row, + const size_t tile_col) { constexpr int vec_elems = 16 / sizeof(IType); constexpr int vecs_per_row = kTileCols / vec_elems; - constexpr int vecs = kTileRows * vecs_per_row; + constexpr int vecs = kStageRows * vecs_per_row; using TileVec = Vec; for (int idx = threadIdx.x; idx < vecs; idx += blockDim.x) { const int local_row = idx / vecs_per_row; const int local_vec_col = idx - local_row * vecs_per_row; const int local_col = local_vec_col * vec_elems; - const size_t global_row = tile_row + local_row; + const size_t global_row = stage_row + local_row; const size_t global_col = tile_col + local_col; + IType *stage_ptr = &tile[local_row * kTileCols + local_col]; - TileVec vec; if (global_row < rows && global_col + vec_elems <= cols) { - vec.load_from(&input[global_row * cols + global_col]); + cp_async_cg_16(stage_ptr, &input[global_row * cols + global_col]); } else { + TileVec vec; vec.clear(); #pragma unroll for (int i = 0; i < vec_elems; ++i) { @@ -342,23 +383,23 @@ __device__ void load_tile_to_shared(const IType *input, IType *tile, const size_ vec.data.elt[i] = input[global_row * cols + global_col + i]; } } + vec.store_to(stage_ptr); } - vec.store_to(&tile[local_row * kTileCols + local_col]); } } template -__device__ void quantize_tile_rowwise(const IType *tile, fp4e2m1x2 *output, nvfp4_scale_t *scales, - const float *amax, const size_t rows, const size_t cols, - const size_t tile_row, const size_t tile_col, - const size_t scale_stride) { - constexpr int groups = kTileRows * kTileColGroups; +__device__ void quantize_stage_rowwise(const IType *tile, fp4e2m1x2 *output, nvfp4_scale_t *scales, + const float *amax, const size_t rows, const size_t cols, + const size_t stage_row, const size_t tile_col, + const size_t scale_stride) { + constexpr int groups = kStageRows * kTileColGroups; for (int group = threadIdx.x; group < groups; group += blockDim.x) { - const int local_row = group % kTileRows; - const int local_col_group = group / kTileRows; + const int local_row = group % kStageRows; + const int local_col_group = group / kStageRows; const int local_col = local_col_group * kGroupSize; - const size_t global_row = tile_row + local_row; + const size_t global_row = stage_row + local_row; const size_t global_col = tile_col + local_col; if (global_row >= rows || global_col >= cols) { continue; @@ -400,16 +441,16 @@ __device__ void quantize_tile_rowwise(const IType *tile, fp4e2m1x2 *output, nvfp } template -__device__ void quantize_tile_colwise(const IType *tile, fp4e2m1x2 *output_t, - nvfp4_scale_t *scales_t, const float *amax, const size_t rows, - const size_t cols, const size_t tile_row, - const size_t tile_col, const size_t scale_stride_t) { - constexpr int groups = kTileRowGroups * kTileCols; +__device__ void quantize_stage_colwise(const IType *tile, fp4e2m1x2 *output_t, + nvfp4_scale_t *scales_t, const float *amax, + const size_t rows, const size_t cols, const size_t stage_row, + const size_t tile_col, const size_t scale_stride_t) { + constexpr int groups = kStageRowGroups * kTileCols; for (int group = threadIdx.x; group < groups; group += blockDim.x) { const int local_row_group = group / kTileCols; const int local_col = group - local_row_group * kTileCols; const int local_row = local_row_group * kGroupSize; - const size_t global_row = tile_row + local_row; + const size_t global_row = stage_row + local_row; const size_t global_col = tile_col + local_col; if (global_row >= rows || global_col >= cols) { continue; @@ -460,25 +501,51 @@ __global__ void __launch_bounds__(kThreads) } extern __shared__ char dynamic_shmem[]; - auto *tile = reinterpret_cast(dynamic_shmem); + auto *tiles = reinterpret_cast(dynamic_shmem); const size_t tile_col = blockIdx.x * kTileCols; const size_t tile_row = blockIdx.y * kTileRows; - load_tile_to_shared(input, tile, rows, cols, tile_row, tile_col); + IType *stage_tiles[kPipelineStages] = { + &tiles[0], + &tiles[kStageRows * kTileCols], + }; + + load_stage_to_shared_async(input, stage_tiles[0], rows, cols, tile_row, tile_col); + cp_async_commit_group(); + cp_async_wait_group<0>(); __syncthreads(); - if constexpr (RETURN_IDENTITY) { - quantize_tile_rowwise( - tile, output, scales, amax_rowwise, rows, cols, tile_row, tile_col, scale_stride); - } + for (int stage = 0; stage < kPipelineStages; ++stage) { + const int next_stage = stage + 1; + if (next_stage < kPipelineStages) { + const size_t next_stage_row = tile_row + next_stage * kStageRows; + load_stage_to_shared_async(input, stage_tiles[next_stage], rows, cols, next_stage_row, + tile_col); + cp_async_commit_group(); + } + + const size_t stage_row = tile_row + stage * kStageRows; + IType *stage_tile = stage_tiles[stage]; + + if constexpr (RETURN_IDENTITY) { + quantize_stage_rowwise( + stage_tile, output, scales, amax_rowwise, rows, cols, stage_row, tile_col, scale_stride); + } + + if constexpr (RETURN_TRANSPOSE) { + const float *columnwise_amax = amax_colwise; + if (columnwise_amax == nullptr) { + columnwise_amax = amax_rowwise; + } + quantize_stage_colwise( + stage_tile, output_t, scales_t, columnwise_amax, rows, cols, stage_row, tile_col, + scale_stride_t); + } - if constexpr (RETURN_TRANSPOSE) { - const float *columnwise_amax = amax_colwise; - if (columnwise_amax == nullptr) { - columnwise_amax = amax_rowwise; + if (next_stage < kPipelineStages) { + cp_async_wait_group<0>(); + __syncthreads(); } - quantize_tile_colwise( - tile, output_t, scales_t, columnwise_amax, rows, cols, tile_row, tile_col, scale_stride_t); } #else NVTE_DEVICE_ERROR("sm_100 or higher is required."); @@ -506,7 +573,7 @@ void launch_quantize_4over6(const Tensor &input, const Tensor *noop, Tensor *out const dim3 grid(DIVUP(cols, static_cast(kTileCols)), DIVUP(rows, static_cast(kTileRows))); const dim3 block(kThreads); - const size_t shmem = kTileRows * kTileCols * sizeof(IType); + const size_t shmem = kPipelineStages * kStageRows * kTileCols * sizeof(IType); const size_t scale_stride = return_identity ? output->scale_inv.shape[1] : 0; const size_t scale_stride_t = return_transpose ? output->columnwise_scale_inv.shape[1] : 0; From 8819d127a1ee5bfcad41d5e6affca7cfbe718bfa Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 13 May 2026 00:30:41 -0700 Subject: [PATCH 54/57] Add benchmark script Signed-off-by: Ziang Li --- benchmarks/benchmark_4over6.py | 164 +++++++++++++++--- .../cast/nvfp4/quantize_4over6_nvfp4.cuh | 9 +- 2 files changed, 144 insertions(+), 29 deletions(-) diff --git a/benchmarks/benchmark_4over6.py b/benchmarks/benchmark_4over6.py index af2b7b611a..8f4388dbe5 100644 --- a/benchmarks/benchmark_4over6.py +++ b/benchmarks/benchmark_4over6.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. +import argparse import os import torch @@ -12,12 +13,33 @@ from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer -SHAPES = [ +BENCHMARK_SHAPES = [ + (8192, 5120), + (8192, 10240), + (8192, 2560), + (8192, 11328), + (8192, 512), + (8192, 3584), + (5120, 8192), + (10240, 8192), + (2560, 8192), + (11328, 8192), + (512, 8192), + (3584, 8192), + (4096, 16384), + (14336, 16384), +] +PROFILE_SHAPES = [ (16384, 6144), ] MIN_RUN_TIME = 5 +# Nsight Compute profiling command: +# ncu -f -o nvfp4_4over6 --set=full --profile-from-start off --target-processes all \ +# python3 benchmarks/benchmark_4over6.py --profile + + def make_quantizer(use_2d_quantization, use_4over6, err_mode): return NVFP4Quantizer( fp4_dtype=tex.DType.kFloat4E2M1, @@ -58,9 +80,65 @@ def benchmark_quantize(shape, use_2d_quantization, use_4over6, err_mode, err_fas return timing.median * 1e6 -def main(): +def iter_profile_cases(): + for shape in PROFILE_SHAPES: + for mode_name, use_2d_quantization in (("1d", False), ("2d", True)): + yield shape, mode_name, "nvfp4", "MAE", False, use_2d_quantization, False + + for err_mode in ("MAE", "MSE"): + for err_fast_math in (False, True): + yield ( + shape, + mode_name, + "4over6", + err_mode, + err_fast_math, + use_2d_quantization, + True, + ) + + +def prepare_profile_case(case): + shape, mode_name, kernel, err_mode, err_fast_math, use_2d_quantization, use_4over6 = case + set_err_fast_math(err_fast_math) + x = torch.randn(shape, dtype=torch.bfloat16, device="cuda") + quantizer = make_quantizer(use_2d_quantization, use_4over6, err_mode) + out = quantizer.make_empty(shape, dtype=x.dtype, device=x.device, requires_grad=False) + quantizer.update_quantized(x, out) + torch.cuda.synchronize() + return { + "shape": shape, + "mode_name": mode_name, + "kernel": kernel, + "err_mode": err_mode, + "err_fast_math": err_fast_math, + "quantizer": quantizer, + "x": x, + "out": out, + } + + +def run_profile(profile_repeats): + cases = [prepare_profile_case(case) for case in iter_profile_cases()] + torch.cuda.synchronize() + torch.cuda.cudart().cudaProfilerStart() + for case in cases: + set_err_fast_math(case["err_fast_math"]) + print( + "PROFILE " + f"shape={case['shape']} mode={case['mode_name']} kernel={case['kernel']} " + f"err={case['err_mode']} err_fast={case['err_fast_math']}", + flush=True, + ) + for _ in range(profile_repeats): + case["quantizer"].update_quantized(case["x"], case["out"]) + torch.cuda.synchronize() + torch.cuda.cudart().cudaProfilerStop() + + +def run_benchmark(): rows = [] - for shape in SHAPES: + for shape in BENCHMARK_SHAPES: for mode_name, use_2d_quantization in (("1d", False), ("2d", True)): baseline_us = benchmark_quantize( shape=shape, @@ -69,39 +147,75 @@ def main(): err_mode="MAE", err_fast_math=False, ) - rows.append((shape, mode_name, "nvfp4", "-", "-", baseline_us, 1.0)) + rows.append((shape, mode_name, "nvfp4", "-", baseline_us, 1.0, None, None)) for err_mode in ("MAE", "MSE"): - for err_fast_math in (False, True): - timing_us = benchmark_quantize( - shape=shape, - use_2d_quantization=use_2d_quantization, - use_4over6=True, - err_mode=err_mode, - err_fast_math=err_fast_math, - ) - rows.append( - ( - shape, - mode_name, - "4over6", - err_mode, - str(err_fast_math), - timing_us, - timing_us / baseline_us, - ) + strict_timing_us = benchmark_quantize( + shape=shape, + use_2d_quantization=use_2d_quantization, + use_4over6=True, + err_mode=err_mode, + err_fast_math=False, + ) + fast_timing_us = benchmark_quantize( + shape=shape, + use_2d_quantization=use_2d_quantization, + use_4over6=True, + err_mode=err_mode, + err_fast_math=True, + ) + rows.append( + ( + shape, + mode_name, + "4over6", + err_mode, + strict_timing_us, + strict_timing_us / baseline_us, + fast_timing_us, + fast_timing_us / baseline_us, ) + ) print( f"{'shape':>18} {'mode':>4} {'kernel':>7} {'err':>3} " - f"{'err_fast':>8} {'time_us':>10} {'slowdown':>8}" + f"{'strict_us':>10} {'strict':>8} {'fast_us':>10} {'fast':>8}" ) - for shape, mode_name, kernel, err_mode, err_fast_math, timing_us, slowdown in rows: + for ( + shape, + mode_name, + kernel, + err_mode, + strict_us, + strict_slowdown, + fast_us, + fast_slowdown, + ) in rows: + fast_us_str = "-" if fast_us is None else f"{fast_us:10.3f}" + fast_slowdown_str = "-" if fast_slowdown is None else f"{fast_slowdown:8.3f}x" print( f"{str(shape):>18} {mode_name:>4} {kernel:>7} {err_mode:>3} " - f"{err_fast_math:>8} {timing_us:10.3f} {slowdown:8.3f}x" + f"{strict_us:10.3f} {strict_slowdown:8.3f}x " + f"{fast_us_str:>10} {fast_slowdown_str:>8}" ) +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--profile", action="store_true", help="Enable Nsight Compute profile mode") + parser.add_argument( + "--profile-repeats", + default=1, + type=int, + help="Number of profiled update_quantized calls per case", + ) + args = parser.parse_args() + + if args.profile: + run_profile(args.profile_repeats) + else: + run_benchmark() + + if __name__ == "__main__": main() diff --git a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh index 4d7c637a55..a0b67c8e47 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh @@ -505,10 +505,11 @@ __global__ void __launch_bounds__(kThreads) const size_t tile_col = blockIdx.x * kTileCols; const size_t tile_row = blockIdx.y * kTileRows; - IType *stage_tiles[kPipelineStages] = { - &tiles[0], - &tiles[kStageRows * kTileCols], - }; + IType *stage_tiles[kPipelineStages]; +#pragma unroll + for (int stage = 0; stage < kPipelineStages; ++stage) { + stage_tiles[stage] = &tiles[stage * kStageRows * kTileCols]; + } load_stage_to_shared_async(input, stage_tiles[0], rows, cols, tile_row, tile_col); cp_async_commit_group(); From 24e417ba0f6104a9a0de6527c8d6903da8b95d05 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 13 May 2026 00:48:02 -0700 Subject: [PATCH 55/57] Minor fix after rebase Signed-off-by: Ziang Li --- tests/cpp/operator/test_cast_nvfp4_transpose.cu | 16 +++++----------- .../pytorch/csrc/extensions/cast.cpp | 1 + 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index b13102f1f9..4825deb329 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -128,19 +128,13 @@ NVFP4FourOverSixQuantization compute_4over6_quantization_scales( const fp8e4m3 scale_map4 = static_cast(sf_high_precision_map4); const fp8e4m3 scale_map6 = static_cast(sf_high_precision_map6); - float reciprocal_map4 = 0.0f; + const float global_decode_scale = 1.0f / global_encode_scale; const float scale_map4_fp32 = static_cast(scale_map4); - if (scale_map4_fp32 != 0.0f) { - reciprocal_map4 = fminf(global_encode_scale / scale_map4_fp32, - Numeric_Traits::maxNorm); - } - - float reciprocal_map6 = 0.0f; + const float reciprocal_map4 = + fminf(1.0f / (scale_map4_fp32 * global_decode_scale), Numeric_Traits::maxNorm); const float scale_map6_fp32 = static_cast(scale_map6); - if (scale_map6_fp32 != 0.0f) { - reciprocal_map6 = fminf(global_encode_scale / scale_map6_fp32, - Numeric_Traits::maxNorm); - } + const float reciprocal_map6 = + fminf(1.0f / (scale_map6_fp32 * global_decode_scale), Numeric_Traits::maxNorm); const float2 zero = {0.0f, 0.0f}; return { diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index d64098903d..35b7f197ac 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -726,6 +726,7 @@ std::tuple, std::vector, bool> bulk_alloc const bool row_scaled_nvfp4 = quantizer_cpp_list[0]->row_scaled_nvfp4; const bool use_4over6 = quantizer_cpp_list[0]->use_4over6; const int nvfp4_e4m3_max = quantizer_cpp_list[0]->nvfp4_e4m3_max; + const auto nvfp4_4over6_err_mode = quantizer_cpp_list[0]->nvfp4_4over6_err_mode; const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage; if (row_scaled_nvfp4) { NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 bulk allocation requires rowwise usage."); From 472e5b8b92ef9ae31d33ea36f6a0df6247372053 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 13 May 2026 02:36:13 -0700 Subject: [PATCH 56/57] Naming consistency Signed-off-by: Ziang Li --- benchmarks/benchmark_4over6.py | 2 +- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 18 ++++----- .../nvfp4/test_nvfp4_quantize_exact.py | 16 ++++---- tests/pytorch/test_cpu_offloading.py | 2 +- tests/pytorch/test_fusible_ops.py | 2 +- tests/pytorch/test_quantized_tensor.py | 4 +- tests/pytorch/test_recipe.py | 12 +++--- .../common/cast/dispatch/quantize.cuh | 23 ++++++----- transformer_engine/pytorch/csrc/common.h | 2 +- .../pytorch/csrc/extensions/cast.cpp | 34 ++++++++-------- transformer_engine/pytorch/csrc/quantizer.cpp | 31 +++++++------- .../pytorch/csrc/type_converters.cpp | 4 +- .../custom_recipes/quantization_ref_nvfp4.py | 40 ++++++++++--------- transformer_engine/pytorch/quantization.py | 12 +++--- .../pytorch/tensor/grouped_tensor.py | 6 +-- .../pytorch/tensor/nvfp4_tensor.py | 40 +++++++++---------- .../tensor/storage/grouped_tensor_storage.py | 34 ++++++++-------- .../tensor/storage/nvfp4_tensor_storage.py | 14 +++---- 18 files changed, 151 insertions(+), 145 deletions(-) diff --git a/benchmarks/benchmark_4over6.py b/benchmarks/benchmark_4over6.py index 8f4388dbe5..af61417a03 100644 --- a/benchmarks/benchmark_4over6.py +++ b/benchmarks/benchmark_4over6.py @@ -52,7 +52,7 @@ def make_quantizer(use_2d_quantization, use_4over6, err_mode): with_2d_quantization=use_2d_quantization, stochastic_rounding=False, row_scaled_nvfp4=False, - use_4over6=use_4over6, + nvfp4_use_4over6=use_4over6, nvfp4_e4m3_max=448, nvfp4_4over6_err_mode=err_mode, with_random_sign_mask=True, diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 16a9387dc6..bd4d029729 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -64,7 +64,7 @@ def check_nvfp4_gemm_versus_reference( with_rht=False, with_post_rht_amax=False, row_scaled_nvfp4=row_scaled_nvfp4, - use_4over6=use_4over6, + nvfp4_use_4over6=use_4over6, nvfp4_e4m3_max=nvfp4_e4m3_max, nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) @@ -76,7 +76,7 @@ def check_nvfp4_gemm_versus_reference( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, - use_4over6=use_4over6, + nvfp4_use_4over6=use_4over6, nvfp4_e4m3_max=nvfp4_e4m3_max, nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) @@ -134,7 +134,7 @@ def check_nvfp4_gemm_versus_reference( eps=0.0, quant_tile_shape=(1, 16), row_scaled_nvfp4=row_scaled_nvfp4, - use_4over6=use_4over6, + nvfp4_use_4over6=use_4over6, nvfp4_e4m3_max=nvfp4_e4m3_max, nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) @@ -145,7 +145,7 @@ def check_nvfp4_gemm_versus_reference( pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), - use_4over6=use_4over6, + nvfp4_use_4over6=use_4over6, nvfp4_e4m3_max=nvfp4_e4m3_max, nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) @@ -268,7 +268,7 @@ def check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( with_rht=False, with_post_rht_amax=False, row_scaled_nvfp4=True, - use_4over6=use_4over6, + nvfp4_use_4over6=use_4over6, nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) w_quantizer = NVFP4Quantizer( @@ -279,7 +279,7 @@ def check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, - use_4over6=use_4over6, + nvfp4_use_4over6=use_4over6, nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) @@ -364,7 +364,7 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( with_rht=False, with_post_rht_amax=False, row_scaled_nvfp4=True, - use_4over6=use_4over6, + nvfp4_use_4over6=use_4over6, nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) x_tensorwise_quantizer = NVFP4Quantizer( @@ -375,7 +375,7 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, - use_4over6=use_4over6, + nvfp4_use_4over6=use_4over6, nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) w_quantizer = NVFP4Quantizer( @@ -386,7 +386,7 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, - use_4over6=use_4over6, + nvfp4_use_4over6=use_4over6, nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 043d71919e..5bb92f70dc 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -83,7 +83,7 @@ def check_quantization_nvfp4_versus_reference( with_post_rht_amax=False, with_2d_quantization=with_2d_quantization, row_scaled_nvfp4=row_scaled_nvfp4, - use_4over6=use_4over6, + nvfp4_use_4over6=use_4over6, nvfp4_e4m3_max=nvfp4_e4m3_max, nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) @@ -119,7 +119,7 @@ def check_quantization_nvfp4_versus_reference( eps=0.0, quant_tile_shape=quant_tile_shape, row_scaled_nvfp4=row_scaled_nvfp4, - use_4over6=use_4over6, + nvfp4_use_4over6=use_4over6, nvfp4_e4m3_max=nvfp4_e4m3_max, nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) @@ -280,7 +280,7 @@ def test_nvfp4_quantization_extrema_versus_reference( with_rht=False, with_post_rht_amax=False, row_scaled_nvfp4=row_scaled_nvfp4, - use_4over6=use_4over6, + nvfp4_use_4over6=use_4over6, nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) @@ -313,7 +313,7 @@ def test_nvfp4_quantization_extrema_versus_reference( eps=0.0, quant_tile_shape=(1, 16), row_scaled_nvfp4=row_scaled_nvfp4, - use_4over6=use_4over6, + nvfp4_use_4over6=use_4over6, nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -410,7 +410,7 @@ def test_nvfp4_quantization_boundary_values( with_rht=False, with_post_rht_amax=False, row_scaled_nvfp4=row_scaled_nvfp4, - use_4over6=use_4over6, + nvfp4_use_4over6=use_4over6, nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) @@ -443,7 +443,7 @@ def test_nvfp4_quantization_boundary_values( eps=0.0, quant_tile_shape=(1, 16), row_scaled_nvfp4=row_scaled_nvfp4, - use_4over6=use_4over6, + nvfp4_use_4over6=use_4over6, nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -526,7 +526,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( with_rht=False, with_post_rht_amax=False, row_scaled_nvfp4=row_scaled_nvfp4, - use_4over6=use_4over6, + nvfp4_use_4over6=use_4over6, nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) @@ -559,7 +559,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( eps=0.0, quant_tile_shape=(1, 16), row_scaled_nvfp4=row_scaled_nvfp4, - use_4over6=use_4over6, + nvfp4_use_4over6=use_4over6, nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) x_nvfp4_ref = ref_quantizer.quantize(x_nc) diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index 47505c5be0..35cc98a976 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -217,7 +217,7 @@ def create_tensor(recipe: Optional[recipe.Recipe], requires_grad: bool = False) with_2d_quantization=qparams.fp4_2d_quantization, stochastic_rounding=qparams.stochastic_rounding, row_scaled_nvfp4=recipe.row_scaled_activation, - use_4over6=use_4over6, + nvfp4_use_4over6=use_4over6, ) return quantizer(tensor) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 179cc417d5..82d903a39b 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -198,7 +198,7 @@ def make_reference_and_test_tensors( with_2d_quantization=with_2d_quantization, stochastic_rounding=False, with_random_sign_mask=False, - use_4over6=quantization == "nvfp4_4over6", + nvfp4_use_4over6=quantization == "nvfp4_4over6", )(test) else: raise ValueError(f"Unsupported quantization scheme ({quantization})") diff --git a/tests/pytorch/test_quantized_tensor.py b/tests/pytorch/test_quantized_tensor.py index 94f5dd040c..1ddadd8c71 100644 --- a/tests/pytorch/test_quantized_tensor.py +++ b/tests/pytorch/test_quantized_tensor.py @@ -175,7 +175,7 @@ def make_reference_and_test_tensors( stochastic_rounding=False, row_scaled_nvfp4=row_scaled_nvfp4, with_random_sign_mask=False, - use_4over6=(quantization == "nvfp4_4over6"), + nvfp4_use_4over6=(quantization == "nvfp4_4over6"), )(test) else: raise ValueError(f"Unsupported quantization scheme ({quantization})") @@ -750,7 +750,7 @@ def test_update_nd_tensor( with_post_rht_amax=False, with_2d_quantization=(quantization == "nvfp4_2d"), row_scaled_nvfp4=row_scaled_nvfp4, - use_4over6=(quantization == "nvfp4_4over6"), + nvfp4_use_4over6=(quantization == "nvfp4_4over6"), ) quantization = "nvfp4" else: diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 0c8de0526f..9ea5e738b4 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -565,7 +565,7 @@ def expected_e4m3_max(tensor_type): num_quantizers=3, ).make_quantizers() assert [q.row_scaled_nvfp4 for q in forward_quantizers] == [True, False, True] - assert [q.use_4over6 for q in forward_quantizers] == [ + assert [q.nvfp4_use_4over6 for q in forward_quantizers] == [ expected_use_4over6(tensor_type) for tensor_type in ("input", "weight", "output") ] assert [q.nvfp4_e4m3_max for q in forward_quantizers] == [ @@ -587,7 +587,7 @@ def expected_e4m3_max(tensor_type): ], ).make_quantizers() assert [q.row_scaled_nvfp4 for q in role_quantizers] == [False, True, True, True] - assert [q.use_4over6 for q in role_quantizers] == [ + assert [q.nvfp4_use_4over6 for q in role_quantizers] == [ expected_use_4over6(tensor_type) for tensor_type in ("weight", "input", "output", "input") ] assert [q.nvfp4_e4m3_max for q in role_quantizers] == [ @@ -605,7 +605,7 @@ def expected_e4m3_max(tensor_type): ], ).make_quantizers() assert [q.row_scaled_nvfp4 for q in backward_quantizers] == [False, False] - assert [q.use_4over6 for q in backward_quantizers] == [ + assert [q.nvfp4_use_4over6 for q in backward_quantizers] == [ expected_use_4over6(tensor_type) for tensor_type in ("grad_output", "grad_input") ] assert [q.nvfp4_e4m3_max for q in backward_quantizers] == [ @@ -637,17 +637,17 @@ def test_fp4_dequantize(dtype, row_scaled_nvfp4, use_4over6, M, N): q = NVFP4Quantizer( columnwise=not row_scaled_nvfp4, row_scaled_nvfp4=row_scaled_nvfp4, - use_4over6=use_4over6, + nvfp4_use_4over6=use_4over6, ) a = torch.rand((M, N)).cuda().to(dtype=dtype) starting_tensor = q(a) assert starting_tensor._row_scaled_nvfp4 == row_scaled_nvfp4 - assert starting_tensor._use_4over6 == use_4over6 + assert starting_tensor._nvfp4_use_4over6 == use_4over6 assert starting_tensor._amax_rowwise.numel() == (M if row_scaled_nvfp4 else 1) dequantized_tensor = starting_tensor.dequantize() new_tensor = q(dequantized_tensor) assert new_tensor._row_scaled_nvfp4 == row_scaled_nvfp4 - assert new_tensor._use_4over6 == use_4over6 + assert new_tensor._nvfp4_use_4over6 == use_4over6 assert new_tensor._amax_rowwise.numel() == (M if row_scaled_nvfp4 else 1) # 4over6 can re-encode a dequantized block with the alternate 4/6 scale # choice while preserving the dequantized values. diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index f25a64053f..65b267433e 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -107,10 +107,10 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, NVTE_CHECK(quant_config_cpp.nvfp4_e4m3_max == output_tensor->nvfp4_e4m3_max, "Tensor and quantization config have inconsistent options for NVFP4 4over6 " "E4M3 scale bound."); - const bool use_4over6 = quant_config_cpp.nvfp4_4over6; - NVTE_CHECK(use_4over6 || quant_config_cpp.nvfp4_e4m3_max == 448, + const bool nvfp4_use_4over6 = quant_config_cpp.nvfp4_4over6; + NVTE_CHECK(nvfp4_use_4over6 || quant_config_cpp.nvfp4_e4m3_max == 448, "Non-4over6 NVFP4 quantization requires E4M3 max 448."); - NVTE_CHECK(!use_4over6 || !quant_config_cpp.stochastic_rounding, + NVTE_CHECK(!nvfp4_use_4over6 || !quant_config_cpp.stochastic_rounding, "NVFP4 4over6 quantization does not support stochastic rounding."); if (row_scaled_nvfp4) { NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, @@ -123,7 +123,7 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, (cols % 32 == 0) && output_tensor->has_data(); // Launch NVFP4 quantize kernel - if (use_4over6) { + if (nvfp4_use_4over6) { if (quant_config_cpp.nvfp4_2d_quantization) { nvfp4::quantize_4over6( *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); @@ -273,10 +273,10 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens NVTE_CHECK(quant_config_cpp.nvfp4_e4m3_max == output_tensor->nvfp4_e4m3_max, "Tensor and quantization config have inconsistent options for NVFP4 4over6 " "E4M3 scale bound."); - const bool use_4over6 = quant_config_cpp.nvfp4_4over6; - NVTE_CHECK(use_4over6 || quant_config_cpp.nvfp4_e4m3_max == 448, + const bool nvfp4_use_4over6 = quant_config_cpp.nvfp4_4over6; + NVTE_CHECK(nvfp4_use_4over6 || quant_config_cpp.nvfp4_e4m3_max == 448, "Non-4over6 NVFP4 quantization requires E4M3 max 448."); - NVTE_CHECK(!use_4over6 || !quant_config_cpp.stochastic_rounding, + NVTE_CHECK(!nvfp4_use_4over6 || !quant_config_cpp.stochastic_rounding, "NVFP4 4over6 quantization does not support stochastic rounding."); NVTE_CHECK(!output_tensor->row_scaled_nvfp4, "Backward NVFP4 quantization does not support row-scaled outputs."); @@ -284,7 +284,7 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens (cols % 32 == 0) && output_tensor->has_data(); // Launch NVFP4 quantize kernel - if (use_4over6) { + if (nvfp4_use_4over6) { if (quant_config_cpp.nvfp4_2d_quantization) { nvfp4::quantize_4over6( *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); @@ -417,12 +417,13 @@ void group_quantize_fwd_host_aware_helper(const NVTETensor input, NVTETensor *ou "Tensor and quantization config have inconsistent options for NVFP4 4over6 " "E4M3 scale bound."); } - const bool use_4over6 = quant_config_cpp.nvfp4_4over6; - NVTE_CHECK(use_4over6 || quant_config_cpp.nvfp4_e4m3_max == 448, + const bool nvfp4_use_4over6 = quant_config_cpp.nvfp4_4over6; + NVTE_CHECK(nvfp4_use_4over6 || quant_config_cpp.nvfp4_e4m3_max == 448, "Non-4over6 NVFP4 quantization requires E4M3 max 448."); NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, "2D quantization is not supported for group quantize."); - NVTE_CHECK(!use_4over6, "NVFP4 4over6 quantization is not supported for group quantize."); + NVTE_CHECK(!nvfp4_use_4over6, + "NVFP4 4over6 quantization is not supported for group quantize."); // Launch NVFP4 group quantize kernel nvfp4::group_quantize_transpose( diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 2ce132adb4..af8b0835d5 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -328,7 +328,7 @@ class NVFP4Quantizer : public Quantizer { bool with_2d_quantization; bool stochastic_rounding; // Whether emitted NVFP4 tensors use 4over6 candidate selection. - bool use_4over6; + bool nvfp4_use_4over6; // Global E4M3 scale bound used by emitted NVFP4 tensors. int nvfp4_e4m3_max; NVTENVFP44Over6ErrMode nvfp4_4over6_err_mode; diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 35b7f197ac..89f745afd7 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -84,7 +84,7 @@ void group_quantize_nvfp4_impl(const GroupedTensorWrapper &grouped_input_tensor, // assert the 2D scaling case, since 2D scaling grouped quant kernel is not ready yet NVTE_CHECK(!nvfp4_quantizer_cpp->with_2d_quantization, "2D scaling grouped quant kernel is not ready yet"); - NVTE_CHECK(!nvfp4_quantizer_cpp->use_4over6, + NVTE_CHECK(!nvfp4_quantizer_cpp->nvfp4_use_4over6, "NVFP4 4over6 quantization is not supported for grouped quantization."); auto quant_config_cpp = QuantizationConfigWrapper(); @@ -724,9 +724,8 @@ std::tuple, std::vector, bool> bulk_alloc // Quantization parameters const auto rowwise_usage = quantizer_cpp_list[0]->rowwise_usage; const bool row_scaled_nvfp4 = quantizer_cpp_list[0]->row_scaled_nvfp4; - const bool use_4over6 = quantizer_cpp_list[0]->use_4over6; + const bool nvfp4_use_4over6 = quantizer_cpp_list[0]->nvfp4_use_4over6; const int nvfp4_e4m3_max = quantizer_cpp_list[0]->nvfp4_e4m3_max; - const auto nvfp4_4over6_err_mode = quantizer_cpp_list[0]->nvfp4_4over6_err_mode; const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage; if (row_scaled_nvfp4) { NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 bulk allocation requires rowwise usage."); @@ -871,11 +870,12 @@ std::tuple, std::vector, bool> bulk_alloc py::object amax_columnwise = columnwise_usage ? py::cast(amax_columnwise_list[i]) : py::none(); // Construct Python tensor - tensor_py_list.emplace_back(NVFP4TensorClass( - rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, amax_rowwise, - amax_columnwise, fp4_dtype, quantizer_py_list[i], with_gemm_swizzled_scales, - py::arg("row_scaled_nvfp4") = row_scaled_nvfp4, py::arg("use_4over6") = use_4over6, - py::arg("nvfp4_e4m3_max") = nvfp4_e4m3_max)); + tensor_py_list.emplace_back( + NVFP4TensorClass(rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, + amax_rowwise, amax_columnwise, fp4_dtype, quantizer_py_list[i], + with_gemm_swizzled_scales, py::arg("row_scaled_nvfp4") = row_scaled_nvfp4, + py::arg("nvfp4_use_4over6") = nvfp4_use_4over6, + py::arg("nvfp4_e4m3_max") = nvfp4_e4m3_max)); // Construct C++ tensor // Use a TensorWrapper variable to hold the output of makeTransformerEngineTensor, @@ -893,7 +893,7 @@ std::tuple, std::vector, bool> bulk_alloc columnwise_usage ? columnwise_scale_shapes[i] : std::vector{0}, scaling_mode); tensor_wrapper.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); tensor_wrapper.set_row_scaled_nvfp4(row_scaled_nvfp4); - tensor_wrapper.set_nvfp4_4over6(use_4over6); + tensor_wrapper.set_nvfp4_4over6(nvfp4_use_4over6); tensor_wrapper.set_nvfp4_e4m3_max(nvfp4_e4m3_max); // Set the amax rowwise and amax columnwise if available @@ -1005,7 +1005,7 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, cudaStream_t stream) { const size_t num_tensors = split_sections.size(); const auto &quantizer = *quantizers.front(); - NVTE_CHECK(!quantizer.use_4over6, + NVTE_CHECK(!quantizer.nvfp4_use_4over6, "NVFP4 4over6 quantization is not supported with RHT split quantization."); std::vector nvte_tensor_input_list; @@ -1043,12 +1043,12 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, need_separate_rng_states, quant_config_list, quant_config_list_colwise); for (auto &config : quant_config_list) { - config.set_nvfp4_4over6(quantizer.use_4over6); + config.set_nvfp4_4over6(quantizer.nvfp4_use_4over6); config.set_nvfp4_e4m3_max(quantizer.nvfp4_e4m3_max); config.set_nvfp4_4over6_err_mode(quantizer.nvfp4_4over6_err_mode); } for (auto &config : quant_config_list_colwise) { - config.set_nvfp4_4over6(quantizer.use_4over6); + config.set_nvfp4_4over6(quantizer.nvfp4_use_4over6); config.set_nvfp4_e4m3_max(quantizer.nvfp4_e4m3_max); config.set_nvfp4_4over6_err_mode(quantizer.nvfp4_4over6_err_mode); } @@ -1063,7 +1063,7 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, // NVFP4 4over6 candidate error math is controlled separately by // NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH. const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); - if (use_fast_math && !quantizer.use_4over6) { + if (use_fast_math && !quantizer.nvfp4_use_4over6) { for (auto &config : quant_config_list) { config.set_use_fast_math(true); } @@ -1191,7 +1191,7 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, cudaStream_t stream) { const size_t num_tensors = input_list.size(); const auto &quantizer = *quantizers.front(); - NVTE_CHECK(!quantizer.use_4over6 || !quantizer.stochastic_rounding, + NVTE_CHECK(!quantizer.nvfp4_use_4over6 || !quantizer.stochastic_rounding, "NVFP4 4over6 quantization does not support stochastic rounding."); std::vector nvte_tensor_input_list; @@ -1226,7 +1226,7 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, dummy_quant_config_list_colwise); // colwise rng states are not needed in this case for (auto &config : quant_config_list) { - config.set_nvfp4_4over6(quantizer.use_4over6); + config.set_nvfp4_4over6(quantizer.nvfp4_use_4over6); config.set_nvfp4_e4m3_max(quantizer.nvfp4_e4m3_max); config.set_nvfp4_4over6_err_mode(quantizer.nvfp4_4over6_err_mode); } @@ -1234,7 +1234,7 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, // NVFP4 4over6 candidate error math is controlled separately by // NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH. const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); - if (use_fast_math && !quantizer.use_4over6) { + if (use_fast_math && !quantizer.nvfp4_use_4over6) { for (auto &config : quant_config_list) { config.set_use_fast_math(true); } @@ -1318,7 +1318,7 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, "NVFP4 split-quantize does not support 2D quantization"); NVTE_CHECK(!quantizer.with_amax_reduction, "NVFP4 split-quantize does not support amax reduction"); - if (quantizer.use_4over6) { + if (quantizer.nvfp4_use_4over6) { NVTE_CHECK(!quantizer.with_rht, "NVFP4 4over6 quantization does not support RHT."); NVTE_CHECK(!quantizer.stochastic_rounding, "NVFP4 4over6 quantization does not support stochastic rounding."); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 6d6e173e1b..9dc592fce1 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1729,8 +1729,9 @@ NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantize this->with_post_rht_amax = quantizer.attr("with_post_rht_amax").cast(); this->with_2d_quantization = quantizer.attr("with_2d_quantization").cast(); this->stochastic_rounding = quantizer.attr("stochastic_rounding").cast(); - this->use_4over6 = quantizer.attr("use_4over6").cast(); - this->nvfp4_e4m3_max = this->use_4over6 ? quantizer.attr("nvfp4_e4m3_max").cast() : 448; + this->nvfp4_use_4over6 = quantizer.attr("nvfp4_use_4over6").cast(); + this->nvfp4_e4m3_max = + this->nvfp4_use_4over6 ? quantizer.attr("nvfp4_e4m3_max").cast() : 448; NVTE_CHECK(this->nvfp4_e4m3_max == 448 || this->nvfp4_e4m3_max == 256, "Unsupported NVFP4 E4M3 max: ", this->nvfp4_e4m3_max); const auto nvfp4_4over6_err_mode = quantizer.attr("nvfp4_4over6_err_mode").cast(); @@ -1790,7 +1791,7 @@ std::pair NVFP4Quantizer::create_tensor( "NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE, " (got shape=", shape, ")"); const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; - const bool use_4over6 = this->use_4over6; + const bool nvfp4_use_4over6 = this->nvfp4_use_4over6; const int nvfp4_e4m3_max = this->nvfp4_e4m3_max; if (row_scaled_nvfp4) { NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 quantization requires rowwise usage."); @@ -1859,7 +1860,7 @@ std::pair NVFP4Quantizer::create_tensor( kwargs["quantizer"] = this->quantizer; kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); - kwargs["use_4over6"] = py::cast(use_4over6); + kwargs["nvfp4_use_4over6"] = py::cast(nvfp4_use_4over6); kwargs["nvfp4_e4m3_max"] = py::cast(nvfp4_e4m3_max); kwargs["fake_dtype"] = GetATenDType(dtype); @@ -1891,7 +1892,7 @@ std::pair NVFP4Quantizer::create_tensor( kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); kwargs["device"] = py::cast(device); kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); - kwargs["use_4over6"] = py::cast(use_4over6); + kwargs["nvfp4_use_4over6"] = py::cast(nvfp4_use_4over6); kwargs["nvfp4_e4m3_max"] = py::cast(nvfp4_e4m3_max); py::tuple args(0); PyObject* result = PyObject_Call(reinterpret_cast(NVFP4TensorPythonClass), @@ -1926,7 +1927,7 @@ std::pair NVFP4Quantizer::create_tensor( } out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); out_cpp.set_row_scaled_nvfp4(row_scaled_nvfp4); - out_cpp.set_nvfp4_4over6(use_4over6); + out_cpp.set_nvfp4_4over6(nvfp4_use_4over6); out_cpp.set_nvfp4_e4m3_max(nvfp4_e4m3_max); this->set_quantization_params(&out_cpp); @@ -1956,7 +1957,7 @@ std::pair NVFP4Quantizer::create_grouped_tenso std::optional columnwise_amax; const std::vector logical_shape_vec = {logical_first_dim, logical_last_dim}; const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; - const bool use_4over6 = this->use_4over6; + const bool nvfp4_use_4over6 = this->nvfp4_use_4over6; const int nvfp4_e4m3_max = this->nvfp4_e4m3_max; if (row_scaled_nvfp4) { NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 grouped quantization requires rowwise usage."); @@ -2032,7 +2033,7 @@ std::pair NVFP4Quantizer::create_grouped_tenso kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); kwargs["with_gemm_swizzled_scales"] = this->optimize_for_gemm; kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); - kwargs["use_4over6"] = py::cast(use_4over6); + kwargs["nvfp4_use_4over6"] = py::cast(nvfp4_use_4over6); kwargs["nvfp4_e4m3_max"] = py::cast(nvfp4_e4m3_max); PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); if (result == nullptr) { @@ -2109,7 +2110,7 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( const auto [flat_first_dim, flat_last_dim] = get_2d_dims(shape); const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; - const bool use_4over6 = this->use_4over6; + const bool nvfp4_use_4over6 = this->nvfp4_use_4over6; const int nvfp4_e4m3_max = this->nvfp4_e4m3_max; if (row_scaled_nvfp4) { NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 quantization requires rowwise usage."); @@ -2117,7 +2118,7 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( "Row-scaled NVFP4 quantization does not support columnwise usage."); } tensor.attr("_row_scaled_nvfp4") = py::cast(row_scaled_nvfp4); - tensor.attr("_use_4over6") = py::cast(use_4over6); + tensor.attr("_nvfp4_use_4over6") = py::cast(nvfp4_use_4over6); tensor.attr("_nvfp4_e4m3_max") = py::cast(nvfp4_e4m3_max); // Coerce row-wise data @@ -2223,7 +2224,7 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( } out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); out_cpp.set_row_scaled_nvfp4(row_scaled_nvfp4); - out_cpp.set_nvfp4_4over6(use_4over6); + out_cpp.set_nvfp4_4over6(nvfp4_use_4over6); out_cpp.set_nvfp4_e4m3_max(nvfp4_e4m3_max); this->set_quantization_params(&out_cpp); @@ -2315,14 +2316,14 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou } quant_config.set_nvfp4_2d_quantization(this->with_2d_quantization); quant_config.set_stochastic_rounding(this->stochastic_rounding); - quant_config.set_nvfp4_4over6(this->use_4over6); + quant_config.set_nvfp4_4over6(this->nvfp4_use_4over6); quant_config.set_nvfp4_e4m3_max(this->nvfp4_e4m3_max); quant_config.set_nvfp4_4over6_err_mode(this->nvfp4_4over6_err_mode); - quant_config_columnwise.set_nvfp4_4over6(this->use_4over6); + quant_config_columnwise.set_nvfp4_4over6(this->nvfp4_use_4over6); quant_config_columnwise.set_nvfp4_e4m3_max(this->nvfp4_e4m3_max); quant_config_columnwise.set_nvfp4_4over6_err_mode(this->nvfp4_4over6_err_mode); - if (this->use_4over6) { + if (this->nvfp4_use_4over6) { NVTE_CHECK(!this->with_rht, "NVFP4 4over6 quantization does not support RHT."); NVTE_CHECK(!this->stochastic_rounding, "NVFP4 4over6 quantization does not support stochastic rounding."); @@ -2470,7 +2471,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou // NVFP4 4over6 candidate error math is controlled separately by // NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH. const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); - if (use_fast_math && !this->use_4over6) { + if (use_fast_math && !this->nvfp4_use_4over6) { quant_config.set_use_fast_math(true); quant_config_columnwise.set_use_fast_math(true); } diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index e25dacf76e..deb5183916 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -135,7 +135,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) const bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); const bool with_gemm_swizzled_scales = tensor.attr("_with_gemm_swizzled_scales").cast(); const bool row_scaled_nvfp4 = tensor.attr("_row_scaled_nvfp4").cast(); - const bool use_4over6 = tensor.attr("_use_4over6").cast(); + const bool nvfp4_use_4over6 = tensor.attr("_nvfp4_use_4over6").cast(); const int nvfp4_e4m3_max = tensor.attr("_nvfp4_e4m3_max").cast(); NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for NVFP4 Tensor."); @@ -167,7 +167,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) // Scale layout ret.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); ret.set_row_scaled_nvfp4(row_scaled_nvfp4); - ret.set_nvfp4_4over6(use_4over6); + ret.set_nvfp4_4over6(nvfp4_use_4over6); ret.set_nvfp4_e4m3_max(nvfp4_e4m3_max); // Quantizer state diff --git a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py index 627b098b1a..42ce32065c 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py @@ -221,7 +221,7 @@ class NVFP4TensorRef(QuantizedTensorStorage): scale_t: Optional[torch.Tensor] = None global_amax_row: Optional[torch.Tensor] = None global_amax_col: Optional[torch.Tensor] = None - use_4over6: bool = False + nvfp4_use_4over6: bool = False nvfp4_e4m3_max: int = 448 dtype: Optional[torch.dtype] = None @@ -352,7 +352,7 @@ def __init__( eps: float = 0.0, quant_tile_shape: Tuple[int, int] = (1, 16), row_scaled_nvfp4: bool = False, - use_4over6: bool = False, + nvfp4_use_4over6: bool = False, nvfp4_e4m3_max: int = 448, nvfp4_4over6_err_mode: str = "MAE", with_rht: bool = False, @@ -368,7 +368,7 @@ def __init__( raise ValueError( "Row-scaled NVFP4 reference quantization does not support columnwise usage." ) - if use_4over6: + if nvfp4_use_4over6: if pow_2_scales: raise ValueError("4over6 is only supported for NVFP4 (non-pow2) mode.") if quant_tile_shape not in ((1, 16), (16, 16)): @@ -383,8 +383,8 @@ def __init__( self.eps = eps self.quant_tile_shape = quant_tile_shape self.row_scaled_nvfp4 = row_scaled_nvfp4 - self.use_4over6 = use_4over6 - self.nvfp4_e4m3_max = nvfp4_e4m3_max if use_4over6 else 448 + self.nvfp4_use_4over6 = nvfp4_use_4over6 + self.nvfp4_e4m3_max = nvfp4_e4m3_max if nvfp4_use_4over6 else 448 if self.nvfp4_e4m3_max not in (448, 256): raise ValueError("nvfp4_e4m3_max must be 448 or 256.") self.nvfp4_4over6_err_mode = nvfp4_4over6_err_mode @@ -583,7 +583,7 @@ def _quantize_blockwise_reference( *, pow_2_scales: bool, row_scaled_nvfp4: bool = False, - use_4over6: bool = False, + nvfp4_use_4over6: bool = False, nvfp4_e4m3_max: int = 448, nvfp4_4over6_err_mode: str = "MAE", eps: float, # pylint: disable=unused-argument @@ -618,7 +618,7 @@ def _quantize_blockwise_reference( x = x.view(m, n // tile_len_x, tile_len_x) FLOAT4_E2M1_MAX = torch.tensor(6.0, device=x.device, dtype=torch.float32) FLOAT8_E4M3_MAX = torch.tensor(448.0, device=x.device, dtype=torch.float32) - global_scale_e4m3_max = float(nvfp4_e4m3_max if use_4over6 else 448) + global_scale_e4m3_max = float(nvfp4_e4m3_max if nvfp4_use_4over6 else 448) GLOBAL_SCALE_E4M3_MAX = torch.tensor( global_scale_e4m3_max, device=x.device, dtype=torch.float32 ) @@ -653,7 +653,7 @@ def _quantize_blockwise_reference( global_encode_scale, ) global_decode_scale = torch.div(1.0, global_encode_scale) - if use_4over6: + if nvfp4_use_4over6: # FourOverSix compares map-to-4 and map-to-6 candidates using # the configured original input-domain error, while keeping TE-style FP4 # quantization for each candidate. @@ -829,7 +829,7 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ self.quant_tile_shape[0], pow_2_scales=self.pow_2_scales, row_scaled_nvfp4=self.row_scaled_nvfp4, - use_4over6=self.use_4over6, + nvfp4_use_4over6=self.nvfp4_use_4over6, nvfp4_e4m3_max=self.nvfp4_e4m3_max, nvfp4_4over6_err_mode=self.nvfp4_4over6_err_mode, eps=self.eps, @@ -855,7 +855,7 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ self.quant_tile_shape[1], self.quant_tile_shape[0], pow_2_scales=self.pow_2_scales, - use_4over6=self.use_4over6, + nvfp4_use_4over6=self.nvfp4_use_4over6, nvfp4_e4m3_max=self.nvfp4_e4m3_max, nvfp4_4over6_err_mode=self.nvfp4_4over6_err_mode, eps=self.eps, @@ -897,7 +897,7 @@ def quantize( scale_t=sx_t, global_amax_row=global_amax_row, global_amax_col=global_amax_col, - use_4over6=self.use_4over6, + nvfp4_use_4over6=self.nvfp4_use_4over6, nvfp4_e4m3_max=self.nvfp4_e4m3_max, dtype=tensor.dtype, device=tensor.device, @@ -946,7 +946,7 @@ def update_quantized( dst.scale_t = sx_t dst.global_amax_row = global_amax_row dst.global_amax_col = global_amax_col - dst.use_4over6 = self.use_4over6 + dst.nvfp4_use_4over6 = self.nvfp4_use_4over6 dst.nvfp4_e4m3_max = self.nvfp4_e4m3_max dst.dtype = src.dtype dst.quant_dtype = self.dtype @@ -1053,11 +1053,15 @@ def qgemm( sx = sx.to(torch.float32) sw = sw.to(torch.float32) - qresult_x_use_4over6 = getattr( - qresult_x, "use_4over6", getattr(qresult_x, "_use_4over6", self.use_4over6) + qresult_x_nvfp4_use_4over6 = getattr( + qresult_x, + "nvfp4_use_4over6", + getattr(qresult_x, "_nvfp4_use_4over6", self.nvfp4_use_4over6), ) - qresult_w_use_4over6 = getattr( - qresult_w, "use_4over6", getattr(qresult_w, "_use_4over6", self.use_4over6) + qresult_w_nvfp4_use_4over6 = getattr( + qresult_w, + "nvfp4_use_4over6", + getattr(qresult_w, "_nvfp4_use_4over6", self.nvfp4_use_4over6), ) qresult_x_e4m3_max = getattr( qresult_x, @@ -1069,11 +1073,11 @@ def qgemm( "nvfp4_e4m3_max", getattr(qresult_w, "_nvfp4_e4m3_max", self.nvfp4_e4m3_max), ) - if qresult_x_use_4over6: + if qresult_x_nvfp4_use_4over6: fp8_max_x = float(qresult_x_e4m3_max) else: fp8_max_x = 448.0 - if qresult_w_use_4over6: + if qresult_w_nvfp4_use_4over6: fp8_max_w = float(qresult_w_e4m3_max) else: fp8_max_w = 448.0 diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index b29119798c..5872c70963 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -1655,15 +1655,15 @@ def _qparams(tensor_type: str): def _make(tensor_type: str) -> NVFP4Quantizer: qparams = _qparams(tensor_type) - use_4over6 = False + nvfp4_use_4over6 = False if self.recipe.nvfp4_4over6 == "all": - use_4over6 = True + nvfp4_use_4over6 = True elif self.recipe.nvfp4_4over6 == "weights": - use_4over6 = tensor_type == "weight" + nvfp4_use_4over6 = tensor_type == "weight" elif self.recipe.nvfp4_4over6 == "activations": - use_4over6 = tensor_type != "weight" + nvfp4_use_4over6 = tensor_type != "weight" nvfp4_e4m3_max = 448 - if use_4over6: + if nvfp4_use_4over6: if self.recipe.nvfp4_e4m3_max == "all": nvfp4_e4m3_max = 256 elif self.recipe.nvfp4_e4m3_max == "weights": @@ -1685,7 +1685,7 @@ def _make(tensor_type: str) -> NVFP4Quantizer: and tensor_type != "weight" and self.recipe.row_scaled_activation ), - use_4over6=use_4over6, + nvfp4_use_4over6=nvfp4_use_4over6, nvfp4_e4m3_max=nvfp4_e4m3_max, nvfp4_4over6_err_mode=self.recipe.nvfp4_4over6_err_mode, ) diff --git a/transformer_engine/pytorch/tensor/grouped_tensor.py b/transformer_engine/pytorch/tensor/grouped_tensor.py index 51e89d8829..0cc03602a1 100644 --- a/transformer_engine/pytorch/tensor/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/grouped_tensor.py @@ -93,7 +93,7 @@ def __new__( stride: Optional[List[int]] = None, with_gemm_swizzled_scales: bool = False, row_scaled_nvfp4: bool = False, - use_4over6: bool = False, + nvfp4_use_4over6: bool = False, nvfp4_e4m3_max: int = 448, ): if ( @@ -168,7 +168,7 @@ def __new__( columnwise_scale_inv_offsets=columnwise_scale_inv_offsets, with_gemm_swizzled_scales=with_gemm_swizzled_scales, row_scaled_nvfp4=row_scaled_nvfp4, - use_4over6=use_4over6, + nvfp4_use_4over6=nvfp4_use_4over6, nvfp4_e4m3_max=nvfp4_e4m3_max, ) return instance @@ -202,7 +202,7 @@ def copy_grouped_storage_metadata(dst: GroupedTensor, src: GroupedTensor) -> Non dst.quantized_tensors = src.quantized_tensors dst._with_gemm_swizzled_scales = src._with_gemm_swizzled_scales dst.row_scaled_nvfp4 = src.row_scaled_nvfp4 - dst.use_4over6 = src.use_4over6 + dst.nvfp4_use_4over6 = src.nvfp4_use_4over6 dst.nvfp4_e4m3_max = src.nvfp4_e4m3_max def make_wrapper_like(src: GroupedTensor, requires_grad: bool) -> GroupedTensor: diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index afc9e102d5..eb38add62b 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -131,7 +131,7 @@ class NVFP4Quantizer(Quantizer): """Whether emitted NVFP4 tensors store one FP32 amax per row.""" row_scaled_nvfp4: bool """Whether to use NVFP4 4over6 map-to-4/map-to-6 block selection.""" - use_4over6: bool + nvfp4_use_4over6: bool """Global E4M3 scale bound used by emitted NVFP4 tensors.""" nvfp4_e4m3_max: int """NVFP4 4over6 candidate-selection error mode.""" @@ -153,7 +153,7 @@ def __init__( with_2d_quantization: bool = False, stochastic_rounding: bool = False, row_scaled_nvfp4: bool = False, - use_4over6: bool = False, + nvfp4_use_4over6: bool = False, nvfp4_e4m3_max: int = 448, nvfp4_4over6_err_mode: str = "MAE", with_random_sign_mask: bool = True, @@ -167,8 +167,8 @@ def __init__( self.with_2d_quantization = with_2d_quantization self.stochastic_rounding = stochastic_rounding self.row_scaled_nvfp4 = row_scaled_nvfp4 - self.use_4over6 = use_4over6 - self.nvfp4_e4m3_max = nvfp4_e4m3_max if use_4over6 else 448 + self.nvfp4_use_4over6 = nvfp4_use_4over6 + self.nvfp4_e4m3_max = nvfp4_e4m3_max if nvfp4_use_4over6 else 448 if self.nvfp4_e4m3_max not in (448, 256): raise ValueError("nvfp4_e4m3_max must be 448 or 256.") self.nvfp4_4over6_err_mode = nvfp4_4over6_err_mode.upper() @@ -220,7 +220,7 @@ def copy(self) -> NVFP4Quantizer: with_2d_quantization=self.with_2d_quantization, stochastic_rounding=self.stochastic_rounding, row_scaled_nvfp4=self.row_scaled_nvfp4, - use_4over6=self.use_4over6, + nvfp4_use_4over6=self.nvfp4_use_4over6, nvfp4_e4m3_max=self.nvfp4_e4m3_max, nvfp4_4over6_err_mode=self.nvfp4_4over6_err_mode, ) @@ -375,7 +375,7 @@ def __new__( quantizer: Quantizer, with_gemm_swizzled_scales: bool, row_scaled_nvfp4: bool = False, - use_4over6: bool = False, + nvfp4_use_4over6: bool = False, nvfp4_e4m3_max: int = 448, **kwargs, ): @@ -392,7 +392,7 @@ def __new__( with_gemm_swizzled_scales, *args, row_scaled_nvfp4=row_scaled_nvfp4, - use_4over6=use_4over6, + nvfp4_use_4over6=nvfp4_use_4over6, nvfp4_e4m3_max=nvfp4_e4m3_max, **kwargs, ) @@ -552,7 +552,7 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m self._amax_rowwise, self._amax_columnwise, self._row_scaled_nvfp4, - self._use_4over6, + self._nvfp4_use_4over6, self._nvfp4_e4m3_max, self.shape[-1], ) @@ -578,7 +578,7 @@ def fsdp_post_all_gather( amax_rowwise, amax_columnwise, row_scaled_nvfp4, - use_4over6, + nvfp4_use_4over6, nvfp4_e4m3_max, K, ) = metadata @@ -605,7 +605,7 @@ def fsdp_post_all_gather( out._amax_rowwise = amax_rowwise out._amax_columnwise = amax_columnwise out._row_scaled_nvfp4 = row_scaled_nvfp4 - out._use_4over6 = use_4over6 + out._nvfp4_use_4over6 = nvfp4_use_4over6 out._nvfp4_e4m3_max = nvfp4_e4m3_max else: # Construct new tensor (first iteration) @@ -623,7 +623,7 @@ def fsdp_post_all_gather( requires_grad=False, with_gemm_swizzled_scales=False, row_scaled_nvfp4=row_scaled_nvfp4, - use_4over6=use_4over6, + nvfp4_use_4over6=nvfp4_use_4over6, nvfp4_e4m3_max=nvfp4_e4m3_max, ) @@ -764,7 +764,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): requires_grad=tensor.requires_grad, with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, row_scaled_nvfp4=tensor._row_scaled_nvfp4, - use_4over6=tensor._use_4over6, + nvfp4_use_4over6=tensor._nvfp4_use_4over6, nvfp4_e4m3_max=tensor._nvfp4_e4m3_max, ) @@ -786,7 +786,7 @@ def _make_in_reduce_ex( quantizer: Quantizer, with_gemm_swizzled_scales: bool = False, row_scaled_nvfp4: bool = False, - use_4over6: bool = False, + nvfp4_use_4over6: bool = False, nvfp4_e4m3_max: int = 448, ) -> NVFP4Tensor: """Build NVFP4Tensor, for use in __reduce__ @@ -809,7 +809,7 @@ def _make_in_reduce_ex( requires_grad=False, with_gemm_swizzled_scales=with_gemm_swizzled_scales, row_scaled_nvfp4=row_scaled_nvfp4, - use_4over6=use_4over6, + nvfp4_use_4over6=nvfp4_use_4over6, nvfp4_e4m3_max=nvfp4_e4m3_max, ) @@ -830,7 +830,7 @@ def __reduce_ex__(self, protocol: int) -> tuple: self._quantizer, self._with_gemm_swizzled_scales, self._row_scaled_nvfp4, - self._use_4over6, + self._nvfp4_use_4over6, self._nvfp4_e4m3_max, ), ) @@ -885,7 +885,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: self._amax_columnwise = tensor._amax_columnwise self._with_gemm_swizzled_scales = tensor._with_gemm_swizzled_scales self._row_scaled_nvfp4 = tensor._row_scaled_nvfp4 - self._use_4over6 = tensor._use_4over6 + self._nvfp4_use_4over6 = tensor._nvfp4_use_4over6 self._nvfp4_e4m3_max = tensor._nvfp4_e4m3_max return @@ -1008,7 +1008,7 @@ def forward( requires_grad=tensor.requires_grad, with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, row_scaled_nvfp4=tensor._row_scaled_nvfp4, - use_4over6=tensor._use_4over6, + nvfp4_use_4over6=tensor._nvfp4_use_4over6, nvfp4_e4m3_max=tensor._nvfp4_e4m3_max, ) @@ -1053,7 +1053,7 @@ def backward( requires_grad=grad.requires_grad, with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales, row_scaled_nvfp4=grad._row_scaled_nvfp4, - use_4over6=grad._use_4over6, + nvfp4_use_4over6=grad._nvfp4_use_4over6, nvfp4_e4m3_max=grad._nvfp4_e4m3_max, ) return dgrad, None @@ -1140,7 +1140,7 @@ def forward( requires_grad=tensor.requires_grad, with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, row_scaled_nvfp4=tensor._row_scaled_nvfp4, - use_4over6=tensor._use_4over6, + nvfp4_use_4over6=tensor._nvfp4_use_4over6, nvfp4_e4m3_max=tensor._nvfp4_e4m3_max, ) @@ -1185,7 +1185,7 @@ def backward( requires_grad=grad.requires_grad, with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales, row_scaled_nvfp4=grad._row_scaled_nvfp4, - use_4over6=grad._use_4over6, + nvfp4_use_4over6=grad._nvfp4_use_4over6, nvfp4_e4m3_max=grad._nvfp4_e4m3_max, ) return dgrad, None diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index 04e396be31..438e124021 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -73,7 +73,7 @@ def _initialize_storage_fields( stride: Optional[List[int]] = None, with_gemm_swizzled_scales: bool = False, row_scaled_nvfp4: bool = False, - use_4over6: bool = False, + nvfp4_use_4over6: bool = False, nvfp4_e4m3_max: int = 448, ) -> None: """ @@ -151,8 +151,8 @@ def _initialize_storage_fields( instance.quantized_tensors = None instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales instance.row_scaled_nvfp4 = row_scaled_nvfp4 - instance.use_4over6 = use_4over6 - instance.nvfp4_e4m3_max = nvfp4_e4m3_max if use_4over6 else 448 + instance.nvfp4_use_4over6 = nvfp4_use_4over6 + instance.nvfp4_e4m3_max = nvfp4_e4m3_max if nvfp4_use_4over6 else 448 def __new__( cls, @@ -179,7 +179,7 @@ def __new__( stride: Optional[List[int]] = None, with_gemm_swizzled_scales: bool = False, row_scaled_nvfp4: bool = False, - use_4over6: bool = False, + nvfp4_use_4over6: bool = False, nvfp4_e4m3_max: int = 448, ): instance = object.__new__(cls) @@ -207,7 +207,7 @@ def __new__( stride=stride, with_gemm_swizzled_scales=with_gemm_swizzled_scales, row_scaled_nvfp4=row_scaled_nvfp4, - use_4over6=use_4over6, + nvfp4_use_4over6=nvfp4_use_4over6, nvfp4_e4m3_max=nvfp4_e4m3_max, ) return instance @@ -325,13 +325,13 @@ def row_scaled_nvfp4(self, row_scaled_nvfp4: bool) -> None: self._row_scaled_nvfp4 = row_scaled_nvfp4 @property - def use_4over6(self) -> bool: + def nvfp4_use_4over6(self) -> bool: """Whether grouped NVFP4 tensors carry 4over6 metadata.""" - return self._use_4over6 + return self._nvfp4_use_4over6 - @use_4over6.setter - def use_4over6(self, use_4over6: bool) -> None: - self._use_4over6 = use_4over6 + @nvfp4_use_4over6.setter + def nvfp4_use_4over6(self, nvfp4_use_4over6: bool) -> None: + self._nvfp4_use_4over6 = nvfp4_use_4over6 @property def nvfp4_e4m3_max(self) -> int: @@ -411,7 +411,7 @@ def clear(self) -> None: self.tensor_shapes = [] self.fake_dtype = torch.float32 self.row_scaled_nvfp4 = False - self.use_4over6 = False + self.nvfp4_use_4over6 = False self.nvfp4_e4m3_max = 448 def __repr__(self) -> str: @@ -582,7 +582,7 @@ def copy(self) -> "GroupedTensorStorage": columnwise_scale_inv_offsets=self.columnwise_scale_inv_offsets, with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, row_scaled_nvfp4=self.row_scaled_nvfp4, - use_4over6=self.use_4over6, + nvfp4_use_4over6=self.nvfp4_use_4over6, nvfp4_e4m3_max=self.nvfp4_e4m3_max, ) @@ -695,7 +695,7 @@ def make_grouped_tensor( scale_inv_offsets = None columnwise_scale_inv_offsets = None row_scaled_nvfp4 = False - use_4over6 = False + nvfp4_use_4over6 = False nvfp4_e4m3_max = 448 if no_quantization: assert dtype is not None, "dtype must be provided for unquantized GroupedTensor" @@ -756,7 +756,7 @@ def make_grouped_tensor( amax = torch.empty(num_tensors, dtype=torch.float32, device=device) elif quantizer._get_compatible_recipe().nvfp4(): row_scaled_nvfp4 = quantizer.row_scaled_nvfp4 - use_4over6 = quantizer.use_4over6 + nvfp4_use_4over6 = quantizer.nvfp4_use_4over6 nvfp4_e4m3_max = quantizer.nvfp4_e4m3_max if row_scaled_nvfp4: if not rowwise_usage: @@ -886,7 +886,7 @@ def make_grouped_tensor( quantizer.optimize_for_gemm if quantizer is not None else False ), row_scaled_nvfp4=row_scaled_nvfp4, - use_4over6=use_4over6, + nvfp4_use_4over6=nvfp4_use_4over6, nvfp4_e4m3_max=nvfp4_e4m3_max, ) grouped_tensor.quantized_tensors = grouped_tensor.split_into_quantized_tensors() @@ -1002,7 +1002,7 @@ def split_into_quantized_tensors( self.columnwise_scale_inv_offsets = columnwise_scale_inv_offsets nvfp4_rowwise_amax_offsets = None row_scaled_nvfp4 = self.row_scaled_nvfp4 - use_4over6 = self.use_4over6 + nvfp4_use_4over6 = self.nvfp4_use_4over6 nvfp4_e4m3_max = self.nvfp4_e4m3_max if recipe.nvfp4() and row_scaled_nvfp4: cum = 0 @@ -1231,7 +1231,7 @@ def split_into_quantized_tensors( quantizer=quantizer, with_gemm_swizzled_scales=quantizer.optimize_for_gemm, row_scaled_nvfp4=row_scaled_nvfp4, - use_4over6=use_4over6, + nvfp4_use_4over6=nvfp4_use_4over6, nvfp4_e4m3_max=nvfp4_e4m3_max, ) result.append(tensor) diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index cc165d5c3e..7a7f6012dd 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -100,7 +100,7 @@ class NVFP4TensorStorage(QuantizedTensorStorage): # Whether this NVFP4 tensor uses row-scaled amax metadata _row_scaled_nvfp4: bool # Whether this NVFP4 tensor uses 4over6 map-to-4/map-to-6 block selection - _use_4over6: bool + _nvfp4_use_4over6: bool # Global E4M3 scale bound used by this NVFP4 tensor _nvfp4_e4m3_max: int @@ -118,7 +118,7 @@ def __new__( *args, fake_dtype: Optional[torch.dtype] = None, row_scaled_nvfp4: bool = False, - use_4over6: bool = False, + nvfp4_use_4over6: bool = False, nvfp4_e4m3_max: int = 448, **kwargs, ): @@ -138,8 +138,8 @@ def __new__( instance._amax_columnwise = amax_columnwise instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales instance._row_scaled_nvfp4 = row_scaled_nvfp4 - instance._use_4over6 = use_4over6 - instance._nvfp4_e4m3_max = nvfp4_e4m3_max if use_4over6 else 448 + instance._nvfp4_use_4over6 = nvfp4_use_4over6 + instance._nvfp4_e4m3_max = nvfp4_e4m3_max if nvfp4_use_4over6 else 448 return instance @@ -166,7 +166,7 @@ def copy_from_storage(self, src: QuantizedTensorStorage) -> None: raise RuntimeError("Scale layout mismatch in copy_from_storage") if self._row_scaled_nvfp4 != src._row_scaled_nvfp4: raise RuntimeError("Rowwise amax scaling mode mismatch in copy_from_storage") - if self._use_4over6 != src._use_4over6: + if self._nvfp4_use_4over6 != src._nvfp4_use_4over6: raise RuntimeError("NVFP4 4over6 mode mismatch in copy_from_storage") if self._nvfp4_e4m3_max != src._nvfp4_e4m3_max: raise RuntimeError("NVFP4 4over6 E4M3 scale bound mismatch in copy_from_storage") @@ -195,7 +195,7 @@ def get_metadata(self) -> Dict[str, Any]: "quantizer": self._quantizer, "with_gemm_swizzled_scales": self._with_gemm_swizzled_scales, "row_scaled_nvfp4": self._row_scaled_nvfp4, - "use_4over6": self._use_4over6, + "nvfp4_use_4over6": self._nvfp4_use_4over6, "nvfp4_e4m3_max": self._nvfp4_e4m3_max, "fake_dtype": self._dtype, } @@ -330,7 +330,7 @@ def view(self, shape: torch.Size): fp4_dtype=self._fp4_dtype, with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, row_scaled_nvfp4=self._row_scaled_nvfp4, - use_4over6=self._use_4over6, + nvfp4_use_4over6=self._nvfp4_use_4over6, nvfp4_e4m3_max=self._nvfp4_e4m3_max, fake_dtype=self._dtype, ) From 83e230873f00676d4966ca151de22c8bfc68a77f Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 13 May 2026 02:36:40 -0700 Subject: [PATCH 57/57] Remove 4over6 benchmark Signed-off-by: Ziang Li --- benchmarks/benchmark_4over6.py | 221 --------------------------------- 1 file changed, 221 deletions(-) delete mode 100644 benchmarks/benchmark_4over6.py diff --git a/benchmarks/benchmark_4over6.py b/benchmarks/benchmark_4over6.py deleted file mode 100644 index af61417a03..0000000000 --- a/benchmarks/benchmark_4over6.py +++ /dev/null @@ -1,221 +0,0 @@ -# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -import argparse -import os - -import torch -import torch.utils.benchmark as benchmark -import transformer_engine.pytorch as te -import transformer_engine_torch as tex - -from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer - - -BENCHMARK_SHAPES = [ - (8192, 5120), - (8192, 10240), - (8192, 2560), - (8192, 11328), - (8192, 512), - (8192, 3584), - (5120, 8192), - (10240, 8192), - (2560, 8192), - (11328, 8192), - (512, 8192), - (3584, 8192), - (4096, 16384), - (14336, 16384), -] -PROFILE_SHAPES = [ - (16384, 6144), -] -MIN_RUN_TIME = 5 - - -# Nsight Compute profiling command: -# ncu -f -o nvfp4_4over6 --set=full --profile-from-start off --target-processes all \ -# python3 benchmarks/benchmark_4over6.py --profile - - -def make_quantizer(use_2d_quantization, use_4over6, err_mode): - return NVFP4Quantizer( - fp4_dtype=tex.DType.kFloat4E2M1, - rowwise=True, - columnwise=True, - with_amax_reduction=False, - amax_reduction_group=None, - with_rht=False, - with_post_rht_amax=False, - with_2d_quantization=use_2d_quantization, - stochastic_rounding=False, - row_scaled_nvfp4=False, - nvfp4_use_4over6=use_4over6, - nvfp4_e4m3_max=448, - nvfp4_4over6_err_mode=err_mode, - with_random_sign_mask=True, - ) - - -def set_err_fast_math(enabled): - os.environ["NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH"] = "1" if enabled else "0" - - -def benchmark_quantize(shape, use_2d_quantization, use_4over6, err_mode, err_fast_math): - set_err_fast_math(err_fast_math) - - x = torch.randn(shape, dtype=torch.bfloat16, device="cuda") - quantizer = make_quantizer(use_2d_quantization, use_4over6, err_mode) - out = quantizer.make_empty(shape, dtype=x.dtype, device=x.device, requires_grad=False) - quantizer.update_quantized(x, out) - torch.cuda.synchronize() - - timing = benchmark.Timer( - stmt="quantizer.update_quantized(x, out)", - globals={"quantizer": quantizer, "x": x, "out": out}, - num_threads=1, - ).blocked_autorange(min_run_time=MIN_RUN_TIME) - return timing.median * 1e6 - - -def iter_profile_cases(): - for shape in PROFILE_SHAPES: - for mode_name, use_2d_quantization in (("1d", False), ("2d", True)): - yield shape, mode_name, "nvfp4", "MAE", False, use_2d_quantization, False - - for err_mode in ("MAE", "MSE"): - for err_fast_math in (False, True): - yield ( - shape, - mode_name, - "4over6", - err_mode, - err_fast_math, - use_2d_quantization, - True, - ) - - -def prepare_profile_case(case): - shape, mode_name, kernel, err_mode, err_fast_math, use_2d_quantization, use_4over6 = case - set_err_fast_math(err_fast_math) - x = torch.randn(shape, dtype=torch.bfloat16, device="cuda") - quantizer = make_quantizer(use_2d_quantization, use_4over6, err_mode) - out = quantizer.make_empty(shape, dtype=x.dtype, device=x.device, requires_grad=False) - quantizer.update_quantized(x, out) - torch.cuda.synchronize() - return { - "shape": shape, - "mode_name": mode_name, - "kernel": kernel, - "err_mode": err_mode, - "err_fast_math": err_fast_math, - "quantizer": quantizer, - "x": x, - "out": out, - } - - -def run_profile(profile_repeats): - cases = [prepare_profile_case(case) for case in iter_profile_cases()] - torch.cuda.synchronize() - torch.cuda.cudart().cudaProfilerStart() - for case in cases: - set_err_fast_math(case["err_fast_math"]) - print( - "PROFILE " - f"shape={case['shape']} mode={case['mode_name']} kernel={case['kernel']} " - f"err={case['err_mode']} err_fast={case['err_fast_math']}", - flush=True, - ) - for _ in range(profile_repeats): - case["quantizer"].update_quantized(case["x"], case["out"]) - torch.cuda.synchronize() - torch.cuda.cudart().cudaProfilerStop() - - -def run_benchmark(): - rows = [] - for shape in BENCHMARK_SHAPES: - for mode_name, use_2d_quantization in (("1d", False), ("2d", True)): - baseline_us = benchmark_quantize( - shape=shape, - use_2d_quantization=use_2d_quantization, - use_4over6=False, - err_mode="MAE", - err_fast_math=False, - ) - rows.append((shape, mode_name, "nvfp4", "-", baseline_us, 1.0, None, None)) - - for err_mode in ("MAE", "MSE"): - strict_timing_us = benchmark_quantize( - shape=shape, - use_2d_quantization=use_2d_quantization, - use_4over6=True, - err_mode=err_mode, - err_fast_math=False, - ) - fast_timing_us = benchmark_quantize( - shape=shape, - use_2d_quantization=use_2d_quantization, - use_4over6=True, - err_mode=err_mode, - err_fast_math=True, - ) - rows.append( - ( - shape, - mode_name, - "4over6", - err_mode, - strict_timing_us, - strict_timing_us / baseline_us, - fast_timing_us, - fast_timing_us / baseline_us, - ) - ) - - print( - f"{'shape':>18} {'mode':>4} {'kernel':>7} {'err':>3} " - f"{'strict_us':>10} {'strict':>8} {'fast_us':>10} {'fast':>8}" - ) - for ( - shape, - mode_name, - kernel, - err_mode, - strict_us, - strict_slowdown, - fast_us, - fast_slowdown, - ) in rows: - fast_us_str = "-" if fast_us is None else f"{fast_us:10.3f}" - fast_slowdown_str = "-" if fast_slowdown is None else f"{fast_slowdown:8.3f}x" - print( - f"{str(shape):>18} {mode_name:>4} {kernel:>7} {err_mode:>3} " - f"{strict_us:10.3f} {strict_slowdown:8.3f}x " - f"{fast_us_str:>10} {fast_slowdown_str:>8}" - ) - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--profile", action="store_true", help="Enable Nsight Compute profile mode") - parser.add_argument( - "--profile-repeats", - default=1, - type=int, - help="Number of profiled update_quantized calls per case", - ) - args = parser.parse_args() - - if args.profile: - run_profile(args.profile_repeats) - else: - run_benchmark() - - -if __name__ == "__main__": - main()