diff --git a/docs/envvars.rst b/docs/envvars.rst index ffbad409d4..2a9a23fd2f 100644 --- a/docs/envvars.rst +++ b/docs/envvars.rst @@ -287,6 +287,30 @@ Kernel Configuration :Default: ``0`` :Description: Enable row-scaled NVFP4 tensors for forward activation quantizers in the ``NVFP4BlockScaling`` recipe. When set to ``1`` (or when ``NVFP4BlockScaling(row_scaled_activation=True)`` is used), rowwise ``amax`` metadata is stored as one FP32 value per tensor row instead of a single scalar. +.. envvar:: NVTE_NVFP4_4OVER6 + + :Type: ``str`` (``weights``, ``activations``, or ``all``) + :Default: unset + :Description: Enable per-block map-to-4 versus map-to-6 candidate selection for selected NVFP4 quantizers in the ``NVFP4BlockScaling`` recipe. ``weights`` selects weight tensor roles, ``activations`` selects non-weight tensor roles, and ``all`` selects both. The selected block scale is the candidate with lower configured input-domain error, and ties select map-to-6. By default, this mode keeps the standard NVFP4 global E4M3 scale bound of 448. Tensors using 4over6 currently require RHT and stochastic rounding to be disabled; activation and backward scopes therefore require ``NVTE_NVFP4_DISABLE_RHT=1`` and ``NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING=1``. + +.. envvar:: NVTE_NVFP4_4OVER6_E4M3_USE_256 + + :Type: ``str`` (``weights``, ``activations``, or ``all``) + :Default: unset + :Description: Select NVFP4 4over6 quantizers that use 256 instead of 448 as the global E4M3 scale bound. ``weights`` selects weight tensor roles, ``activations`` selects non-weight tensor roles, and ``all`` selects both. This option is only meaningful for tensor roles that also enable :envvar:`NVTE_NVFP4_4OVER6`. + +.. envvar:: NVTE_NVFP4_4OVER6_ERR_MODE + + :Type: ``str`` (``MAE`` or ``MSE``) + :Default: ``MAE`` + :Description: Select the input-domain error metric used by NVFP4 4over6 map-to-4 versus map-to-6 candidate selection in the ``NVFP4BlockScaling`` recipe. + +.. envvar:: NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH + + :Type: ``int`` (0 or 1) + :Default: ``0`` + :Description: Allow the NVFP4 4over6 candidate error computation to use faster non-strict floating-point expressions. By default, 4over6 error comparison uses strict expressions; ``NVTE_USE_FAST_MATH`` does not control this error-comparison path. + Torch Compilation and Fusion ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index a8f58f8598..4825deb329 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -62,12 +62,14 @@ std::vector create_transpose(const InputType* const input, const size } // Compute the global encode scale factor for a given global amax -float compute_global_encode_scaling_factor_FP4(const float global_amax, const bool use_fast_math) { - constexpr float fp8_max = 448.0f; // 448.0f; +float compute_global_encode_scaling_factor_FP4(const float global_amax, const bool use_fast_math, + const int e4m3_max = 448) { + NVTE_CHECK(e4m3_max == 448 || e4m3_max == 256, "Unsupported NVFP4 E4M3 max."); + const float fp8_max = static_cast(e4m3_max); constexpr float fp4_max = 6.0f; // 6.0f; float global_encode_scale = fp8_max * fp4_max / global_amax; // If scale is infinity, return the max normalized value - const float max_norm_clamp = use_fast_math + const float max_norm_clamp = (use_fast_math && e4m3_max == 448) ? Numeric_Traits::maxNorm : Numeric_Traits::maxNorm; @@ -79,6 +81,108 @@ float compute_global_encode_scaling_factor_FP4(const float global_amax, const bo return global_encode_scale; } +struct NVFP4FourOverSixQuantization { + fp8e4m3 scale_map4; + fp8e4m3 scale_map6; + float reciprocal_map4; + float reciprocal_map6; + fp4e2m1x2 quantized_map4; + fp4e2m1x2 quantized_map6; +}; + +enum class NVFP4FourOverSixCandidate { + Map4, + Map6, +}; + +enum class NVFP4ScalingMode { + Block1D, + RowScaled1D, + Block2D, +}; + +struct NVFP4FourOverSixTestConfig { + bool enabled = false; + int e4m3_max = 448; + NVTENVFP44Over6ErrMode err_mode = kNVTENVFP44Over6ErrMAE; + bool err_use_fast_math = false; +}; + +bool use_2d_quantization(const NVFP4ScalingMode scaling_mode) { + return scaling_mode == NVFP4ScalingMode::Block2D; +} + +bool row_scaled_nvfp4(const NVFP4ScalingMode scaling_mode) { + return scaling_mode == NVFP4ScalingMode::RowScaled1D; +} + +NVFP4FourOverSixQuantization compute_4over6_quantization_scales( + const float block_amax, const float global_encode_scale) { + constexpr float fp4_max = 6.0f; + constexpr float fp8_max = 448.0f; + constexpr float scale_expansion_factor = 1.5f; + const float base_sf_high_precision = block_amax / fp4_max * global_encode_scale; + const float sf_high_precision_map4 = + fminf(base_sf_high_precision * scale_expansion_factor, fp8_max); + const float sf_high_precision_map6 = fminf(base_sf_high_precision, fp8_max); + const fp8e4m3 scale_map4 = static_cast(sf_high_precision_map4); + const fp8e4m3 scale_map6 = static_cast(sf_high_precision_map6); + + const float global_decode_scale = 1.0f / global_encode_scale; + const float scale_map4_fp32 = static_cast(scale_map4); + const float reciprocal_map4 = + fminf(1.0f / (scale_map4_fp32 * global_decode_scale), Numeric_Traits::maxNorm); + const float scale_map6_fp32 = static_cast(scale_map6); + const float reciprocal_map6 = + fminf(1.0f / (scale_map6_fp32 * global_decode_scale), Numeric_Traits::maxNorm); + + const float2 zero = {0.0f, 0.0f}; + return { + scale_map4, + scale_map6, + reciprocal_map4, + reciprocal_map6, + fp4e2m1x2(zero), + fp4e2m1x2(zero), + }; +} + +fp8e4m3 select_4over6_scale(const NVFP4FourOverSixQuantization& quantization, + const NVFP4FourOverSixCandidate candidate) { + if (candidate == NVFP4FourOverSixCandidate::Map4) { + return quantization.scale_map4; + } + return quantization.scale_map6; +} + +fp4e2m1x2 select_4over6_quantized_pair(const NVFP4FourOverSixQuantization& quantization, + const NVFP4FourOverSixCandidate candidate) { + if (candidate == NVFP4FourOverSixCandidate::Map4) { + return quantization.quantized_map4; + } + return quantization.quantized_map6; +} + +NVFP4FourOverSixQuantization quantize_4over6_pair( + const float x, const float y, const NVFP4FourOverSixQuantization& quantization) { + const float2 scaled_map4 = {x * quantization.reciprocal_map4, + y * quantization.reciprocal_map4}; + const fp4e2m1x2 quantized_map4(scaled_map4); + + const float2 scaled_map6 = {x * quantization.reciprocal_map6, + y * quantization.reciprocal_map6}; + const fp4e2m1x2 quantized_map6(scaled_map6); + + return { + quantization.scale_map4, + quantization.scale_map6, + quantization.reciprocal_map4, + quantization.reciprocal_map6, + quantized_map4, + quantized_map6, + }; +} + // 1D Scaling: Original implementation with 1x16 blocks template void quantize_nvfp4_1d(float (*OP)(const float), @@ -89,10 +193,15 @@ void quantize_nvfp4_1d(float (*OP)(const float), const size_t cols, const size_t scales_stride, const float global_amax, - const bool use_fast_math) { + const bool use_fast_math, + const bool use_4over6 = false, + const int e4m3_max = 448, + const NVFP4FourOverSixCandidate four_over_six_candidate = + NVFP4FourOverSixCandidate::Map6) { // Compute a global encoding/decoding scaling factor for all S_dec_b - const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math); + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math, + e4m3_max); constexpr size_t block_size_X = 16; const size_t blocks_X = divide_round_up(cols, block_size_X); @@ -122,6 +231,27 @@ void quantize_nvfp4_1d(float (*OP)(const float), block_amax = std::max(block_amax, std::abs(elt)); } + const size_t scale_idx = i * scales_stride + block_X; + + if (use_4over6) { + const NVFP4FourOverSixQuantization quantization = + compute_4over6_quantization_scales(block_amax, S_enc); + scales[scale_idx] = select_4over6_scale(quantization, four_over_six_candidate); + + for (size_t j = j_min; j < j_max; j += 2) { + const int idx_pair = (i * cols + j) / 2; + const int cache_idx_x = j - j_min; + const int cache_idx_y = cache_idx_x + 1; + const float cached_x = cache_buffer[cache_idx_x]; + const float cached_y = cache_buffer[cache_idx_y]; + const NVFP4FourOverSixQuantization pair_quantization = + quantize_4over6_pair(cached_x, cached_y, quantization); + output[idx_pair] = + select_4over6_quantized_pair(pair_quantization, four_over_six_candidate); + } + continue; + } + // Compute and store the per-block FP8 decode scale const float S_dec_b = block_amax * (S_enc * (1.0f / 6.0f)); const fp8e4m3 S_dec_b_fp8 = static_cast(fminf(S_dec_b, Numeric_Traits::maxNorm)); @@ -131,7 +261,6 @@ void quantize_nvfp4_1d(float (*OP)(const float), const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : fminf(1.0f / (S_dec_b_fp32 * (1.0f / S_enc)), Numeric_Traits::maxNorm); - const size_t scale_idx = i * scales_stride + block_X; scales[scale_idx] = S_dec_b_fp8; float scale_reciprocal = S_enc_b_fp8; @@ -167,9 +296,14 @@ void compute_2d_mathematical_scales(float (*OP)(const float), const size_t cols, const float global_amax, std::vector>& math_scales, - const bool use_fast_math) { - - const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math); + const bool use_fast_math, + const bool use_4over6 = false, + const int e4m3_max = 448, + const NVFP4FourOverSixCandidate four_over_six_candidate = + NVFP4FourOverSixCandidate::Map6) { + + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math, + e4m3_max); constexpr size_t block_size_Y = 16; constexpr size_t block_size_X = 16; const size_t blocks_Y = divide_round_up(rows, block_size_Y); @@ -197,9 +331,16 @@ void compute_2d_mathematical_scales(float (*OP)(const float), } // Compute E4M3 scaling factor for this 16x16 block - const float S_dec_b = block_amax / 6.0f; - const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); - math_scales[block_Y][block_X] = S_dec_b_fp8; + if (use_4over6) { + const NVFP4FourOverSixQuantization quantization = + compute_4over6_quantization_scales(block_amax, S_enc); + math_scales[block_Y][block_X] = + select_4over6_scale(quantization, four_over_six_candidate); + } else { + const float S_dec_b = block_amax / 6.0f * S_enc; + const fp8e4m3 S_dec_b_fp8_map6 = static_cast(S_dec_b); + math_scales[block_Y][block_X] = S_dec_b_fp8_map6; + } } } } @@ -214,13 +355,19 @@ void quantize_nvfp4_2d(float (*OP)(const float), const size_t cols, const size_t scales_stride, const float global_amax, - const bool use_fast_math) { + const bool use_fast_math, + const bool use_4over6 = false, + const int e4m3_max = 448, + const NVFP4FourOverSixCandidate four_over_six_candidate = + NVFP4FourOverSixCandidate::Map6) { // Step 1: Compute mathematical 8x8 scaling factors std::vector> math_scales; - compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math); + compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math, + use_4over6, e4m3_max, four_over_six_candidate); - const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math); + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math, + e4m3_max); constexpr size_t block_size_Y = 16; constexpr size_t block_size_X = 16; const size_t blocks_Y = divide_round_up(rows, block_size_Y); @@ -250,7 +397,7 @@ void quantize_nvfp4_2d(float (*OP)(const float), // Get the scaling factor for this block const float S_dec_b_fp8 = static_cast(math_scales[block_Y][block_X]); - const float S_enc_b_fp8 = S_dec_b_fp8 == 0 ? 0.f : S_enc / S_dec_b_fp8; + const float S_enc_b_fp8 = S_dec_b_fp8 == 0.0f ? 0.0f : S_enc / S_dec_b_fp8; const float scale_reciprocal = S_enc_b_fp8; // Process and cache data for this 16x16 block @@ -302,11 +449,17 @@ void quantize_nvfp4(float (*OP)(const float), const size_t scales_stride, const float global_amax, const bool use_fast_math, - const bool use_2d_quantization = false) { + const bool use_2d_quantization = false, + const bool use_4over6 = false, + const int e4m3_max = 448, + const NVFP4FourOverSixCandidate four_over_six_candidate = + NVFP4FourOverSixCandidate::Map6) { if (use_2d_quantization) { - quantize_nvfp4_2d(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_fast_math); + quantize_nvfp4_2d(OP, input, output, scales, rows, cols, scales_stride, global_amax, + use_fast_math, use_4over6, e4m3_max, four_over_six_candidate); } else { - quantize_nvfp4_1d(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_fast_math); + quantize_nvfp4_1d(OP, input, output, scales, rows, cols, scales_stride, global_amax, + use_fast_math, use_4over6, e4m3_max, four_over_six_candidate); } } @@ -324,7 +477,11 @@ void compute_ref(float (*OP)(const float), const size_t scales_stride_t, const bool use_fast_math, const bool use_2d_quantization = false, - const bool row_scaled_nvfp4 = false) + const bool row_scaled_nvfp4 = false, + const bool use_4over6 = false, + const int e4m3_max = 448, + const NVFP4FourOverSixCandidate four_over_six_candidate = + NVFP4FourOverSixCandidate::Map6) { std::vector input_t = create_transpose(input, rows, cols); NVTE_CHECK(!(use_2d_quantization && row_scaled_nvfp4), @@ -334,7 +491,8 @@ void compute_ref(float (*OP)(const float), if (use_2d_quantization) { // Step 1: Compute mathematical 8×8 scaling factors std::vector> math_scales; - compute_2d_mathematical_scales(OP, input, rows, cols, *amax, math_scales, use_fast_math); + compute_2d_mathematical_scales(OP, input, rows, cols, *amax, math_scales, use_fast_math, + use_4over6, e4m3_max, four_over_six_candidate); constexpr size_t block_size_Y = 16; constexpr size_t block_size_X = 16; @@ -362,9 +520,11 @@ void compute_ref(float (*OP)(const float), // Step 4: Process quantized outputs using the same algorithm as quantize_nvfp4_2d // (This part processes the actual FP4 data using the mathematical scaling factors) quantize_nvfp4_2d(OP, input, output, nullptr, rows, cols, scales_stride, *amax, - use_fast_math); // scales already filled + use_fast_math, use_4over6, e4m3_max, + four_over_six_candidate); // scales already filled quantize_nvfp4_2d(OP, input_t.data(), output_t, nullptr, cols, rows, scales_stride_t, *amax, - use_fast_math); // scales_t already filled + use_fast_math, use_4over6, e4m3_max, + four_over_six_candidate); // scales_t already filled return; } @@ -381,16 +541,21 @@ void compute_ref(float (*OP)(const float), scales_stride, amax[row], use_fast_math, - use_2d_quantization); + use_2d_quantization, + use_4over6, + e4m3_max, + four_over_six_candidate); } return; } // Ref impl for basic NVFP4 quantize_nvfp4(OP, input, output, scales, rows, cols, scales_stride, *amax, - use_fast_math, use_2d_quantization); + use_fast_math, use_2d_quantization, use_4over6, e4m3_max, + four_over_six_candidate); quantize_nvfp4(OP, input_t.data(), output_t, scales_t, cols, rows, scales_stride_t, *amax, - use_fast_math, use_2d_quantization); + use_fast_math, use_2d_quantization, use_4over6, e4m3_max, + four_over_six_candidate); } void compare_nvfp4_tensors(const std::string& name, @@ -515,6 +680,92 @@ void compareResults_nvfp4(Tensor &test, } } +template +bool bitwise_equal(const T& x, const T& y) { + const auto *x_bytes = reinterpret_cast(&x); + const auto *y_bytes = reinterpret_cast(&y); + for (size_t i = 0; i < sizeof(T); ++i) { + if (x_bytes[i] != y_bytes[i]) { + return false; + } + } + return true; +} + +bool nvfp4_output_block_matches(const fp4e2m1x2* const test_data, + const fp4e2m1x2* const ref_data, + const size_t row, + const size_t cols, + const size_t block_x) { + constexpr size_t block_size_X = 16; + const size_t j_min = block_x * block_size_X; + const size_t j_max = std::min(j_min + block_size_X, cols); + for (size_t j = j_min; j < j_max; j += 2) { + const size_t idx_pair = (row * cols + j) / 2; + if (!bitwise_equal(test_data[idx_pair], ref_data[idx_pair])) { + return false; + } + } + return true; +} + +void compare_nvfp4_4over6_candidates(const std::string& name, + const fp4e2m1* const test_data, + const fp8e4m3* const test_scales, + const fp4e2m1x2* const ref_data_map4, + const fp8e4m3* const ref_scales_map4, + const fp4e2m1x2* const ref_data_map6, + const fp8e4m3* const ref_scales_map6, + const size_t rows, + const size_t cols, + const size_t blocks_X, + const size_t scales_stride) { + constexpr int max_mismatches_to_print = 3; + const auto* const test_data_pairs = reinterpret_cast(test_data); + size_t total_mismatches = 0; + + for (size_t row = 0; row < rows; ++row) { + for (size_t block_x = 0; block_x < blocks_X; ++block_x) { + const size_t scale_idx = row * scales_stride + block_x; + const bool scale_matches_map4 = + bitwise_equal(test_scales[scale_idx], ref_scales_map4[scale_idx]); + const bool data_matches_map4 = + nvfp4_output_block_matches(test_data_pairs, ref_data_map4, row, cols, block_x); + const bool scale_matches_map6 = + bitwise_equal(test_scales[scale_idx], ref_scales_map6[scale_idx]); + const bool data_matches_map6 = + nvfp4_output_block_matches(test_data_pairs, ref_data_map6, row, cols, block_x); + + if ((scale_matches_map4 && data_matches_map4) || + (scale_matches_map6 && data_matches_map6)) { + continue; + } + + ++total_mismatches; + if (total_mismatches <= max_mismatches_to_print) { + std::cout << "Error in tensor " << name << ": 4over6 block mismatch at row " + << row << ", block_x " << block_x + << ". The output did not match either map-to-4 or map-to-6 exactly." + << std::endl; + } + } + } + + std::cout << "=== SUMMARY for tensor " << name << " ===" << std::endl; + std::cout << "Total 4over6 blocks checked: " << (rows * blocks_X) << std::endl; + if (total_mismatches > 0) { + std::cout << "STATUS: FAILED for output" << std::endl; + std::cout << "Total mismatched 4over6 blocks found: " << total_mismatches << std::endl; + std::cout << "============================" << std::endl; + GTEST_FAIL() << "Found " << total_mismatches << " 4over6 block mismatches in tensor " + << name; + } + + std::cout << "STATUS: PASSED for output" << std::endl; + std::cout << "Each 4over6 block matched either map-to-4 or map-to-6 exactly" << std::endl; + std::cout << "============================" << std::endl; +} + void compare_rowwise_amax(Tensor &output, const std::vector &ref_amax) { ASSERT_EQ(output.rowwise_amax_size(), ref_amax.size()); const auto *amax_ptr = output.cpu_rowwise_amax_ptr(); @@ -529,14 +780,27 @@ template void performTest(float (*OP)(const float), const std::vector& shape, const bool use_fast_math, - const bool row_scaled_nvfp4 = false) { + const NVFP4ScalingMode scaling_mode = NVFP4ScalingMode::Block1D, + const bool use_4over6 = false, + const int e4m3_max = 448, + const NVTENVFP44Over6ErrMode err_mode = kNVTENVFP44Over6ErrMAE, + const bool use_4over6_err_use_fast_math = false) { using namespace test; + if (use_4over6 && use_fast_math) { + std::cout << "WARNING: Plain NVFP4 fast math is ignored for 4over6. " + "Use use_4over6_err_use_fast_math to test the 4over6 candidate " + "error fast-math path." + << std::endl; + } + DType itype = TypeInfo::dtype; DType otype = DType::kFloat4E2M1; + const bool is_2d_quantization = use_2d_quantization(scaling_mode); + const bool is_row_scaled_nvfp4 = row_scaled_nvfp4(scaling_mode); const bool rowwise = true; - const bool columnwise = !row_scaled_nvfp4; + const bool columnwise = !is_row_scaled_nvfp4; const size_t rows = first_dimension(shape); const size_t cols = last_dimension(shape); @@ -560,17 +824,56 @@ void performTest(float (*OP)(const float), Tensor input("input", shape, itype); Tensor output("output", shape, otype, rowwise, columnwise, NVTE_NVFP4_1D_SCALING); + output.set_nvfp4_4over6(use_4over6); + output.set_nvfp4_e4m3_max((use_4over6 ? e4m3_max : 448)); std::unique_ptr ref_output = std::make_unique(rows * (cols / 2)); std::unique_ptr ref_output_t = std::make_unique(cols * (rows / 2)); std::unique_ptr ref_scales = std::make_unique(blocks_Y * blocks_X); std::unique_ptr ref_scales_t = std::make_unique(blocks_Y_t * blocks_X_t); + std::unique_ptr ref_output_map6; + std::unique_ptr ref_output_t_map6; + std::unique_ptr ref_scales_map6; + std::unique_ptr ref_scales_t_map6; fillCase(&input, InputsFillCase::uniform); + if (use_4over6 && is_row_scaled_nvfp4) { + const float target_row_amax = static_cast(e4m3_max) * 6.0f * 8.0f; + auto *input_vals = input.rowwise_cpu_dptr(); + for (size_t row = 0; row < rows; ++row) { + float row_amax = 0.0f; + size_t max_col = 0; + for (size_t col = 0; col < cols; ++col) { + const float val = static_cast(input_vals[row * cols + col]); + const float abs_val = fabsf(val); + if (abs_val > row_amax) { + row_amax = abs_val; + max_col = col; + } + } + + if (row_amax == 0.0f) { + continue; + } + + const float row_scale = target_row_amax / row_amax; + for (size_t col = 0; col < cols; ++col) { + float scaled = static_cast(input_vals[row * cols + col]) * row_scale; + scaled = fminf(fmaxf(scaled, -target_row_amax), target_row_amax); + input_vals[row * cols + col] = static_cast(scaled); + } + + const float max_val = static_cast(input_vals[row * cols + max_col]); + input_vals[row * cols + max_col] = + static_cast(max_val < 0.0f ? -target_row_amax : target_row_amax); + } + input.from_cpu(); + } + // Compute 2nd stage NVFP4 scaling factor std::vector ref_amax; - if (row_scaled_nvfp4) { + if (is_row_scaled_nvfp4) { // Compute per-row amaxes const auto *input_vals = input.rowwise_cpu_dptr(); for (size_t row = 0; row < rows; ++row){ @@ -584,10 +887,14 @@ void performTest(float (*OP)(const float), // Update tensor // Note: No need to update amax like standard NVFP4, amaxes // are computed during quantization. - output.set_row_scaled_nvfp4(row_scaled_nvfp4); + output.set_row_scaled_nvfp4(is_row_scaled_nvfp4); } else { // Golden value of amax chosen to make the 2nd-stage scaling mantissa zero and avoid rounding issues - ref_amax.assign(1, 448.0f * 6.0f * 8.0f); + if (use_4over6) { + ref_amax.assign(1, static_cast(e4m3_max) * 6.0f * 8.0f); + } else { + ref_amax.assign(1, 448.0f * 6.0f * 8.0f); + } // Update tensor if (rowwise) { @@ -599,22 +906,63 @@ void performTest(float (*OP)(const float), output.from_cpu(); } - // Reference implementation - bool use_2d_quantization = false; - compute_ref(OP, - input.rowwise_cpu_dptr(), - ref_output.get(), - ref_output_t.get(), - ref_scales.get(), - ref_scales_t.get(), - ref_amax.data(), - rows, - cols, - scales_stride, - scales_stride_t, - use_fast_math, - use_2d_quantization, - row_scaled_nvfp4); + if (use_4over6) { + ref_output_map6 = std::make_unique(rows * (cols / 2)); + ref_output_t_map6 = std::make_unique(cols * (rows / 2)); + ref_scales_map6 = std::make_unique(blocks_Y * blocks_X); + ref_scales_t_map6 = std::make_unique(blocks_Y_t * blocks_X_t); + + compute_ref(OP, + input.rowwise_cpu_dptr(), + ref_output.get(), + ref_output_t.get(), + ref_scales.get(), + ref_scales_t.get(), + ref_amax.data(), + rows, + cols, + scales_stride, + scales_stride_t, + use_fast_math, + is_2d_quantization, + is_row_scaled_nvfp4, + use_4over6, + e4m3_max, + NVFP4FourOverSixCandidate::Map4); + compute_ref(OP, + input.rowwise_cpu_dptr(), + ref_output_map6.get(), + ref_output_t_map6.get(), + ref_scales_map6.get(), + ref_scales_t_map6.get(), + ref_amax.data(), + rows, + cols, + scales_stride, + scales_stride_t, + use_fast_math, + is_2d_quantization, + is_row_scaled_nvfp4, + use_4over6, + e4m3_max, + NVFP4FourOverSixCandidate::Map6); + } else { + compute_ref(OP, + input.rowwise_cpu_dptr(), + ref_output.get(), + ref_output_t.get(), + ref_scales.get(), + ref_scales_t.get(), + ref_amax.data(), + rows, + cols, + scales_stride, + scales_stride_t, + use_fast_math, + is_2d_quantization, + is_row_scaled_nvfp4, + use_4over6); + } // Initialize stochastic rounding Tensor rng_state("rng_state", std::vector{2}, DType::kInt64); @@ -624,10 +972,14 @@ void performTest(float (*OP)(const float), // Quantization options QuantizationConfigWrapper quant_config; - quant_config.set_use_fast_math(use_fast_math); + quant_config.set_use_fast_math(use_fast_math && !use_4over6); quant_config.set_stochastic_rounding(false); quant_config.set_rng_state(rng_state.data()); - quant_config.set_nvfp4_2d_quantization(use_2d_quantization); + quant_config.set_nvfp4_2d_quantization(is_2d_quantization); + quant_config.set_nvfp4_4over6(use_4over6); + quant_config.set_nvfp4_e4m3_max((use_4over6 ? e4m3_max : 448)); + quant_config.set_nvfp4_4over6_err_mode(err_mode); + quant_config.set_nvfp4_4over6_err_use_fast_math(use_4over6 && use_4over6_err_use_fast_math); // Call appropriate function based on operation type // Activation functions take 3 parameters (input, output, stream) @@ -656,21 +1008,50 @@ void performTest(float (*OP)(const float), const double atol = 1.0E-6; const double rtol = 1.0E-6; - // Set dump_data=true to enable dumping tensor data to files for analysis - compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, - false, !row_scaled_nvfp4); - - size_t scale_mismatches_num = 0; - compare_scaling_factors("scales", output.rowwise_cpu_scale_inv_ptr(), - ref_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride, - scale_mismatches_num); - - if (!row_scaled_nvfp4) { - compare_scaling_factors("scales_t", output.columnwise_cpu_scale_inv_ptr(), - ref_scales_t.get(), - unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t, + if (use_4over6) { + output.to_cpu(); + compare_nvfp4_4over6_candidates("output", + output.rowwise_cpu_dptr(), + output.rowwise_cpu_scale_inv_ptr(), + ref_output.get(), + ref_scales.get(), + ref_output_map6.get(), + ref_scales_map6.get(), + rows, + cols, + unpadded_blocks_X, + scales_stride); + if (!is_row_scaled_nvfp4) { + compare_nvfp4_4over6_candidates("output_t", + output.columnwise_cpu_dptr(), + output.columnwise_cpu_scale_inv_ptr(), + ref_output_t.get(), + ref_scales_t.get(), + ref_output_t_map6.get(), + ref_scales_t_map6.get(), + cols, + rows, + unpadded_blocks_X_t, + scales_stride_t); + } + } else { + // Set dump_data=true to enable dumping tensor data to files for analysis + compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, + true, false, !is_row_scaled_nvfp4); + + size_t scale_mismatches_num = 0; + compare_scaling_factors("scales", output.rowwise_cpu_scale_inv_ptr(), + ref_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, scale_mismatches_num); + + if (!is_row_scaled_nvfp4) { + compare_scaling_factors("scales_t", + output.columnwise_cpu_scale_inv_ptr(), + ref_scales_t.get(), + unpadded_blocks_Y_t, unpadded_blocks_X_t, + scales_stride_t, scale_mismatches_num); + } } compare_rowwise_amax(output, ref_amax); @@ -707,7 +1088,8 @@ class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam std::vector, transformer_engine::DType, bool, - bool>> {}; + NVFP4ScalingMode, + NVFP4FourOverSixTestConfig>> {}; TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { // Skip tests for pre-Blackwell architectures @@ -722,7 +1104,8 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { const auto tensor_dims = std::get<1>(GetParam()); const DType input_type = std::get<2>(GetParam()); const bool use_fast_math = std::get<3>(GetParam()); - const bool row_scaled_nvfp4 = std::get<4>(GetParam()); + const NVFP4ScalingMode scaling_mode = std::get<4>(GetParam()); + const NVFP4FourOverSixTestConfig config = std::get<5>(GetParam()); // Skip tests if the input tensor is 1D if (tensor_dims.size() < 2) { @@ -740,7 +1123,9 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { } TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, - performTest(OP, tensor_dims, use_fast_math, row_scaled_nvfp4); + performTest(OP, tensor_dims, use_fast_math, scaling_mode, config.enabled, + config.e4m3_max, config.err_mode, + config.err_use_fast_math); ); } @@ -756,49 +1141,94 @@ std::string to_string(const ActivationType Act_type) { } } +std::string to_string(const NVFP4ScalingMode scaling_mode) { + switch (scaling_mode) { + case NVFP4ScalingMode::Block1D: return ""; + case NVFP4ScalingMode::RowScaled1D: return "XROW_SCALED"; + case NVFP4ScalingMode::Block2D: return "X2D"; + default: return ""; + } +} + +std::string test_name(const FusedCastTransposeNVFP4TestSuite::ParamType& param) { + std::string name = to_string(std::get<0>(param)); + const auto& shape = std::get<1>(param); + for (const auto& s: shape) { + name += "X" + std::to_string(s); + } + name += "X" + test::typeName(std::get<2>(param)); + if (std::get<3>(param)) { + name += "X_FAST_SCALING"; + } + name += to_string(std::get<4>(param)); + const NVFP4FourOverSixTestConfig& config = std::get<5>(param); + if (config.enabled) { + name += "X4OVER6"; + if (config.e4m3_max == 448) { + name += "XE4M3_MAX_448"; + } else { + name += "XE4M3_MAX_256"; + } + if (config.err_mode == kNVTENVFP44Over6ErrMSE) { + name += "XMSE"; + } else { + name += "XMAE"; + } + if (config.err_use_fast_math) { + name += "XERR_USE_FAST_MATH"; + } + } + return name; +} + INSTANTIATE_TEST_SUITE_P( OperatorTest, FusedCastTransposeNVFP4TestSuite, ::testing::Combine( - ::testing::ValuesIn(Activation_types), - ::testing::ValuesIn(tensor_dims), - ::testing::Values(DType::kBFloat16), - ::testing::Values(false), - ::testing::Values(false)), + ::testing::ValuesIn(Activation_types), // activation_type + ::testing::ValuesIn(tensor_dims), // tensor_dims + ::testing::Values(DType::kBFloat16), // input_type + ::testing::Values(false), // use_fast_math + ::testing::Values(NVFP4ScalingMode::Block1D), // scaling_mode + ::testing::Values(NVFP4FourOverSixTestConfig{})), // four_over_six_config [](const testing::TestParamInfo& info) { - std::string name = to_string(std::get<0>(info.param)); - const auto& shape = std::get<1>(info.param); - for ( const auto& s: shape) { - name += "X" + std::to_string(s); - } - name += "X" + test::typeName(std::get<2>(info.param)); - if (std::get<3>(info.param)) { - name += "X_FAST_SCALING"; - } - return name; + return test_name(info.param); }); INSTANTIATE_TEST_SUITE_P( OperatorTestRowScaled, FusedCastTransposeNVFP4TestSuite, ::testing::Combine( - ::testing::Values(ActivationType::Identity), - ::testing::Values(tensor_dims[4], tensor_dims[9], tensor_dims[12]), - ::testing::Values(DType::kBFloat16, DType::kFloat32), - ::testing::Values(false), - ::testing::Values(true)), + ::testing::ValuesIn(Activation_types), // activation_type + ::testing::ValuesIn(tensor_dims), // tensor_dims + ::testing::Values(DType::kBFloat16, DType::kFloat32), // input_type + ::testing::Values(false), // use_fast_math + ::testing::Values(NVFP4ScalingMode::RowScaled1D), // scaling_mode + ::testing::Values(NVFP4FourOverSixTestConfig{})), // four_over_six_config [](const testing::TestParamInfo& info) { - std::string name = to_string(std::get<0>(info.param)); - const auto& shape = std::get<1>(info.param); - for (const auto& s: shape) { - name += "X" + std::to_string(s); - } - name += "X" + test::typeName(std::get<2>(info.param)); - if (std::get<3>(info.param)) { - name += "X_FAST_SCALING"; - } - if (std::get<4>(info.param)) { - name += "XROW_SCALED"; - } - return name; + return test_name(info.param); + }); + +INSTANTIATE_TEST_SUITE_P( + OperatorTest4Over6, + FusedCastTransposeNVFP4TestSuite, + ::testing::Combine( + ::testing::ValuesIn(Activation_types), // activation_type + ::testing::ValuesIn(tensor_dims), // tensor_dims + ::testing::Values(DType::kBFloat16, DType::kFloat32), // input_type + ::testing::Values(false), // use_fast_math + ::testing::Values(NVFP4ScalingMode::Block1D, + NVFP4ScalingMode::RowScaled1D, + NVFP4ScalingMode::Block2D), // scaling_mode + ::testing::Values( + NVFP4FourOverSixTestConfig{true, 448, kNVTENVFP44Over6ErrMAE, false}, + NVFP4FourOverSixTestConfig{true, 448, kNVTENVFP44Over6ErrMAE, true}, + NVFP4FourOverSixTestConfig{true, 448, kNVTENVFP44Over6ErrMSE, false}, + NVFP4FourOverSixTestConfig{true, 448, kNVTENVFP44Over6ErrMSE, true}, + NVFP4FourOverSixTestConfig{true, 256, kNVTENVFP44Over6ErrMAE, false}, + NVFP4FourOverSixTestConfig{true, 256, kNVTENVFP44Over6ErrMAE, true}, + NVFP4FourOverSixTestConfig{true, 256, kNVTENVFP44Over6ErrMSE, false}, + NVFP4FourOverSixTestConfig{true, 256, kNVTENVFP44Over6ErrMSE, true})), // four_over_six_config + [](const testing::TestParamInfo& info) { + return test_name(info.param); }); diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu index eb9e8bce23..0dfaca7e58 100644 --- a/tests/cpp/operator/test_dequantize_nvfp4.cu +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -46,8 +46,9 @@ void compute_ref_dequantize_nvfp4(const uint8_t *packed_data, OType *output, size_t rows, size_t cols, - size_t scale_stride) { - constexpr float factor_inv = 1.0f / (6.0f * 448.0f); + size_t scale_stride, + int e4m3_max) { + const float factor_inv = 1.0f / (6.0f * static_cast(e4m3_max)); constexpr size_t BLOCK_SIZE = 16; const size_t Mread = cols / BLOCK_SIZE; const size_t bytes_per_block = BLOCK_SIZE / 2; @@ -90,7 +91,9 @@ float compute_amax(test::Tensor &t, size_t rows, size_t cols) { // against a CPU reference computed from the quantized data. template void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, - const bool row_scaled_nvfp4) { + const bool row_scaled_nvfp4, + const bool use_4over6, + const int e4m3_max) { using namespace test; DType otype = TypeInfo::dtype; @@ -105,6 +108,10 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, // Configure quantized tensor amax size_t amax_size = 1; + quantized.set_nvfp4_4over6(use_4over6); + quantized.set_nvfp4_e4m3_max((use_4over6 ? e4m3_max : 448)); + ASSERT_EQ(quantized.nvfp4_4over6(), use_4over6); + ASSERT_EQ(quantized.nvfp4_e4m3_max(), (use_4over6 ? e4m3_max : 448)); if (row_scaled_nvfp4) { quantized.set_row_scaled_nvfp4(true); amax_size = rows; @@ -116,7 +123,10 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, // Quantize if (rows > 0 && cols > 0) { - nvte_quantize(input.data(), quantized.data(), 0); + QuantizationConfigWrapper quant_config; + quant_config.set_nvfp4_4over6(use_4over6); + quant_config.set_nvfp4_e4m3_max((use_4over6 ? e4m3_max : 448)); + nvte_quantize_v2(input.data(), quantized.data(), quant_config, 0); cudaDeviceSynchronize(); auto err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); @@ -146,7 +156,7 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, std::make_unique(rows * cols); compute_ref_dequantize_nvfp4( fp4_data, scales, amax_vals, ref_output.get(), - rows, cols, scale_stride); + rows, cols, scale_stride, (use_4over6 ? e4m3_max : 448)); // Compare results from TE and reference impls auto [atol, rtol] = getTolerances(otype); @@ -156,7 +166,9 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, // Dequantize NVFP4 with GEMM-swizzled scales and compare against compact path. template void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, - const bool row_scaled_nvfp4) { + const bool row_scaled_nvfp4, + const bool use_4over6, + const int e4m3_max) { using namespace test; DType otype = TypeInfo::dtype; @@ -165,6 +177,10 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, Tensor quantized_compact("quantized_compact", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); + quantized_compact.set_nvfp4_4over6(use_4over6); + quantized_compact.set_nvfp4_e4m3_max((use_4over6 ? e4m3_max : 448)); + ASSERT_EQ(quantized_compact.nvfp4_4over6(), use_4over6); + ASSERT_EQ(quantized_compact.nvfp4_e4m3_max(), (use_4over6 ? e4m3_max : 448)); if (row_scaled_nvfp4) { quantized_compact.set_row_scaled_nvfp4(true); } else if (rows > 0 && cols > 0) { @@ -174,7 +190,10 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, } if (rows > 0 && cols > 0) { - nvte_quantize(input.data(), quantized_compact.data(), 0); + QuantizationConfigWrapper quant_config; + quant_config.set_nvfp4_4over6(use_4over6); + quant_config.set_nvfp4_e4m3_max((use_4over6 ? e4m3_max : 448)); + nvte_quantize_v2(input.data(), quantized_compact.data(), quant_config, 0); cudaDeviceSynchronize(); } @@ -186,6 +205,10 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, // Create tensor with same FP4 data but swizzled scales Tensor quantized_swizzled("quantized_swizzled", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); + quantized_swizzled.set_nvfp4_4over6(use_4over6); + quantized_swizzled.set_nvfp4_e4m3_max((use_4over6 ? e4m3_max : 448)); + ASSERT_EQ(quantized_swizzled.nvfp4_4over6(), use_4over6); + ASSERT_EQ(quantized_swizzled.nvfp4_e4m3_max(), (use_4over6 ? e4m3_max : 448)); if (row_scaled_nvfp4) { quantized_swizzled.set_row_scaled_nvfp4(true); } else { @@ -260,7 +283,9 @@ std::vector> nvfp4_tensor_dims = { class DequantizeNVFP4TestSuite : public ::testing::TestWithParam , transformer_engine::DType, - bool>> {}; + bool, + bool, + int>> {}; TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4) { @@ -271,10 +296,12 @@ TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4) const auto tensor_size = std::get<0>(GetParam()); const DType output_type = std::get<1>(GetParam()); const bool row_scaled_nvfp4 = std::get<2>(GetParam()); + const bool use_4over6 = std::get<3>(GetParam()); + const int e4m3_max = use_4over6 ? std::get<4>(GetParam()) : 448; TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, performTest_dequantize_nvfp4( - tensor_size.first, tensor_size.second, row_scaled_nvfp4); + tensor_size.first, tensor_size.second, row_scaled_nvfp4, use_4over6, e4m3_max); ); } @@ -284,13 +311,20 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Combine( ::testing::ValuesIn(nvfp4_tensor_dims), ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), - ::testing::Bool()), + ::testing::Bool(), + ::testing::Bool(), + ::testing::Values(448, 256)), [](const testing::TestParamInfo& info) { std::string name = std::to_string(std::get<0>(info.param).first) + "X" + std::to_string(std::get<0>(info.param).second) + "X" + test::typeName(std::get<1>(info.param)) + "X" + - (std::get<2>(info.param) ? "RowScaled" : "PerTensor"); + (std::get<2>(info.param) ? "RowScaled" : "PerTensor") + "X" + + (std::get<3>(info.param) ? "FourOverSix" : "Default") + "X" + + (std::get<3>(info.param) + ? (std::get<4>(info.param) == 256 ? "E4M3Max256" : "E4M3Max448") + : (std::get<4>(info.param) == 256 ? "E4M3Max256Ignored" + : "E4M3Max448")); return name; } ); @@ -298,7 +332,9 @@ INSTANTIATE_TEST_SUITE_P( class DequantizeNVFP4SwizzledTestSuite : public ::testing::TestWithParam , transformer_engine::DType, - bool>> {}; + bool, + bool, + int>> {}; TEST_P(DequantizeNVFP4SwizzledTestSuite, TestDequantizeNVFP4Swizzled) { @@ -309,10 +345,12 @@ TEST_P(DequantizeNVFP4SwizzledTestSuite, TestDequantizeNVFP4Swizzled) const auto tensor_size = std::get<0>(GetParam()); const DType output_type = std::get<1>(GetParam()); const bool row_scaled_nvfp4 = std::get<2>(GetParam()); + const bool use_4over6 = std::get<3>(GetParam()); + const int e4m3_max = use_4over6 ? std::get<4>(GetParam()) : 448; TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, performTest_dequantize_nvfp4_swizzled( - tensor_size.first, tensor_size.second, row_scaled_nvfp4); + tensor_size.first, tensor_size.second, row_scaled_nvfp4, use_4over6, e4m3_max); ); } @@ -322,13 +360,20 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Combine( ::testing::ValuesIn(nvfp4_tensor_dims), ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), - ::testing::Bool()), + ::testing::Bool(), + ::testing::Bool(), + ::testing::Values(448, 256)), [](const testing::TestParamInfo& info) { std::string name = std::to_string(std::get<0>(info.param).first) + "X" + std::to_string(std::get<0>(info.param).second) + "X" + test::typeName(std::get<1>(info.param)) + "X" + (std::get<2>(info.param) ? "RowScaled" : "PerTensor") + "X" + + (std::get<3>(info.param) ? "FourOverSix" : "Default") + "X" + + (std::get<3>(info.param) + ? (std::get<4>(info.param) == 256 ? "E4M3Max256" : "E4M3Max448") + : (std::get<4>(info.param) == 256 ? "E4M3Max256Ignored" + : "E4M3Max448")) + "X" + "Swizzled"; return name; } diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 4fd75bb927..3569ab258a 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -440,6 +440,30 @@ void Tensor::set_row_scaled_nvfp4(bool row_scaled_nvfp4) { } } +void Tensor::set_nvfp4_4over6(bool nvfp4_4over6) { + NVTE_CHECK(tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING, + "NVFP4 4over6 is only supported for NVFP4 tensors."); + tensor_.set_nvfp4_4over6(nvfp4_4over6); +} + +void Tensor::set_nvfp4_e4m3_max(int nvfp4_e4m3_max) { + NVTE_CHECK(tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING, + "NVFP4 E4M3 max is only supported for NVFP4 tensors."); + tensor_.set_nvfp4_e4m3_max(nvfp4_e4m3_max); +} + +bool Tensor::nvfp4_4over6() const { + NVTE_CHECK(tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING, + "NVFP4 4over6 is only supported for NVFP4 tensors."); + return tensor_.get_nvfp4_4over6(); +} + +int Tensor::nvfp4_e4m3_max() const { + NVTE_CHECK(tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING, + "NVFP4 E4M3 max is only supported for NVFP4 tensors."); + return tensor_.get_nvfp4_e4m3_max(); +} + void Tensor::to_cpu() { if (data_rowwise_) { data_rowwise_->to_cpu(); } if (data_columnwise_) { data_columnwise_->to_cpu(); } diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 17f36a99dd..851593cae7 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -293,10 +293,15 @@ class Tensor { return columnwise_; } + bool nvfp4_4over6() const; + int nvfp4_e4m3_max() const; + void set_tensor_amax_nullptr(); void set_with_gemm_swizzled_scales(bool with_gemm_swizzled_scales); void set_row_scaled_nvfp4(bool row_scaled_nvfp4); + void set_nvfp4_4over6(bool nvfp4_4over6); + void set_nvfp4_e4m3_max(int nvfp4_e4m3_max); void to_cpu(); void from_cpu(); diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index a7ea4f089f..bd4d029729 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -28,7 +28,12 @@ def check_nvfp4_gemm_versus_reference( x_columnwise: bool = False, w_columnwise: bool = False, row_scaled_nvfp4: bool = False, + use_4over6: bool = False, + nvfp4_e4m3_max: int = 448, + nvfp4_4over6_err_mode: str = "MAE", ): + if nvfp4_e4m3_max != 448 and not use_4over6: + pytest.skip("E4M3 max 256 is only meaningful for 4over6") te_dtype = tex.DType.kFloat4E2M1 # Setup device and random seed @@ -59,6 +64,9 @@ def check_nvfp4_gemm_versus_reference( with_rht=False, with_post_rht_amax=False, row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) w_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -68,6 +76,9 @@ def check_nvfp4_gemm_versus_reference( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, + nvfp4_use_4over6=use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) # Quantize x and w @@ -123,6 +134,9 @@ def check_nvfp4_gemm_versus_reference( eps=0.0, quant_tile_shape=(1, 16), row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) w_ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, @@ -131,6 +145,9 @@ def check_nvfp4_gemm_versus_reference( pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), + nvfp4_use_4over6=use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) # Create reference quantized tensors needed by reference GEMM @@ -232,6 +249,8 @@ def check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( *, use_bias: bool, single_output: bool, + use_4over6: bool = False, + nvfp4_4over6_err_mode: str = "MAE", ): te_dtype = tex.DType.kFloat4E2M1 device = "cuda" @@ -249,6 +268,8 @@ def check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( with_rht=False, with_post_rht_amax=False, row_scaled_nvfp4=True, + nvfp4_use_4over6=use_4over6, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) w_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -258,6 +279,8 @@ def check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, + nvfp4_use_4over6=use_4over6, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) x_nvfp4 = [] @@ -321,6 +344,8 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( M: int, K: int, N: int, + use_4over6: bool = False, + nvfp4_4over6_err_mode: str = "MAE", ): te_dtype = tex.DType.kFloat4E2M1 device = "cuda" @@ -339,6 +364,8 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( with_rht=False, with_post_rht_amax=False, row_scaled_nvfp4=True, + nvfp4_use_4over6=use_4over6, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) x_tensorwise_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -348,6 +375,8 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, + nvfp4_use_4over6=use_4over6, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) w_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -357,6 +386,8 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, + nvfp4_use_4over6=use_4over6, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) x_row_scaled = x_row_scaled_quantizer.update_quantized( @@ -417,6 +448,9 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( ids=["rowxrow", "colxrow", "colxcol"], ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) +@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) +@pytest.mark.parametrize("nvfp4_e4m3_max", [448, 256], ids=["e4m3_448", "e4m3_256"]) +@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) def test_nvfp4_gemm_versus_reference( M: int, K: int, @@ -428,6 +462,9 @@ def test_nvfp4_gemm_versus_reference( is_x_columnwise: bool, is_w_columnwise: bool, row_scaled_nvfp4: bool, + use_4over6: bool, + nvfp4_e4m3_max: int, + nvfp4_4over6_err_mode: str, ): if row_scaled_nvfp4: if accumulate: @@ -446,6 +483,9 @@ def test_nvfp4_gemm_versus_reference( x_columnwise=is_x_columnwise, w_columnwise=is_w_columnwise, row_scaled_nvfp4=row_scaled_nvfp4, + use_4over6=use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) @@ -471,6 +511,8 @@ def test_nvfp4_gemm_versus_reference( @pytest.mark.parametrize("out_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("use_bias", [False, True], ids=["no_bias", "bias"]) @pytest.mark.parametrize("single_output", [False, True], ids=["list_output", "single_output"]) +@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) +@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) def test_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( m_splits: list[int], k: int, @@ -480,6 +522,8 @@ def test_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( out_dtype: torch.dtype, use_bias: bool, single_output: bool, + use_4over6: bool, + nvfp4_4over6_err_mode: str, ): check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( x_dtype=x_dtype, @@ -490,6 +534,8 @@ def test_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( n=n, use_bias=use_bias, single_output=single_output, + use_4over6=use_4over6, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) @@ -513,6 +559,8 @@ def test_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( @pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("w_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) +@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) def test_nvfp4_row_scaled_gemm_matches_emulated( M: int, K: int, @@ -520,6 +568,8 @@ def test_nvfp4_row_scaled_gemm_matches_emulated( x_dtype: torch.dtype, w_dtype: torch.dtype, out_dtype: torch.dtype, + use_4over6: bool, + nvfp4_4over6_err_mode: str, ): check_nvfp4_row_scaled_gemm_matches_emulated( x_dtype=x_dtype, @@ -528,4 +578,6 @@ def test_nvfp4_row_scaled_gemm_matches_emulated( M=M, K=K, N=N, + use_4over6=use_4over6, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 53569d90d9..5bb92f70dc 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -20,7 +20,14 @@ def maybe_skip_row_scaled_unsupported_quantization( row_scaled_nvfp4: bool, return_transpose: bool, with_2d_quantization: bool = False, + use_4over6: bool = False, + x_dtype: torch.dtype | None = None, + M: int | None = None, + N: int | None = None, ) -> None: + if use_4over6 and with_2d_quantization: + if x_dtype != torch.bfloat16 or M is None or N is None or M % 32 != 0 or N % 32 != 0: + pytest.skip("NVFP4 2D 4over6 exact tests require the optimized BF16 kernel path") if not row_scaled_nvfp4: return if return_transpose: @@ -45,9 +52,14 @@ def check_quantization_nvfp4_versus_reference( use_cpp_allocator: bool, with_2d_quantization: bool, row_scaled_nvfp4: bool = False, + use_4over6: bool = False, + nvfp4_e4m3_max: int = 448, + nvfp4_4over6_err_mode: str = "MAE", ) -> None: + if nvfp4_e4m3_max != 448 and not use_4over6: + pytest.skip("E4M3 max 256 is only meaningful for 4over6") maybe_skip_row_scaled_unsupported_quantization( - row_scaled_nvfp4, return_transpose, with_2d_quantization + row_scaled_nvfp4, return_transpose, with_2d_quantization, use_4over6, x_dtype, M, N ) te_dtype = tex.DType.kFloat4E2M1 @@ -71,6 +83,9 @@ def check_quantization_nvfp4_versus_reference( with_post_rht_amax=False, with_2d_quantization=with_2d_quantization, row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) if use_cpp_allocator: x_nvfp4_sut = nvfp4_quantizer(x) @@ -104,6 +119,9 @@ def check_quantization_nvfp4_versus_reference( eps=0.0, quant_tile_shape=quant_tile_shape, row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -179,6 +197,9 @@ def check_quantization_nvfp4_versus_reference( "with_2d_quantization", [True, False], ids=["2d_quantization", "1d_quantization"] ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) +@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) +@pytest.mark.parametrize("nvfp4_e4m3_max", [448, 256], ids=["e4m3_448", "e4m3_256"]) +@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) def test_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, M: int, @@ -188,6 +209,9 @@ def test_quantization_block_tiling_versus_reference( use_cpp_allocator: bool, with_2d_quantization: bool, row_scaled_nvfp4: bool, + use_4over6: bool, + nvfp4_e4m3_max: int, + nvfp4_4over6_err_mode: str, ) -> None: check_quantization_nvfp4_versus_reference( x_dtype=x_dtype, @@ -198,6 +222,9 @@ def test_quantization_block_tiling_versus_reference( use_cpp_allocator=use_cpp_allocator, with_2d_quantization=with_2d_quantization, row_scaled_nvfp4=row_scaled_nvfp4, + use_4over6=use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) @@ -215,6 +242,8 @@ def test_quantization_block_tiling_versus_reference( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) +@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) +@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) def test_nvfp4_quantization_extrema_versus_reference( x_dtype: torch.dtype, M: int, @@ -223,8 +252,12 @@ def test_nvfp4_quantization_extrema_versus_reference( return_transpose: bool, use_cpp_allocator: bool, row_scaled_nvfp4: bool, + use_4over6: bool, + nvfp4_4over6_err_mode: str, ): - maybe_skip_row_scaled_unsupported_quantization(row_scaled_nvfp4, return_transpose) + maybe_skip_row_scaled_unsupported_quantization( + row_scaled_nvfp4, return_transpose, use_4over6=use_4over6 + ) te_dtype = tex.DType.kFloat4E2M1 @@ -247,6 +280,8 @@ def test_nvfp4_quantization_extrema_versus_reference( with_rht=False, with_post_rht_amax=False, row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=use_4over6, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) if use_cpp_allocator: @@ -278,6 +313,8 @@ def test_nvfp4_quantization_extrema_versus_reference( eps=0.0, quant_tile_shape=(1, 16), row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=use_4over6, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -322,6 +359,8 @@ def test_nvfp4_quantization_extrema_versus_reference( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) +@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) +@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) def test_nvfp4_quantization_boundary_values( x_dtype: torch.dtype, M: int, @@ -329,13 +368,17 @@ def test_nvfp4_quantization_boundary_values( return_transpose: bool, use_cpp_allocator: bool, row_scaled_nvfp4: bool, + use_4over6: bool, + nvfp4_4over6_err_mode: str, ): """ Stress rounding/threshold behavior by placing values just below/above many potential bin edges within each 16-element microblock. Validates native vs reference byte-for-byte and scale parity. """ - maybe_skip_row_scaled_unsupported_quantization(row_scaled_nvfp4, return_transpose) + maybe_skip_row_scaled_unsupported_quantization( + row_scaled_nvfp4, return_transpose, use_4over6=use_4over6 + ) te_dtype = tex.DType.kFloat4E2M1 @@ -367,6 +410,8 @@ def test_nvfp4_quantization_boundary_values( with_rht=False, with_post_rht_amax=False, row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=use_4over6, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) if use_cpp_allocator: @@ -398,6 +443,8 @@ def test_nvfp4_quantization_boundary_values( eps=0.0, quant_tile_shape=(1, 16), row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=use_4over6, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -442,6 +489,8 @@ def test_nvfp4_quantization_boundary_values( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) +@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) +@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) def test_nvfp4_quantization_noncontiguous_inputs( x_dtype: torch.dtype, M: int, @@ -449,8 +498,12 @@ def test_nvfp4_quantization_noncontiguous_inputs( return_transpose: bool, use_cpp_allocator: bool, row_scaled_nvfp4: bool, + use_4over6: bool, + nvfp4_4over6_err_mode: str, ): - maybe_skip_row_scaled_unsupported_quantization(row_scaled_nvfp4, return_transpose) + maybe_skip_row_scaled_unsupported_quantization( + row_scaled_nvfp4, return_transpose, use_4over6=use_4over6 + ) te_dtype = tex.DType.kFloat4E2M1 @@ -473,6 +526,8 @@ def test_nvfp4_quantization_noncontiguous_inputs( with_rht=False, with_post_rht_amax=False, row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=use_4over6, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) if use_cpp_allocator: @@ -504,6 +559,8 @@ def test_nvfp4_quantization_noncontiguous_inputs( eps=0.0, quant_tile_shape=(1, 16), row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=use_4over6, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) x_nvfp4_ref = ref_quantizer.quantize(x_nc) diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index 43e9587d95..5e6f36e8b4 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -83,6 +83,11 @@ marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), id="NVFP4RowScaledBlockScaling", ), + pytest.param( + "nvfp4_4over6", + marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), + id="NVFP44Over6BlockScaling", + ), ] @@ -170,7 +175,7 @@ def _maybe_skip_recipe_dtype( ) -> None: if dtype == torch.bfloat16 and not bf16_available: pytest.skip(reason_for_no_bf16) - if recipe_name in ("nvfp4", "nvfp4_row_scaled"): + if recipe_name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): if module_type in ("linear", "layernorm_linear") and dtype not in ( torch.bfloat16, torch.float32, @@ -185,6 +190,8 @@ def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: s pytest.skip("Fusible ops (te_ops.Linear) do not support Float8BlockScaling recipe") if module_type == "ops_linear" and recipe_name == "nvfp4_row_scaled": pytest.skip("Row-scaled NVFP4 currently does not support fused te_ops paths.") + if module_type == "grouped_linear" and recipe_name == "nvfp4_4over6": + pytest.skip("NVFP4 4over6 currently does not support grouped quantization.") def _make_quantized_forward_reference_recipe(recipe_name: str) -> recipe.Recipe: @@ -208,7 +215,7 @@ def _maybe_skip_unsupported_recipe_shape( " by 32." ) return - if recipe_name in ("nvfp4", "nvfp4_row_scaled") and ( + if recipe_name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") and ( flat_first_dim % 16 != 0 or last_dim % 16 != 0 ): pytest.skip( @@ -235,7 +242,7 @@ def _maybe_skip_unsupported_recipe_shape( pytest.skip( "te_ops.Linear + MXFP8 requires prod(shape[:-1]) and shape[-1] divisible by 32." ) - if recipe_name in ("nvfp4", "nvfp4_row_scaled") and ( + if recipe_name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") and ( flat_first_dim % 16 != 0 or last_dim % 16 != 0 ): pytest.skip( @@ -256,9 +263,13 @@ def _maybe_skip_unsupported_grouped_splits(recipe_name: str, m_splits: list[int] ) if recipe_name == "mxfp8" and any(m % 32 != 0 for m in non_empty_splits): pytest.skip("GroupedLinear + MXFP8 requires each non-empty m_split divisible by 32.") - if recipe_name in ("nvfp4", "nvfp4_row_scaled") and any(m % 16 != 0 for m in non_empty_splits): + if recipe_name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") and any( + m % 16 != 0 for m in non_empty_splits + ): pytest.skip("GroupedLinear + NVFP4 requires each non-empty m_split divisible by 16.") - if recipe_name in ("nvfp4", "nvfp4_row_scaled") and any(m % 64 != 0 for m in non_empty_splits): + if recipe_name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") and any( + m % 64 != 0 for m in non_empty_splits + ): pytest.skip( "GroupedLinear + NVFP4 grouped split_quantize currently requires each non-empty " "m_split divisible by 64 due to grouped amax kernel constraints." diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index 50196782f2..35cc98a976 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -19,7 +19,7 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager import transformer_engine.pytorch as te from transformer_engine.common import recipe -from utils import ModelConfig, skip_unsupported_backward_override +from utils import ModelConfig, recipe_id, skip_unsupported_backward_override import transformer_engine_torch as tex # Check supported quantization schemes @@ -28,6 +28,33 @@ mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available() nvfp4_available, _ = FP8GlobalStateManager.is_nvfp4_available() + +def nvfp4_row_scaled(): + nvfp4_recipe = recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + row_scaled_activation=True, + backward_override="dequantized", + ) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + +def nvfp4_4over6(): + nvfp4_recipe = recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + nvfp4_4over6="all", + ) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(fp4_2d_quantization=True) + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + quantization_recipes: List[Optional[recipe.Recipe]] = [None] if fp8_available: quantization_recipes.extend((recipe.Float8CurrentScaling(), recipe.DelayedScaling())) @@ -37,6 +64,8 @@ quantization_recipes.append(recipe.MXFP8BlockScaling()) if nvfp4_available: quantization_recipes.append(recipe.NVFP4BlockScaling()) + quantization_recipes.append(nvfp4_4over6()) + quantization_recipes.append(nvfp4_row_scaled()) model_config = { @@ -176,7 +205,20 @@ def create_tensor(recipe: Optional[recipe.Recipe], requires_grad: bool = False) quantizer = te.tensor.mxfp8_tensor.MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) return quantizer(tensor) elif recipe.nvfp4(): - quantizer = te.tensor.nvfp4_tensor.NVFP4Quantizer() + qparams = recipe.fp4_quant_fwd_inp + use_4over6 = False + if recipe.nvfp4_4over6 in ("activations", "all"): + use_4over6 = True + quantizer = te.tensor.nvfp4_tensor.NVFP4Quantizer( + rowwise=True, + columnwise=not recipe.row_scaled_activation, + with_rht=qparams.random_hadamard_transform, + with_post_rht_amax=qparams.random_hadamard_transform, + with_2d_quantization=qparams.fp4_2d_quantization, + stochastic_rounding=qparams.stochastic_rounding, + row_scaled_nvfp4=recipe.row_scaled_activation, + nvfp4_use_4over6=use_4over6, + ) return quantizer(tensor) @staticmethod @@ -191,10 +233,24 @@ def get_tensor_size_mb(tensor): if tensor is None: return 0 if isinstance(tensor, te.quantized_tensor.QuantizedTensorStorage): - return sum(Utils.get_tensor_size_mb(t) for t in tensor.get_data_tensors()) + tensors = [ + value for value in tensor.get_metadata().values() if isinstance(value, torch.Tensor) + ] + return sum(Utils.get_tensor_size_mb(t) for t in tensors) else: return tensor.numel() * tensor.element_size() / (1024**2) + @staticmethod + def get_saved_tensor_gpu_size_mb(tensor): + if tensor is None or isinstance(tensor, int): + return 0 + if isinstance(tensor, tuple): + push_results, _ = tensor + return Utils.get_saved_tensor_gpu_size_mb(push_results) + if isinstance(tensor, list): + return sum(Utils.get_saved_tensor_gpu_size_mb(t) for t in tensor) + return Utils.get_tensor_size_mb(tensor) + @staticmethod def memory_leak_check(): # Should be called before each test. @@ -212,7 +268,7 @@ def memory_leak_check(): class TestsOffloadableLayerState: @pytest.mark.parametrize("random_num_tensors", [True, False]) - @pytest.mark.parametrize("recipe", quantization_recipes) + @pytest.mark.parametrize("recipe", quantization_recipes, ids=recipe_id) def test_general(self, random_num_tensors, recipe): """ Test general functionality of DefaultOffloadSynchronizer - offload NUM_LAYERS-1 out of NUM_LAYERS layers, @@ -289,7 +345,7 @@ def test_offload_base_tensor(self): class TestsDefaultOffloadSynchronizer: @pytest.mark.parametrize("random_num_tensors", [True, False]) - @pytest.mark.parametrize("recipe", quantization_recipes) + @pytest.mark.parametrize("recipe", quantization_recipes, ids=recipe_id) def test_general(self, random_num_tensors, recipe): """ Test general functionality of DefaultOffloadSynchronizer - offload NUM_LAYERS-1 out of NUM_LAYERS layers, @@ -335,7 +391,7 @@ def test_general(self, random_num_tensors, recipe): offload_synchronizer.finish_part_of_bwd() torch.cuda.synchronize() - @pytest.mark.parametrize("recipe", quantization_recipes) + @pytest.mark.parametrize("recipe", quantization_recipes, ids=recipe_id) def test_memory(self, recipe): torch.cuda.synchronize() Utils.memory_leak_check() @@ -363,11 +419,16 @@ def test_memory(self, recipe): del tensor, tensor_id torch.cuda.synchronize() + resident_gpu_size = sum( + Utils.get_saved_tensor_gpu_size_mb(tensor_id) for tensor_id in tensor_ids + ) if recipe is None: assert Utils.get_max_cuda_memory_mb() == pytest.approx( - init_cuda_memory + tensor_size, 0.1 + init_cuda_memory + resident_gpu_size, 0.1 ) - assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory + tensor_size, 0.1) + assert Utils.get_cuda_memory_mb() == pytest.approx( + init_cuda_memory + resident_gpu_size, 0.1 + ) for i in range(NUM_LAYERS - 1, -1, -1): offload_synchronizer.bwd_step(i) @@ -385,7 +446,7 @@ def test_memory(self, recipe): ) assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory, 0.1) - @pytest.mark.parametrize("recipe", quantization_recipes) + @pytest.mark.parametrize("recipe", quantization_recipes, ids=recipe_id) def test_multiple_tensor_offload(self, recipe): Utils.memory_leak_check() init_cpu_memory = Utils.get_cpu_memory_mb() @@ -416,7 +477,7 @@ def test_multiple_tensor_offload(self, recipe): class TestTELayers: @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) - @pytest.mark.parametrize("recipe", quantization_recipes) + @pytest.mark.parametrize("recipe", quantization_recipes, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) def test_sanity(self, layer_type, recipe, backward_override): Utils.memory_leak_check() @@ -463,7 +524,7 @@ def test_sanity(self, layer_type, recipe, backward_override): del out, inp, layers @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) - @pytest.mark.parametrize("recipe", quantization_recipes) + @pytest.mark.parametrize("recipe", quantization_recipes, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) def test_memory(self, layer_type, recipe, backward_override): Utils.memory_leak_check() @@ -536,7 +597,9 @@ def test_memory(self, layer_type, recipe, backward_override): out = out + 1 out = sync_function(out) del inp - if backward_override is None: + if recipe is not None and recipe.nvfp4() and recipe.row_scaled_activation: + assert Utils.get_cuda_memory_mb() <= cuda_memory_no_offload + elif backward_override is None: assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory, 0.1) else: assert ( @@ -554,7 +617,7 @@ def test_memory(self, layer_type, recipe, backward_override): out.sum().backward() @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) - @pytest.mark.parametrize("recipe", quantization_recipes) + @pytest.mark.parametrize("recipe", quantization_recipes, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) def test_manual_synchronization(self, recipe, layer_type, backward_override): Utils.memory_leak_check() @@ -623,7 +686,7 @@ def test_manual_synchronization(self, recipe, layer_type, backward_override): out_1.sum().backward() out_2.sum().backward() - @pytest.mark.parametrize("recipe", quantization_recipes) + @pytest.mark.parametrize("recipe", quantization_recipes, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) @pytest.mark.parametrize("use_cuda_graphs", [True, False]) diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 33ba65e0d9..bb4a4e3857 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -65,13 +65,31 @@ def nvfp4_rht_and_2d_quantization(): def nvfp4_row_scaled(): - nvfp4_recipe = recipe.NVFP4BlockScaling(row_scaled_activation=True) + nvfp4_recipe = recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + row_scaled_activation=True, + backward_override="dequantized", + ) nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() return nvfp4_recipe +def nvfp4_4over6(): + nvfp4_recipe = recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + nvfp4_4over6="all", + ) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(fp4_2d_quantization=True) + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + def check_rht_usage(recipe: recipe.Recipe) -> bool: # if using RHT, we can only support bf16 # check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad @@ -101,6 +119,7 @@ def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> if nvfp4_available: fp8_recipes.append(nvfp4_rht_and_2d_quantization()) fp8_recipes.append(nvfp4_row_scaled()) + fp8_recipes.append(nvfp4_4over6()) if fp8_block_scaling_available: fp8_recipes.append(recipe.Float8BlockScaling()) if fp8_available: diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 7691582f97..82d903a39b 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -39,6 +39,7 @@ Float8Quantizer, MXFP8Quantizer, NVFP4Quantizer, + QuantizerRole, is_bf16_available, ) from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor @@ -77,6 +78,7 @@ _quantization_list.append("mxfp8") if nvfp4_available: _quantization_list.append("nvfp4") + _quantization_list.append("nvfp4_4over6") @pytest.fixture(autouse=True, scope="class") @@ -106,7 +108,7 @@ def maybe_skip_quantization( pytest.skip(reason_for_no_fp8) if quantization == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if quantization == "nvfp4" and not nvfp4_available: + if quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") and not nvfp4_available: pytest.skip(reason_for_no_nvfp4) # Check dims @@ -119,13 +121,16 @@ def maybe_skip_quantization( elif quantization == "mxfp8": if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0: pytest.skip("MXFP8 GEMMs require dims that are divisible by 32") - elif quantization == "nvfp4": + elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0: pytest.skip("NVFP4 GEMMs require dims that are divisible by 16") # Check dtype if dtype is not None: - if quantization == "nvfp4" and dtype != torch.bfloat16: + if ( + quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") + and dtype != torch.bfloat16 + ): pytest.skip("NVFP4 quantization is only supported with BF16 data") @@ -141,6 +146,7 @@ def make_reference_and_test_tensors( test_dtype: torch.dtype = torch.float32, test_device: torch.device = "cuda", test_is_quantized: bool = False, + quantizer_role: Optional[QuantizerRole] = None, requires_grad: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Construct tensors with the same values @@ -151,6 +157,8 @@ def make_reference_and_test_tensors( If a quantization scheme is provided, the tensor values are quantized so that they are representable. + NVFP4 4over6 follows recipe role dispatch: activation-like tensors + use 1D quantization and weight tensors use the 2D weight path. """ @@ -180,13 +188,17 @@ def make_reference_and_test_tensors( test = quantizer(test) elif quantization == "mxfp8": test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) - elif quantization == "nvfp4": + elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): + with_2d_quantization = False + if quantization == "nvfp4_4over6" and quantizer_role is not None: + with_2d_quantization = quantizer_role.tensor_type == "weight" test = NVFP4Quantizer( with_rht=False, with_post_rht_amax=False, - with_2d_quantization=False, + with_2d_quantization=with_2d_quantization, stochastic_rounding=False, with_random_sign_mask=False, + nvfp4_use_4over6=quantization == "nvfp4_4over6", )(test) else: raise ValueError(f"Unsupported quantization scheme ({quantization})") @@ -503,6 +515,7 @@ def test_dtype_cast( quantization=quantization, test_dtype=dtype, test_device=device, + quantizer_role=QuantizerRole(tensor_type="weight"), ) # Construct operation @@ -910,6 +923,7 @@ def _test_basic_linear( quantization=quantization, test_dtype=dtype, test_device=device, + quantizer_role=QuantizerRole(tensor_type="weight"), ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, @@ -1082,6 +1096,7 @@ def test_linear( quantization=quantization, test_dtype=dtype, test_device=device, + quantizer_role=QuantizerRole(tensor_type="weight"), ) b_ref, b_test = None, None if bias: @@ -1512,7 +1527,7 @@ def test_add_extra_input( if in_place: if quantization in ("fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"): tols = dtype_tols(x1_test._fp8_dtype) - elif quantization == "nvfp4": + elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): tols = dtype_tols(x1_test._fp4_dtype) y_test = y_test.to(dtype=torch.float64, device="cpu") dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu") @@ -1883,7 +1898,7 @@ def test_clamped_swiglu( # Expected numerical error tols = dtype_tols(dtype) - if quantized_compute and quantization == "nvfp4": + if quantized_compute and quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): tols = dtype_tols(tex.DType.kFloat4E2M1) elif quantized_compute: tols = dtype_tols(tex.DType.kFloat8E4M3) @@ -2076,6 +2091,8 @@ def test_grouped_linear( pytest.skip("Quantization scheme is not used") if quantization is not None and dtype not in (torch.bfloat16, torch.float16): pytest.skip("Quantized group GEMM is only supported with BF16/FP16") + if quantization == "nvfp4_4over6": + pytest.skip("NVFP4 4over6 grouped quantization is not supported") if single_grouped_bias and not bias: pytest.skip("single_grouped_bias requires bias=True") @@ -2112,6 +2129,7 @@ def test_grouped_linear( quantization=quantization, test_dtype=dtype, test_device=device, + quantizer_role=QuantizerRole(tensor_type="weight"), requires_grad=weight_requires_grad, ) b_ref, b_test = None, None @@ -2622,6 +2640,7 @@ def test_forward_linear_bias_activation( quantization=quantization, test_dtype=dtype, test_device=device, + quantizer_role=QuantizerRole(tensor_type="weight"), ) b_ref, b_test = None, None if bias: @@ -2727,6 +2746,7 @@ def test_forward_linear_bias_add( quantization=quantization, test_dtype=dtype, test_device=device, + quantizer_role=QuantizerRole(tensor_type="weight"), ) b_ref, b_test = None, None if bias: @@ -2840,6 +2860,7 @@ def test_forward_linear_scale_add( quantization=quantization, test_dtype=dtype, test_device=device, + quantizer_role=QuantizerRole(tensor_type="weight"), ) x2_ref, x2_test = make_reference_and_test_tensors( out_shape, @@ -3122,6 +3143,7 @@ def test_backward_linear_add( quantization=quantization, test_dtype=dtype, test_device=device, + quantizer_role=QuantizerRole(tensor_type="weight"), ) dy1_ref, dy1_test = make_reference_and_test_tensors( out_shape, @@ -3225,6 +3247,7 @@ def test_backward_linear_scale( quantization=quantization, test_dtype=dtype, test_device=device, + quantizer_role=QuantizerRole(tensor_type="weight"), ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, @@ -3454,12 +3477,14 @@ def test_layernorm_mlp( quantization=quantization, test_dtype=dtype, test_device=device, + quantizer_role=QuantizerRole(tensor_type="weight"), ) w2_ref, w2_test = make_reference_and_test_tensors( (hidden_size, ffn_hidden_size // 2), quantization=quantization, test_dtype=dtype, test_device=device, + quantizer_role=QuantizerRole(tensor_type="weight"), ) b1_ref, b1_test, b2_ref, b2_test = None, None, None, None if bias: @@ -3608,7 +3633,14 @@ def test_grouped_mlp( pytest.skip("single_grouped_bias requires bias=True") if with_quantization and dtype not in (torch.bfloat16, torch.float16): pytest.skip("Quantized group GEMM is only supported with BF16/FP16") - if quantization == "nvfp4" and activation == "scaled_clamped_qgeglu" and bias: + if quantization == "nvfp4_4over6": + pytest.skip("NVFP4 4over6 grouped quantization is not supported") + if ( + with_quantization + and quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") + and activation == "scaled_clamped_qgeglu" + and bias + ): # TODO: ksivaman: Need to debug numerics for this case. pytest.skip("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU") @@ -3647,6 +3679,7 @@ def test_grouped_mlp( quantization=quantization, test_dtype=dtype, test_device=device, + quantizer_role=QuantizerRole(tensor_type="weight"), ) fc2_w_ref, fc2_w_test = make_reference_and_test_tensors( (hidden_size, hidden_size), @@ -3655,6 +3688,7 @@ def test_grouped_mlp( quantization=quantization, test_dtype=dtype, test_device=device, + quantizer_role=QuantizerRole(tensor_type="weight"), ) fc1_b_ref, fc1_b_test = None, None fc2_b_ref, fc2_b_test = None, None @@ -3830,7 +3864,7 @@ def test_grouped_mlp( # Loose tols for sanity checking tols = {"rtol": 0.125, "atol": 0.25} - if quantization == "nvfp4": + if quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): tols = {"rtol": 0.25, "atol": 0.5} # Check values diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index a718ea2a8a..5f82bfcba2 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -54,7 +54,7 @@ from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor from transformer_engine.common import recipe import transformer_engine_torch as tex -from utils import ModelConfig, reset_rng_states +from utils import ModelConfig, recipe_id, reset_rng_states, skip_unsupported_backward_override # Only run FP8 tests on supported devices. @@ -138,6 +138,32 @@ def nvfp4_rht_and_2d_quantization(): return nvfp4_recipe +def nvfp4_row_scaled(): + nvfp4_recipe = recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + row_scaled_activation=True, + backward_override="high_precision", + ) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + +def nvfp4_4over6(): + nvfp4_recipe = recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + nvfp4_4over6="all", + ) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(fp4_2d_quantization=True) + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + def check_rht_usage(recipe: recipe.Recipe) -> bool: # if using RHT, we can only support bf16 # check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad @@ -171,6 +197,8 @@ def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> fp8_recipes.append(recipe.DelayedScaling()) if nvfp4_available: fp8_recipes.append(nvfp4_rht_and_2d_quantization()) + fp8_recipes.append(nvfp4_4over6()) + fp8_recipes.append(nvfp4_row_scaled()) use_cutlass_grouped_gemm = [False] # Only enable cutlass grouped gemm on Hopper @@ -627,11 +655,15 @@ def _test_e2e_selective_recompute( @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", all_boolean) -@pytest.mark.parametrize("recipe", fp8_recipes) +@pytest.mark.parametrize("recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("fp8_model_params", all_boolean) def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params): if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") + if fp8 or fp8_model_params: + skip_unsupported_backward_override( + "transformer_layer", recipe, getattr(recipe, "backward_override", None) + ) if fp8 and recipe.nvfp4(): if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): pytest.skip( @@ -739,7 +771,7 @@ def _test_e2e_full_recompute( @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", all_boolean) -@pytest.mark.parametrize("recipe", fp8_recipes) +@pytest.mark.parametrize("recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_reentrant", all_boolean) def test_gpt_full_activation_recompute( @@ -747,6 +779,10 @@ def test_gpt_full_activation_recompute( ): if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") + if fp8 or fp8_model_params: + skip_unsupported_backward_override( + "transformer_layer", recipe, getattr(recipe, "backward_override", None) + ) if fp8 and recipe.nvfp4(): if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): pytest.skip( @@ -1324,7 +1360,7 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_ @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("model", ["small"]) -@pytest.mark.parametrize("recipe", fp8_recipes + [None]) +@pytest.mark.parametrize("recipe", fp8_recipes + [None], ids=recipe_id) def test_linear_accuracy_save_original_input(dtype, model, recipe): bs = 1 fuse_wgrad_accumulation = True @@ -1333,6 +1369,7 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe): if fp8 and recipe.delayed(): pytest.skip("DelayedScaling recipe is not supported with save_original_input") + skip_unsupported_backward_override("linear", recipe, getattr(recipe, "backward_override", None)) config = model_configs[model] if config.max_seqlen_q % 16 != 0 and fp8: @@ -1894,7 +1931,7 @@ def _test_grouped_linear_accuracy( @pytest.mark.parametrize("num_gemms", [3, 6]) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) -@pytest.mark.parametrize("recipe", fp8_recipes + [None]) +@pytest.mark.parametrize("recipe", fp8_recipes + [None], ids=recipe_id) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) @pytest.mark.parametrize("bias", all_boolean) @@ -1917,6 +1954,9 @@ def test_grouped_linear_accuracy( pytest.skip("FP8 parameters are not supported in debug mode.") if NVTE_TEST_NVINSPECT_ENABLED and delay_wgrad_compute: pytest.skip("Delayed wgrad compute is not supported in debug mode.") + skip_unsupported_backward_override( + "grouped_linear", recipe, getattr(recipe, "backward_override", None) + ) config = model_configs[model] if config.max_seqlen_q % 16 != 0 and fp8: @@ -2037,7 +2077,7 @@ def test_grouped_linear_accuracy_cutlass( @pytest.mark.parametrize("num_gemms", [3]) @pytest.mark.parametrize("bs", [1]) @pytest.mark.parametrize("model", ["126m"]) -@pytest.mark.parametrize("recipe", fp8_recipes + [None]) +@pytest.mark.parametrize("recipe", fp8_recipes + [None], ids=recipe_id) @pytest.mark.parametrize("fp8_model_params", [False]) @pytest.mark.parametrize("fuse_wgrad_accumulation", [True]) @pytest.mark.parametrize("bias", [False]) @@ -2061,6 +2101,9 @@ def test_grouped_linear_accuracy_save_original_input( pytest.skip("DelayedScaling recipe is not supported with save_original_input") if NVTE_TEST_NVINSPECT_ENABLED and delay_wgrad_compute: pytest.skip("Delayed wgrad compute is not supported in debug mode.") + skip_unsupported_backward_override( + "grouped_linear", recipe, getattr(recipe, "backward_override", None) + ) config = model_configs[model] if config.max_seqlen_q % 16 != 0 and fp8: @@ -2139,7 +2182,7 @@ def test_grouped_linear_accuracy_save_original_input( torch.testing.assert_close(o, o_ref, rtol=0, atol=0) -@pytest.mark.parametrize("recipe", fp8_recipes + [None]) +@pytest.mark.parametrize("recipe", fp8_recipes + [None], ids=recipe_id) def test_grouped_linear_accuracy_single_gemm(recipe): """Split the tests to save CI time""" test_grouped_linear_accuracy( @@ -2253,7 +2296,7 @@ def _generate_random_numbers(n, total_sum): @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", [True]) -@pytest.mark.parametrize("recipe", fp8_recipes) +@pytest.mark.parametrize("recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("fp8_model_params", all_boolean) def test_padding_grouped_linear_accuracy( dtype, @@ -2267,6 +2310,9 @@ def test_padding_grouped_linear_accuracy( ): if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") + skip_unsupported_backward_override( + "grouped_linear", recipe, getattr(recipe, "backward_override", None) + ) config = model_configs[model] if config.max_seqlen_q % 16 != 0 and fp8: @@ -2328,7 +2374,7 @@ def test_padding_grouped_linear_accuracy( @pytest.mark.parametrize("bs", [1]) @pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", [True]) -@pytest.mark.parametrize("recipe", fp8_recipes) +@pytest.mark.parametrize("recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("fp8_model_params", [False]) def test_padding_grouped_linear_accuracy_save_original_input( dtype, @@ -2344,6 +2390,9 @@ def test_padding_grouped_linear_accuracy_save_original_input( pytest.skip("FP8 parameters are not supported in debug mode.") if fp8 and recipe.delayed(): pytest.skip("DelayedScaling recipe is not supported with save_original_input") + skip_unsupported_backward_override( + "grouped_linear", recipe, getattr(recipe, "backward_override", None) + ) config = model_configs[model] if config.max_seqlen_q % 16 != 0 and fp8: @@ -2559,10 +2608,13 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) -@pytest.mark.parametrize("recipe", fp8_recipes) +@pytest.mark.parametrize("recipe", fp8_recipes, ids=recipe_id) def test_gpt_fp8_parameters(dtype, bs, model, recipe): if NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") + skip_unsupported_backward_override( + "transformer_layer", recipe, getattr(recipe, "backward_override", None) + ) if recipe.nvfp4(): if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): diff --git a/tests/pytorch/test_quantized_tensor.py b/tests/pytorch/test_quantized_tensor.py index 526045e43e..1ddadd8c71 100644 --- a/tests/pytorch/test_quantized_tensor.py +++ b/tests/pytorch/test_quantized_tensor.py @@ -28,7 +28,7 @@ import transformer_engine_torch as tex from references.ref_per_tensor_cs import ref_per_tensor_cs_cast -from utils import assert_close, quantization_tols +from utils import assert_close # PyTorch tensor dtypes _dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16] @@ -69,6 +69,8 @@ def _to_list(x: Union[Iterable, Any]) -> List: _quantization_list.append("mxfp8") if nvfp4_available: _quantization_list.append("nvfp4") + _quantization_list.append("nvfp4_row_scaled") + _quantization_list.append("nvfp4_4over6") # delayed scaling @@ -163,13 +165,17 @@ def make_reference_and_test_tensors( test = quantizer(test) elif quantization == "mxfp8": test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) - elif quantization == "nvfp4": + elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): + row_scaled_nvfp4 = quantization == "nvfp4_row_scaled" test = NVFP4Quantizer( + columnwise=not row_scaled_nvfp4, with_rht=False, with_post_rht_amax=False, with_2d_quantization=False, stochastic_rounding=False, + row_scaled_nvfp4=row_scaled_nvfp4, with_random_sign_mask=False, + nvfp4_use_4over6=(quantization == "nvfp4_4over6"), )(test) else: raise ValueError(f"Unsupported quantization scheme ({quantization})") @@ -735,13 +741,16 @@ def test_update_nd_tensor( ) elif quantization == "mxfp8": quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) - elif quantization in ("nvfp4", "nvfp4_2d"): + elif quantization in ("nvfp4", "nvfp4_2d", "nvfp4_row_scaled", "nvfp4_4over6"): + row_scaled_nvfp4 = quantization == "nvfp4_row_scaled" quantizer = NVFP4Quantizer( rowwise=True, - columnwise=True, + columnwise=not row_scaled_nvfp4, with_rht=False, with_post_rht_amax=False, with_2d_quantization=(quantization == "nvfp4_2d"), + row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=(quantization == "nvfp4_4over6"), ) quantization = "nvfp4" else: @@ -756,9 +765,9 @@ def test_update_nd_tensor( q_x.copy_(x_new) # Check results + q_ref = quantizer(x_new) assert q_x.shape == torch.Size(shape) - tols = quantization_tols(quantization) - assert_close(q_x, x_new, **tols) + assert_close(q_x, q_ref, rtol=0, atol=0) @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 5f5221af76..9ea5e738b4 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -26,6 +26,7 @@ from transformer_engine.pytorch.quantization import ( FP8GlobalStateManager, NVFP4BlockScalingRecipeState, + QuantizerRole, _amax_and_scale_update, ) import transformer_engine.pytorch.ops as te_ops @@ -514,8 +515,49 @@ def test_quantizer_update(self, module_class): @pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) -def test_nvfp4_row_scaled_quantizer_roles(): - recipe = NVFP4BlockScaling(row_scaled_activation=True) +@pytest.mark.parametrize( + "nvfp4_4over6", + [None, "weights", "activations", "all"], + ids=["default", "weights", "activations", "all"], +) +@pytest.mark.parametrize( + "nvfp4_e4m3_max", + [None, "weights", "activations", "all"], + ids=["e4m3_448", "e4m3_256_weights", "e4m3_256_activations", "e4m3_256_all"], +) +@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) +def test_nvfp4_row_scaled_quantizer_roles(nvfp4_4over6, nvfp4_e4m3_max, nvfp4_4over6_err_mode): + recipe = NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + nvfp4_4over6=nvfp4_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, + row_scaled_activation=True, + ) + + def expected_use_4over6(tensor_type): + if nvfp4_4over6 == "all": + return True + if nvfp4_4over6 == "weights": + return tensor_type == "weight" + if nvfp4_4over6 == "activations": + return tensor_type != "weight" + return False + + def expected_e4m3_max(tensor_type): + if not expected_use_4over6(tensor_type): + return 448 + if nvfp4_e4m3_max == "all": + return 256 + if nvfp4_e4m3_max == "weights": + if tensor_type == "weight": + return 256 + if nvfp4_e4m3_max == "activations": + if tensor_type != "weight": + return 256 + return 448 forward_quantizers = NVFP4BlockScalingRecipeState( recipe, @@ -523,20 +565,59 @@ def test_nvfp4_row_scaled_quantizer_roles(): num_quantizers=3, ).make_quantizers() assert [q.row_scaled_nvfp4 for q in forward_quantizers] == [True, False, True] + assert [q.nvfp4_use_4over6 for q in forward_quantizers] == [ + expected_use_4over6(tensor_type) for tensor_type in ("input", "weight", "output") + ] + assert [q.nvfp4_e4m3_max for q in forward_quantizers] == [ + expected_e4m3_max(tensor_type) for tensor_type in ("input", "weight", "output") + ] + assert [q.nvfp4_4over6_err_mode for q in forward_quantizers] == [nvfp4_4over6_err_mode] * 3 assert not forward_quantizers[0].is_quantizable(torch.empty(16, 16)) assert forward_quantizers[1].is_quantizable(torch.empty(16, 16)) + role_quantizers = NVFP4BlockScalingRecipeState( + recipe, + mode="forward", + num_quantizers=4, + roles=[ + QuantizerRole(module_type="linear", tensor_type="weight"), + QuantizerRole(module_type="linear", tensor_type="input"), + QuantizerRole(module_type="linear", tensor_type="output"), + None, + ], + ).make_quantizers() + assert [q.row_scaled_nvfp4 for q in role_quantizers] == [False, True, True, True] + assert [q.nvfp4_use_4over6 for q in role_quantizers] == [ + expected_use_4over6(tensor_type) for tensor_type in ("weight", "input", "output", "input") + ] + assert [q.nvfp4_e4m3_max for q in role_quantizers] == [ + expected_e4m3_max(tensor_type) for tensor_type in ("weight", "input", "output", "input") + ] + assert [q.nvfp4_4over6_err_mode for q in role_quantizers] == [nvfp4_4over6_err_mode] * 4 + backward_quantizers = NVFP4BlockScalingRecipeState( recipe, mode="backward", num_quantizers=2, + roles=[ + QuantizerRole(module_type="linear", tensor_type="grad_output"), + QuantizerRole(module_type="linear", tensor_type="grad_input"), + ], ).make_quantizers() assert [q.row_scaled_nvfp4 for q in backward_quantizers] == [False, False] + assert [q.nvfp4_use_4over6 for q in backward_quantizers] == [ + expected_use_4over6(tensor_type) for tensor_type in ("grad_output", "grad_input") + ] + assert [q.nvfp4_e4m3_max for q in backward_quantizers] == [ + expected_e4m3_max(tensor_type) for tensor_type in ("grad_output", "grad_input") + ] + assert [q.nvfp4_4over6_err_mode for q in backward_quantizers] == [nvfp4_4over6_err_mode] * 2 @pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) +@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) @pytest.mark.parametrize( "M, N", [ @@ -552,24 +633,30 @@ def test_nvfp4_row_scaled_quantizer_roles(): (8192, 8192), ], ) -def test_fp4_dequantize(dtype, row_scaled_nvfp4, M, N): +def test_fp4_dequantize(dtype, row_scaled_nvfp4, use_4over6, M, N): q = NVFP4Quantizer( columnwise=not row_scaled_nvfp4, row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=use_4over6, ) a = torch.rand((M, N)).cuda().to(dtype=dtype) starting_tensor = q(a) assert starting_tensor._row_scaled_nvfp4 == row_scaled_nvfp4 + assert starting_tensor._nvfp4_use_4over6 == use_4over6 assert starting_tensor._amax_rowwise.numel() == (M if row_scaled_nvfp4 else 1) dequantized_tensor = starting_tensor.dequantize() new_tensor = q(dequantized_tensor) assert new_tensor._row_scaled_nvfp4 == row_scaled_nvfp4 + assert new_tensor._nvfp4_use_4over6 == use_4over6 assert new_tensor._amax_rowwise.numel() == (M if row_scaled_nvfp4 else 1) - torch.testing.assert_close( - new_tensor._rowwise_data, - starting_tensor._rowwise_data, - rtol=0, - atol=0, - ) + # 4over6 can re-encode a dequantized block with the alternate 4/6 scale + # choice while preserving the dequantized values. + if not use_4over6: + torch.testing.assert_close( + new_tensor._rowwise_data, + starting_tensor._rowwise_data, + rtol=0, + atol=0, + ) new_dequantized_tensor = new_tensor.dequantize() torch.testing.assert_close(dequantized_tensor, new_dequantized_tensor) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index c811342df5..27eafbecdc 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -95,27 +95,43 @@ def nvfp4_vanilla(): def nvfp4_row_scaled(): - nvfp4_recipe = recipe.NVFP4BlockScaling(row_scaled_activation=True) + nvfp4_recipe = recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + row_scaled_activation=True, + backward_override="dequantized", + ) nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() return nvfp4_recipe +def nvfp4_4over6(): + nvfp4_recipe = recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + nvfp4_4over6="all", + ) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(fp4_2d_quantization=True) + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + fp8_recipes = [] if mxfp8_available: fp8_recipes.append(recipe.MXFP8BlockScaling()) if nvfp4_available: fp8_recipes.append(nvfp4_vanilla()) # TODO: fix check for this + fp8_recipes.append(nvfp4_4over6()) if fp8_block_scaling_available: fp8_recipes.append(recipe.Float8BlockScaling()) if fp8_available: fp8_recipes.append(recipe.Float8CurrentScaling()) fp8_recipes.append(recipe.DelayedScaling()) fp8_recipes.append(None) -fp8_recipes_with_row_scaled = fp8_recipes.copy() -if nvfp4_available: - fp8_recipes_with_row_scaled.insert(-1, nvfp4_row_scaled()) param_types = [torch.float32, torch.float16] if is_bf16_available(): # bf16 requires sm_80 or higher @@ -415,7 +431,11 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_row_scaled, ids=recipe_id) +@pytest.mark.parametrize( + "fp8_recipe", + fp8_recipes + ([nvfp4_row_scaled()] if nvfp4_available else []), + ids=recipe_id, +) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @@ -463,7 +483,11 @@ def test_sanity_layernorm_linear( @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_row_scaled, ids=recipe_id) +@pytest.mark.parametrize( + "fp8_recipe", + fp8_recipes + ([nvfp4_row_scaled()] if nvfp4_available else []), + ids=recipe_id, +) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @@ -501,7 +525,11 @@ def test_sanity_linear( @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes_with_zero) @pytest.mark.parametrize("model", ["small", "weird"]) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_row_scaled, ids=recipe_id) +@pytest.mark.parametrize( + "fp8_recipe", + fp8_recipes + ([nvfp4_row_scaled()] if nvfp4_available else []), + ids=recipe_id, +) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) @@ -542,7 +570,11 @@ def test_sanity_linear_with_zero_tokens( @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes_with_zero) @pytest.mark.parametrize("model", ["small", "weird"]) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_row_scaled, ids=recipe_id) +@pytest.mark.parametrize( + "fp8_recipe", + fp8_recipes + ([nvfp4_row_scaled()] if nvfp4_available else []), + ids=recipe_id, +) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) @@ -621,7 +653,7 @@ def test_sanity_grouped_linear( @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean) @@ -671,7 +703,7 @@ def test_sanity_layernorm_mlp( @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("bias", all_boolean) @@ -744,7 +776,7 @@ def test_sanity_gpt_126m(): @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("normalization", all_normalizations) @@ -800,7 +832,7 @@ def test_sanity_bert_126m(): @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("normalization", all_normalizations) @@ -856,7 +888,7 @@ def test_sanity_T5_126m(): @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): @@ -889,7 +921,7 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("model", ["small"]) def test_sanity_drop_path(dtype, fp8_recipe, model): config = model_configs[model] @@ -924,7 +956,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model): @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): @@ -960,7 +992,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgrad): diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py index 51f72b1e56..137e5f5a77 100644 --- a/tests/pytorch/test_torch_compile.py +++ b/tests/pytorch/test_torch_compile.py @@ -39,6 +39,33 @@ fp8_block_scaling_available = is_fp8_block_scaling_available() nvfp4_available = is_nvfp4_available() + +def nvfp4_row_scaled(): + nvfp4_recipe = recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + row_scaled_activation=True, + backward_override="dequantized", + ) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + +def nvfp4_4over6(): + nvfp4_recipe = recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + nvfp4_4over6="all", + ) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(fp4_2d_quantization=True) + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + _all_recipes: list = [] if fp8_available: _all_recipes.append(recipe.Float8CurrentScaling()) @@ -48,7 +75,8 @@ _all_recipes.append(recipe.MXFP8BlockScaling()) if nvfp4_available: _all_recipes.append(recipe.NVFP4BlockScaling()) - _all_recipes.append(recipe.NVFP4BlockScaling(row_scaled_activation=True)) + _all_recipes.append(nvfp4_4over6()) + _all_recipes.append(nvfp4_row_scaled()) # --------------------------------------------------------------------------- @@ -97,8 +125,19 @@ def __fx_repr__(self): def _make_qfactory(tag: str): """Return a qfactory that produces ToyQuantizer instances tagged with *tag*.""" + quantizers = { + role: ToyQuantizer(tag=f"{tag}:{role}") + for role in ( + "linear_input", + "linear_weight", + "linear_output", + "linear_grad_output", + "linear_grad_input", + ) + } + def qfactory(role: str): - return ToyQuantizer(tag=f"{tag}:{role}") + return quantizers[role] return qfactory diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 2ee18aaf57..f8113332d9 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -118,7 +118,7 @@ def quantization_tols(name: str) -> dict[str, float]: "mxfp8_block_scaling", ): return dtype_tols(tex.DType.kFloat8E4M3) - if name in ("nvfp4", "nvfp4_row_scaled"): + if name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): return dtype_tols(tex.DType.kFloat4E2M1) raise ValueError(f"Unsupported quantization scheme ({name})") @@ -145,21 +145,17 @@ def make_recipe(name: Optional[str], **recipe_kwargs: Any) -> Optional[Recipe]: ) if name == "fp8_block_scaling": return transformer_engine.common.recipe.Float8BlockScaling(**recipe_kwargs) - if name == "nvfp4": - return transformer_engine.common.recipe.NVFP4BlockScaling( - disable_rht=True, - disable_stochastic_rounding=True, - disable_2d_quantization=True, - **recipe_kwargs, - ) - if name == "nvfp4_row_scaled": - return transformer_engine.common.recipe.NVFP4BlockScaling( - disable_rht=True, - disable_stochastic_rounding=True, - disable_2d_quantization=True, - row_scaled_activation=True, - **recipe_kwargs, - ) + if name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): + use_4over6 = name == "nvfp4_4over6" + kwargs = { + "disable_rht": True, + "disable_stochastic_rounding": True, + "disable_2d_quantization": not use_4over6, + "row_scaled_activation": name == "nvfp4_row_scaled", + "nvfp4_4over6": "all" if use_4over6 else None, + } + kwargs.update(recipe_kwargs) + return transformer_engine.common.recipe.NVFP4BlockScaling(**kwargs) raise ValueError(f"Unsupported quantization scheme ({name})") @@ -167,6 +163,10 @@ def recipe_id(recipe: Optional[Recipe]) -> str: """Readable pytest id for a quantization recipe.""" if not isinstance(recipe, Recipe): return "None" + if recipe.nvfp4() and recipe.row_scaled_activation and recipe.nvfp4_4over6 is not None: + return "NVFP4RowScaled4Over6BlockScaling" + if recipe.nvfp4() and recipe.nvfp4_4over6 is not None: + return "NVFP44Over6BlockScaling" if recipe.nvfp4() and recipe.row_scaled_activation: return "NVFP4RowScaledBlockScaling" return type(recipe).__name__ @@ -185,6 +185,13 @@ def skip_unsupported_backward_override( and backward_override is None ): pytest.skip("Row-scaled NVFP4 does not support default quantized backward.") + if ( + quant_recipe is not None + and quant_recipe.nvfp4() + and quant_recipe.nvfp4_4over6 is not None + and layer_type == "grouped_linear" + ): + pytest.skip("NVFP4 4over6 currently does not support grouped quantization.") if backward_override is None: return if quant_recipe is None and backward_override is not None: diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 123362ce10..65b267433e 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -21,6 +21,7 @@ #include "../mxfp8/group_quantize_mxfp8.cuh" #include "../mxfp8/quantize_mxfp8.cuh" #include "../nvfp4/group_quantize_transpose_nvfp4.cuh" +#include "../nvfp4/quantize_4over6_nvfp4.cuh" #include "../nvfp4/quantize_transpose_nvfp4.cuh" namespace transformer_engine { @@ -101,6 +102,16 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, int32_t cols = input_tensor->flat_last_dim(); auto dtype = input_tensor->dtype(); const bool row_scaled_nvfp4 = output_tensor->row_scaled_nvfp4; + NVTE_CHECK(quant_config_cpp.nvfp4_4over6 == output_tensor->nvfp4_4over6, + "Tensor and quantization config have inconsistent options for NVFP4 4over6."); + NVTE_CHECK(quant_config_cpp.nvfp4_e4m3_max == output_tensor->nvfp4_e4m3_max, + "Tensor and quantization config have inconsistent options for NVFP4 4over6 " + "E4M3 scale bound."); + const bool nvfp4_use_4over6 = quant_config_cpp.nvfp4_4over6; + NVTE_CHECK(nvfp4_use_4over6 || quant_config_cpp.nvfp4_e4m3_max == 448, + "Non-4over6 NVFP4 quantization requires E4M3 max 448."); + NVTE_CHECK(!nvfp4_use_4over6 || !quant_config_cpp.stochastic_rounding, + "NVFP4 4over6 quantization does not support stochastic rounding."); if (row_scaled_nvfp4) { NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, "Row-scaled NVFP4 quantization does not support 2D quantization."); @@ -112,7 +123,15 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, (cols % 32 == 0) && output_tensor->has_data(); // Launch NVFP4 quantize kernel - if (use_optimized_kernel) { + if (nvfp4_use_4over6) { + if (quant_config_cpp.nvfp4_2d_quantization) { + nvfp4::quantize_4over6( + *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } else { + nvfp4::quantize_4over6( + *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } + } else if (use_optimized_kernel) { if (quant_config_cpp.nvfp4_2d_quantization) { nvfp4::quantize_transpose( *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); @@ -249,13 +268,31 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens int32_t rows = grad_tensor->flat_first_dim(); int32_t cols = grad_tensor->flat_last_dim(); auto dtype = grad_tensor->dtype(); + NVTE_CHECK(quant_config_cpp.nvfp4_4over6 == output_tensor->nvfp4_4over6, + "Tensor and quantization config have inconsistent options for NVFP4 4over6."); + NVTE_CHECK(quant_config_cpp.nvfp4_e4m3_max == output_tensor->nvfp4_e4m3_max, + "Tensor and quantization config have inconsistent options for NVFP4 4over6 " + "E4M3 scale bound."); + const bool nvfp4_use_4over6 = quant_config_cpp.nvfp4_4over6; + NVTE_CHECK(nvfp4_use_4over6 || quant_config_cpp.nvfp4_e4m3_max == 448, + "Non-4over6 NVFP4 quantization requires E4M3 max 448."); + NVTE_CHECK(!nvfp4_use_4over6 || !quant_config_cpp.stochastic_rounding, + "NVFP4 4over6 quantization does not support stochastic rounding."); NVTE_CHECK(!output_tensor->row_scaled_nvfp4, "Backward NVFP4 quantization does not support row-scaled outputs."); bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && (cols % 32 == 0) && output_tensor->has_data(); // Launch NVFP4 quantize kernel - if (use_optimized_kernel) { + if (nvfp4_use_4over6) { + if (quant_config_cpp.nvfp4_2d_quantization) { + nvfp4::quantize_4over6( + *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } else { + nvfp4::quantize_4over6( + *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } + } else if (use_optimized_kernel) { if (quant_config_cpp.nvfp4_2d_quantization) { nvfp4::quantize_transpose( *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); @@ -277,7 +314,8 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, /*rng_state=*/quant_config_cpp.rng_state, /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, - /*row_scaled_nvfp4=*/false, /*noop_tensor=*/noop_tensor->data, + /*row_scaled_nvfp4=*/false, + /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); } break; @@ -372,8 +410,20 @@ void group_quantize_fwd_host_aware_helper(const NVTETensor input, NVTETensor *ou int32_t cols = input_tensor->flat_last_dim(); auto dtype = input_tensor->dtype(); + for (const auto *output_tensor : output_tensors) { + NVTE_CHECK(quant_config_cpp.nvfp4_4over6 == output_tensor->nvfp4_4over6, + "Tensor and quantization config have inconsistent options for NVFP4 4over6."); + NVTE_CHECK(quant_config_cpp.nvfp4_e4m3_max == output_tensor->nvfp4_e4m3_max, + "Tensor and quantization config have inconsistent options for NVFP4 4over6 " + "E4M3 scale bound."); + } + const bool nvfp4_use_4over6 = quant_config_cpp.nvfp4_4over6; + NVTE_CHECK(nvfp4_use_4over6 || quant_config_cpp.nvfp4_e4m3_max == 448, + "Non-4over6 NVFP4 quantization requires E4M3 max 448."); NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, "2D quantization is not supported for group quantize."); + NVTE_CHECK(!nvfp4_use_4over6, + "NVFP4 4over6 quantization is not supported for group quantize."); // Launch NVFP4 group quantize kernel nvfp4::group_quantize_transpose( diff --git a/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh index 792b068cbc..3820430d5b 100644 --- a/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh @@ -75,10 +75,14 @@ namespace core { #if FP4_TYPE_SUPPORTED using namespace ptx; -// Compute the global encode scale factor for a given global amax +// Compute the global encode scale factor for a given global amax. +// NVFP4 uses the full E4M3 range by default. Some 4over6 tensors dispatch +// E4M3_MAX=256 to leave room for map-to-4 scale expansion. +template __device__ __forceinline__ float compute_global_encode_scaling_factor_FP4(const float global_amax) { using namespace detail; - constexpr float fp8_max = TypeExtrema::max; // 448.0f; + static_assert(E4M3_MAX == 448 || E4M3_MAX == 256, "Unsupported NVFP4 E4M3 max."); + constexpr float fp8_max = static_cast(E4M3_MAX); constexpr float fp4_max = TypeExtrema::max; // 6.0f; float global_encode_scale = fp8_max * fp4_max / global_amax; // If scale is infinity, return max value of float32 diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index d549a050ee..faf3c58adf 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -31,12 +31,11 @@ namespace dispatch { namespace nvfp4 { namespace dequantize_kernel { #if FP4_TYPE_SUPPORTED -template +template __global__ void __launch_bounds__(512) dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales, - const float *const tensor_amax, const bool row_scaled_nvfp4, - const size_t N, const size_t M, const size_t scale_stride, - const size_t num_scale_tiles_X) { + const float *const tensor_amax, const size_t N, const size_t M, + const size_t scale_stride, const size_t num_scale_tiles_X) { const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; const size_t x = thread_idx % M; const size_t y = thread_idx / M; @@ -64,8 +63,9 @@ __global__ void __launch_bounds__(512) fp4vec value; value.vec = input_vectorized[my_index]; fp8e4m3 scale = scales[my_scale_index]; - float amax = row_scaled_nvfp4 ? tensor_amax[y] : tensor_amax[0]; - constexpr float factor_inv = 1.0 / (6.0 * 448.0); + float amax = ROW_SCALED_NVFP4 ? tensor_amax[y] : tensor_amax[0]; + static_assert(E4M3_MAX == 448 || E4M3_MAX == 256, "Unsupported NVFP4 E4M3 max."); + constexpr float factor_inv = 1.0f / (6.0f * static_cast(E4M3_MAX)); float final_scale = static_cast(scale) * amax * factor_inv; #pragma unroll for (int i = 0; i < 4; i++) { @@ -92,6 +92,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) const bool with_gemm_swizzled_scales = input.with_gemm_swizzled_scales; const bool row_scaled_nvfp4 = input.row_scaled_nvfp4; + const int e4m3_max = input.nvfp4_e4m3_max; constexpr int FP4_BLOCK_SIZE = 16; const size_t N = input.flat_first_dim(); @@ -112,14 +113,25 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) output->data.dtype, OType, TRANSFORMER_ENGINE_SWITCH_CONDITION( with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, - - dequantize_fp4_kernel<<>>( - input.data.dptr, reinterpret_cast(output->data.dptr), - reinterpret_cast(input.scale_inv.dptr), - reinterpret_cast(input.amax.dptr), row_scaled_nvfp4, N, Mread, - input.scale_inv.shape.back(), - num_scale_tiles_X);); // NOLINT(*) - ); // NOLINT(*) + TRANSFORMER_ENGINE_SWITCH_CONDITION( + row_scaled_nvfp4, ROW_SCALED_NVFP4, + if (e4m3_max == 256) { + dequantize_fp4_kernel + <<>>( + input.data.dptr, reinterpret_cast(output->data.dptr), + reinterpret_cast(input.scale_inv.dptr), + reinterpret_cast(input.amax.dptr), N, Mread, + input.scale_inv.shape.back(), num_scale_tiles_X); + } else { + NVTE_CHECK(e4m3_max == 448, "Unsupported NVFP4 E4M3 max (got ", e4m3_max, ")"); + dequantize_fp4_kernel + <<>>( + input.data.dptr, reinterpret_cast(output->data.dptr), + reinterpret_cast(input.scale_inv.dptr), + reinterpret_cast(input.amax.dptr), N, Mread, + input.scale_inv.shape.back(), num_scale_tiles_X); + });); // NOLINT(*) + ); // NOLINT(*) NVTE_CHECK_CUDA(cudaGetLastError()); #else NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); diff --git a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh new file mode 100644 index 0000000000..a0b67c8e47 --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh @@ -0,0 +1,671 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_4over6_nvfp4.cuh + * \brief Dedicated kernels for NVFP4 4over6 quantization. + * + * Four Over Six evaluates two TE-style NVFP4 encodings for every 1x16 + * quantization group. The map-to-6 candidate uses the normal scale. The + * map-to-4 candidate expands the E4M3 block scale by 1.5x so FP4 value 4 + * reaches the same range that FP4 value 6 reaches in the normal encoding. + * The selected candidate is the one with lower configured dequantization + * error; ties select map-to-6. The quantized candidates, dequantized values, + * and errors are kept in registers, matching the structure of the official + * Four Over Six implementation. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_4OVER6_NVFP4_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_4OVER6_NVFP4_CUH_ + +#include +#include +#include +#include +#include + +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../utils.cuh" +#include "core_nvfp4.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { + +#if FP4_TYPE_SUPPORTED + +#define TRANSFORMER_ENGINE_NVFP4_4OVER6_ERR_MODE_SWITCH(ERR_MODE, ERR_MODE_CONST, ...) \ + switch (ERR_MODE) { \ + case kNVTENVFP44Over6ErrMAE: { \ + constexpr NVTENVFP44Over6ErrMode ERR_MODE_CONST = kNVTENVFP44Over6ErrMAE; \ + { __VA_ARGS__ } \ + } break; \ + case kNVTENVFP44Over6ErrMSE: { \ + constexpr NVTENVFP44Over6ErrMode ERR_MODE_CONST = kNVTENVFP44Over6ErrMSE; \ + { __VA_ARGS__ } \ + } break; \ + default: { \ + NVTE_ERROR("Unsupported NVFP4 4over6 error mode."); \ + } \ + } + +#define TRANSFORMER_ENGINE_NVFP4_4OVER6_E4M3_MAX_SWITCH(E4M3_MAX_VALUE, E4M3_MAX_CONST, ...) \ + if ((E4M3_MAX_VALUE) == 256) { \ + constexpr int E4M3_MAX_CONST = 256; \ + { __VA_ARGS__ } \ + } else { \ + NVTE_CHECK((E4M3_MAX_VALUE) == 448, "Unsupported NVFP4 E4M3 max."); \ + constexpr int E4M3_MAX_CONST = 448; \ + { __VA_ARGS__ } \ + } + +namespace quantize_4over6_kernel { + +constexpr int kThreads = 128; +constexpr int kWarpThreads = 32; +constexpr int kGroupSize = 16; +constexpr int kTileRows = 128; +constexpr int kTileCols = 64; +constexpr int kTileColGroups = kTileCols / kGroupSize; +constexpr int kTileRowGroups = kTileRows / kGroupSize; +constexpr int kPipelineStages = 2; +constexpr int kStageRows = kTileRows / kPipelineStages; +constexpr int kStageRowGroups = kStageRows / kGroupSize; +constexpr int kElementsPerHalfGroup = 8; +constexpr int kPackedWordsPerGroup = 2; +static_assert(kTileRows == kPipelineStages * kStageRows); +static_assert(kStageRows % kGroupSize == 0); + +template +struct Config { + static constexpr NVTENVFP44Over6ErrMode err_mode = kErrMode; + static constexpr bool err_use_fast_math = kErrUseFastMath; +}; + +struct Candidate { + uint32_t packed[kPackedWordsPerGroup]; + float err; +}; + +struct CandidatePair { + Candidate map4; + Candidate map6; +}; + +struct ScalePair { + nvfp4_scale_t map4; + nvfp4_scale_t map6; + float inv_map4; + float inv_map6; +}; + +template +__device__ __forceinline__ float compute_error_rn(const float diff) { + if constexpr (kErrMode == kNVTENVFP44Over6ErrMSE) { + return __fmul_rn(diff, diff); + } else if constexpr (kErrMode == kNVTENVFP44Over6ErrMAE) { + return fabsf(diff); + } else { + NVTE_DEVICE_ERROR("Unsupported NVFP4 4over6 error mode."); + return fabsf(diff); + } +} + +template +__device__ __forceinline__ float compute_error(const float diff) { + if constexpr (kErrMode == kNVTENVFP44Over6ErrMSE) { + return diff * diff; + } else if constexpr (kErrMode == kNVTENVFP44Over6ErrMAE) { + return fabsf(diff); + } else { + NVTE_DEVICE_ERROR("Unsupported NVFP4 4over6 error mode."); + return fabsf(diff); + } +} + +template +__device__ __forceinline__ ScalePair compute_scale_pair(const float block_amax, + const float global_amax) { + static_assert(E4M3_MAX == 448 || E4M3_MAX == 256, "Unsupported NVFP4 E4M3 max."); + constexpr float fp4_max = detail::TypeExtrema::max; // 6.0f + constexpr float fp8_max = detail::TypeExtrema::max; // 448.0f + constexpr float expand_to_map4 = 1.5f; + const float S_enc = core::compute_global_encode_scaling_factor_FP4(global_amax); + const float base = block_amax / fp4_max * S_enc; + + ScalePair scales; + scales.map4 = static_cast(fminf(base * expand_to_map4, fp8_max)); + scales.map6 = static_cast(fminf(base, fp8_max)); + + const float S_dec = 1.0f / S_enc; + scales.inv_map4 = + fminf(1.0f / (static_cast(scales.map4) * S_dec), detail::TypeExtrema::max); + scales.inv_map6 = + fminf(1.0f / (static_cast(scales.map6) * S_dec), detail::TypeExtrema::max); + return scales; +} + +template +__device__ __forceinline__ float load_input(const IType *ptr, const size_t idx) { + return static_cast(ptr[idx]); +} + +template +__device__ __forceinline__ void load_row_group(const IType *tile, const int row, + const int col_start, float (&x0)[8], float (&x1)[8], + float *amax) { + Vec x0_vec; + Vec x1_vec; + x0_vec.load_from(&tile[row * kTileCols + col_start]); + x1_vec.load_from(&tile[row * kTileCols + col_start + kElementsPerHalfGroup]); + + *amax = 0.0f; +#pragma unroll + for (int i = 0; i < kElementsPerHalfGroup; ++i) { + const float v0 = static_cast(x0_vec.data.elt[i]); + const float v1 = static_cast(x1_vec.data.elt[i]); + x0[i] = v0; + x1[i] = v1; + *amax = fmaxf(*amax, fabsf(v0)); + *amax = fmaxf(*amax, fabsf(v1)); + } +} + +template +__device__ __forceinline__ void load_col_group(const IType *tile, const int row_start, + const int col, float (&x0)[8], float (&x1)[8], + float *amax) { + *amax = 0.0f; +#pragma unroll + for (int i = 0; i < kElementsPerHalfGroup; ++i) { + const float v0 = load_input(tile, (row_start + i) * kTileCols + col); + const float v1 = load_input(tile, (row_start + i + kElementsPerHalfGroup) * kTileCols + col); + x0[i] = v0; + x1[i] = v1; + *amax = fmaxf(*amax, fabsf(v0)); + *amax = fmaxf(*amax, fabsf(v1)); + } +} + +template +__device__ __forceinline__ void accumulate_dequant_error(const uint32_t dequant_bits, const float x, + const float sf, const float global_amax, + float *err) { + constexpr float fp4_max = detail::TypeExtrema::max; // 6.0f + constexpr float fp8_max = static_cast(E4M3_MAX); + constexpr float err_denom = fp4_max * fp8_max; + const uint16_t half_bits = (dequant_bits >> SHIFT) & 0xFFFF; + + if constexpr (Cfg::err_use_fast_math) { + const float dequant = __half2float(__ushort_as_half(half_bits)); + const float val = dequant * sf * global_amax / err_denom; + const float diff = val - x; + *err += compute_error(diff); + } else { + const float dequant = __half2float(__ushort_as_half(half_bits)); + const float val = __fdiv_rn(__fmul_rn(__fmul_rn(dequant, sf), global_amax), err_denom); + const float diff = __fsub_rn(val, x); + *err = __fadd_rn(*err, compute_error_rn(diff)); + } +} + +template +__device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error(const float (&x)[8], + const float block_scale_inverse, + const nvfp4_scale_t sf, + const float global_amax, + float *err) { + uint32_t out = 0; + uint32_t out_dequant_1 = 0; + uint32_t out_dequant_2 = 0; + uint32_t out_dequant_3 = 0; + uint32_t out_dequant_4 = 0; + + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; + if constexpr (is_blackwell) { + asm volatile( + "{\n" + ".reg .b8 byte0, byte1, byte2, byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %8, %7;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %10, %9;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %12, %11;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "cvt.rn.f16x2.e2m1x2 %1, byte0;\n" + "cvt.rn.f16x2.e2m1x2 %2, byte1;\n" + "cvt.rn.f16x2.e2m1x2 %3, byte2;\n" + "cvt.rn.f16x2.e2m1x2 %4, byte3;\n" + "}" + : "=r"(out), "=r"(out_dequant_1), "=r"(out_dequant_2), "=r"(out_dequant_3), + "=r"(out_dequant_4) + : "f"(__fmul_rn(x[0], block_scale_inverse)), "f"(__fmul_rn(x[1], block_scale_inverse)), + "f"(__fmul_rn(x[2], block_scale_inverse)), "f"(__fmul_rn(x[3], block_scale_inverse)), + "f"(__fmul_rn(x[4], block_scale_inverse)), "f"(__fmul_rn(x[5], block_scale_inverse)), + "f"(__fmul_rn(x[6], block_scale_inverse)), "f"(__fmul_rn(x[7], block_scale_inverse))); + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } + + const float sf_float = static_cast(sf); + accumulate_dequant_error(out_dequant_1, x[0], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_1, x[1], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_2, x[2], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_2, x[3], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_3, x[4], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_3, x[5], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_4, x[6], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_4, x[7], sf_float, global_amax, err); + return out; +} + +template +__device__ __forceinline__ CandidatePair make_candidates(const float (&x0)[8], const float (&x1)[8], + const ScalePair &scales, + const float global_amax) { + CandidatePair candidates; + candidates.map4.err = 0.0f; + candidates.map6.err = 0.0f; + candidates.map4.packed[0] = cvt_fp32_to_fp4_8x_with_error( + x0, scales.inv_map4, scales.map4, global_amax, &candidates.map4.err); + candidates.map6.packed[0] = cvt_fp32_to_fp4_8x_with_error( + x0, scales.inv_map6, scales.map6, global_amax, &candidates.map6.err); + candidates.map4.packed[1] = cvt_fp32_to_fp4_8x_with_error( + x1, scales.inv_map4, scales.map4, global_amax, &candidates.map4.err); + candidates.map6.packed[1] = cvt_fp32_to_fp4_8x_with_error( + x1, scales.inv_map6, scales.map6, global_amax, &candidates.map6.err); + return candidates; +} + +__device__ __forceinline__ float reduce_group_sum_16(float value) { + const int lane = threadIdx.x & (kWarpThreads - 1); + const int group_base = lane & ~(kGroupSize - 1); + const unsigned mask = 0xffffu << group_base; +#pragma unroll + for (int offset = kGroupSize / 2; offset > 0; offset /= 2) { + value += __shfl_down_sync(mask, value, offset, kGroupSize); + } + return __shfl_sync(mask, value, group_base, kWarpThreads); +} + +__device__ __forceinline__ float reduce_group_max_16(float value) { + const int lane = threadIdx.x & (kWarpThreads - 1); + const int group_base = lane & ~(kGroupSize - 1); + const unsigned mask = 0xffffu << group_base; +#pragma unroll + for (int offset = kGroupSize / 2; offset > 0; offset /= 2) { + value = fmaxf(value, __shfl_down_sync(mask, value, offset, kGroupSize)); + } + return __shfl_sync(mask, value, group_base, kWarpThreads); +} + +__device__ __forceinline__ void store_packed_group(const uint32_t *packed, fp4e2m1x2 *dst) { + const uint64_t packed64 = + static_cast(packed[0]) | (static_cast(packed[1]) << 32); + *reinterpret_cast(dst) = packed64; +} + +__device__ __forceinline__ const uint32_t *select_packed(const CandidatePair &candidates, + const bool pick_map4) { + if (pick_map4) { + return candidates.map4.packed; + } + return candidates.map6.packed; +} + +__device__ __forceinline__ nvfp4_scale_t select_scale(const ScalePair &scales, + const bool pick_map4) { + if (pick_map4) { + return scales.map4; + } + return scales.map6; +} + +__device__ __forceinline__ void cp_async_cg_16(void *dst, const void *src) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + const uint32_t dst_smem_ptr = __cvta_generic_to_shared(dst); + const uint64_t src_gmem_ptr = reinterpret_cast(src); + asm volatile("cp.async.cg.shared.global [%0], [%1], 16;\n" ::"r"(dst_smem_ptr), + "l"(src_gmem_ptr)); +#else + NVTE_DEVICE_ERROR("cp.async is only supported on SM 8.0+."); +#endif +} + +__device__ __forceinline__ void cp_async_commit_group() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + asm volatile("cp.async.commit_group;\n" ::); +#else + NVTE_DEVICE_ERROR("cp.async is only supported on SM 8.0+."); +#endif +} + +template +__device__ __forceinline__ void cp_async_wait_group() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); +#else + NVTE_DEVICE_ERROR("cp.async is only supported on SM 8.0+."); +#endif +} + +template +__device__ void load_stage_to_shared_async(const IType *input, IType *tile, const size_t rows, + const size_t cols, const size_t stage_row, + const size_t tile_col) { + constexpr int vec_elems = 16 / sizeof(IType); + constexpr int vecs_per_row = kTileCols / vec_elems; + constexpr int vecs = kStageRows * vecs_per_row; + using TileVec = Vec; + + for (int idx = threadIdx.x; idx < vecs; idx += blockDim.x) { + const int local_row = idx / vecs_per_row; + const int local_vec_col = idx - local_row * vecs_per_row; + const int local_col = local_vec_col * vec_elems; + const size_t global_row = stage_row + local_row; + const size_t global_col = tile_col + local_col; + IType *stage_ptr = &tile[local_row * kTileCols + local_col]; + + if (global_row < rows && global_col + vec_elems <= cols) { + cp_async_cg_16(stage_ptr, &input[global_row * cols + global_col]); + } else { + TileVec vec; + vec.clear(); +#pragma unroll + for (int i = 0; i < vec_elems; ++i) { + if (global_row < rows && global_col + i < cols) { + vec.data.elt[i] = input[global_row * cols + global_col + i]; + } + } + vec.store_to(stage_ptr); + } + } +} + +template +__device__ void quantize_stage_rowwise(const IType *tile, fp4e2m1x2 *output, nvfp4_scale_t *scales, + const float *amax, const size_t rows, const size_t cols, + const size_t stage_row, const size_t tile_col, + const size_t scale_stride) { + constexpr int groups = kStageRows * kTileColGroups; + for (int group = threadIdx.x; group < groups; group += blockDim.x) { + const int local_row = group % kStageRows; + const int local_col_group = group / kStageRows; + const int local_col = local_col_group * kGroupSize; + const size_t global_row = stage_row + local_row; + const size_t global_col = tile_col + local_col; + if (global_row >= rows || global_col >= cols) { + continue; + } + + float x0[8]; + float x1[8]; + float group_amax = 0.0f; + load_row_group(tile, local_row, local_col, x0, x1, &group_amax); + + float block_amax = group_amax; + if constexpr (USE_2D_QUANTIZATION) { + block_amax = reduce_group_max_16(group_amax); + } + + float global_amax = amax[0]; + if constexpr (ROW_SCALED_NVFP4) { + global_amax = amax[global_row]; + } + + const ScalePair scale_pair = compute_scale_pair(block_amax, global_amax); + CandidatePair candidates = make_candidates(x0, x1, scale_pair, global_amax); + + float err_map4 = candidates.map4.err; + float err_map6 = candidates.map6.err; + if constexpr (USE_2D_QUANTIZATION) { + err_map4 = reduce_group_sum_16(err_map4); + err_map6 = reduce_group_sum_16(err_map6); + } + + const bool pick_map4 = err_map4 < err_map6; + const nvfp4_scale_t selected_scale = select_scale(scale_pair, pick_map4); + const uint32_t *selected = select_packed(candidates, pick_map4); + + const size_t global_col_group = global_col / kGroupSize; + scales[global_row * scale_stride + global_col_group] = selected_scale; + store_packed_group(selected, &output[(global_row * cols + global_col) / 2]); + } +} + +template +__device__ void quantize_stage_colwise(const IType *tile, fp4e2m1x2 *output_t, + nvfp4_scale_t *scales_t, const float *amax, + const size_t rows, const size_t cols, const size_t stage_row, + const size_t tile_col, const size_t scale_stride_t) { + constexpr int groups = kStageRowGroups * kTileCols; + for (int group = threadIdx.x; group < groups; group += blockDim.x) { + const int local_row_group = group / kTileCols; + const int local_col = group - local_row_group * kTileCols; + const int local_row = local_row_group * kGroupSize; + const size_t global_row = stage_row + local_row; + const size_t global_col = tile_col + local_col; + if (global_row >= rows || global_col >= cols) { + continue; + } + + float x0[8]; + float x1[8]; + float group_amax = 0.0f; + load_col_group(tile, local_row, local_col, x0, x1, &group_amax); + + float block_amax = group_amax; + if constexpr (USE_2D_QUANTIZATION) { + block_amax = reduce_group_max_16(group_amax); + } + + const float global_amax = amax[0]; + const ScalePair scale_pair = compute_scale_pair(block_amax, global_amax); + CandidatePair candidates = make_candidates(x0, x1, scale_pair, global_amax); + + float err_map4 = candidates.map4.err; + float err_map6 = candidates.map6.err; + if constexpr (USE_2D_QUANTIZATION) { + err_map4 = reduce_group_sum_16(err_map4); + err_map6 = reduce_group_sum_16(err_map6); + } + + const bool pick_map4 = err_map4 < err_map6; + const nvfp4_scale_t selected_scale = select_scale(scale_pair, pick_map4); + const uint32_t *selected = select_packed(candidates, pick_map4); + + const size_t global_row_group = global_row / kGroupSize; + scales_t[global_col * scale_stride_t + global_row_group] = selected_scale; + store_packed_group(selected, &output_t[(global_col * rows + global_row) / 2]); + } +} + +template +__global__ void __launch_bounds__(kThreads) + quantize_4over6_kernel(const IType *input, fp4e2m1x2 *output, fp4e2m1x2 *output_t, + nvfp4_scale_t *scales, nvfp4_scale_t *scales_t, + const float *amax_rowwise, const float *amax_colwise, const size_t rows, + const size_t cols, const size_t scale_stride, + const size_t scale_stride_t, const float *noop) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + + extern __shared__ char dynamic_shmem[]; + auto *tiles = reinterpret_cast(dynamic_shmem); + const size_t tile_col = blockIdx.x * kTileCols; + const size_t tile_row = blockIdx.y * kTileRows; + + IType *stage_tiles[kPipelineStages]; +#pragma unroll + for (int stage = 0; stage < kPipelineStages; ++stage) { + stage_tiles[stage] = &tiles[stage * kStageRows * kTileCols]; + } + + load_stage_to_shared_async(input, stage_tiles[0], rows, cols, tile_row, tile_col); + cp_async_commit_group(); + cp_async_wait_group<0>(); + __syncthreads(); + + for (int stage = 0; stage < kPipelineStages; ++stage) { + const int next_stage = stage + 1; + if (next_stage < kPipelineStages) { + const size_t next_stage_row = tile_row + next_stage * kStageRows; + load_stage_to_shared_async(input, stage_tiles[next_stage], rows, cols, next_stage_row, + tile_col); + cp_async_commit_group(); + } + + const size_t stage_row = tile_row + stage * kStageRows; + IType *stage_tile = stage_tiles[stage]; + + if constexpr (RETURN_IDENTITY) { + quantize_stage_rowwise( + stage_tile, output, scales, amax_rowwise, rows, cols, stage_row, tile_col, scale_stride); + } + + if constexpr (RETURN_TRANSPOSE) { + const float *columnwise_amax = amax_colwise; + if (columnwise_amax == nullptr) { + columnwise_amax = amax_rowwise; + } + quantize_stage_colwise( + stage_tile, output_t, scales_t, columnwise_amax, rows, cols, stage_row, tile_col, + scale_stride_t); + } + + if (next_stage < kPipelineStages) { + cp_async_wait_group<0>(); + __syncthreads(); + } + } +#else + NVTE_DEVICE_ERROR("sm_100 or higher is required."); +#endif +} + +template +void launch_quantize_4over6(const Tensor &input, const Tensor *noop, Tensor *output, + cudaStream_t stream) { + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + const bool row_scaled_nvfp4 = output->row_scaled_nvfp4; + const bool return_identity = output->has_data(); + const bool return_transpose = output->has_columnwise_data(); + + const auto *input_ptr = reinterpret_cast(input.data.dptr); + auto *output_ptr = reinterpret_cast(output->data.dptr); + auto *output_t_ptr = reinterpret_cast(output->columnwise_data.dptr); + auto *scales_ptr = reinterpret_cast(output->scale_inv.dptr); + auto *scales_t_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); + const auto *amax_rowwise_ptr = reinterpret_cast(output->amax.dptr); + const auto *amax_colwise_ptr = reinterpret_cast(output->columnwise_amax.dptr); + const auto *noop_ptr = reinterpret_cast(noop->data.dptr); + + const dim3 grid(DIVUP(cols, static_cast(kTileCols)), + DIVUP(rows, static_cast(kTileRows))); + const dim3 block(kThreads); + const size_t shmem = kPipelineStages * kStageRows * kTileCols * sizeof(IType); + const size_t scale_stride = return_identity ? output->scale_inv.shape[1] : 0; + const size_t scale_stride_t = return_transpose ? output->columnwise_scale_inv.shape[1] : 0; + + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_identity, RETURN_IDENTITY, { + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + TRANSFORMER_ENGINE_SWITCH_CONDITION(row_scaled_nvfp4, ROW_SCALED_NVFP4, { + auto kernel = quantize_4over6_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem); + kernel<<>>(input_ptr, output_ptr, output_t_ptr, scales_ptr, + scales_t_ptr, amax_rowwise_ptr, amax_colwise_ptr, + rows, cols, scale_stride, scale_stride_t, noop_ptr); + }); + }); + }); +} + +} // namespace quantize_4over6_kernel + +#endif // FP4_TYPE_SUPPORTED + +template +void quantize_4over6(const Tensor &input, const Tensor *noop, Tensor *output, + const QuantizationConfig *quant_config, cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + using namespace quantize_4over6_kernel; + + checkCuDriverContext(stream); + CheckNoopTensor(*noop, "cast_noop"); + CheckInputTensor(input, "input"); + CheckOutputTensor(*output, "output", false); + + NVTE_CHECK(quant_config != nullptr && quant_config->nvfp4_4over6, + "NVFP4 4over6 quantization requires an enabled quantization config."); + NVTE_CHECK(!quant_config->stochastic_rounding, + "NVFP4 4over6 quantization does not support stochastic rounding."); + NVTE_CHECK(quant_config->nvfp4_e4m3_max == output->nvfp4_e4m3_max, + "Tensor and quantization config have inconsistent options for NVFP4 4over6 " + "E4M3 scale bound."); + NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + "NVFP4 4over6 output tensor must have rowwise or columnwise data."); + NVTE_CHECK(!output->with_gemm_swizzled_scales, "Output must have scales in compact format."); + NVTE_CHECK(input.flat_last_dim() % kGroupSize == 0, + "NVFP4 4over6 quantization requires columns divisible by ", kGroupSize, "."); + NVTE_CHECK(!(output->has_columnwise_data() || use_2d_quantization) || + input.flat_first_dim() % kGroupSize == 0, + "NVFP4 4over6 columnwise or 2D quantization requires rows divisible by ", kGroupSize, + "."); + NVTE_CHECK(!output->row_scaled_nvfp4 || !use_2d_quantization, + "Row-scaled NVFP4 quantization does not support 2D quantization."); + NVTE_CHECK(!output->row_scaled_nvfp4 || !output->has_columnwise_data(), + "Row-scaled NVFP4 quantization does not produce columnwise output."); + NVTE_CHECK(!use_2d_quantization || output->has_data(), + "NVFP4 4over6 2D quantization requires rowwise output."); + + if (output->has_data()) { + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); + NVTE_CHECK(output->amax.dptr != nullptr, "Rowwise amax tensor must be allocated."); + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Transposed scaling tensor must be allocated."); + NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), + "Transposed output must have FP4 type."); + NVTE_CHECK(output->columnwise_amax.dptr != nullptr || output->amax.dptr != nullptr, + "NVFP4 4over6 columnwise quantization requires columnwise amax or rowwise amax."); + } + + TRANSFORMER_ENGINE_NVFP4_4OVER6_E4M3_MAX_SWITCH( + quant_config->nvfp4_e4m3_max, E4M3_MAX, + TRANSFORMER_ENGINE_NVFP4_4OVER6_ERR_MODE_SWITCH( + quant_config->nvfp4_4over6_err_mode, ERR_MODE, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + quant_config->nvfp4_4over6_err_use_fast_math, ERR_USE_FAST_MATH, { + using Cfg = quantize_4over6_kernel::Config; + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), IType, + quantize_4over6_kernel::launch_quantize_4over6( + input, noop, output, stream);); + }););); + + NVTE_CHECK_CUDA(cudaGetLastError()); +#else + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif // FP4_TYPE_SUPPORTED +} + +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_4OVER6_NVFP4_CUH_ diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 28218e2b43..eeb801b758 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -230,6 +230,14 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz chunk.set_row_scaled_nvfp4(source.get_row_scaled_nvfp4()); continue; } + if (param_type == NVTETensorParam::kNVTENVFP44Over6) { + chunk.set_nvfp4_4over6(source.get_nvfp4_4over6()); + continue; + } + if (param_type == NVTETensorParam::kNVTENVFP4E4M3Max) { + chunk.set_nvfp4_e4m3_max(source.get_nvfp4_e4m3_max()); + continue; + } auto param = source.get_parameter(param_type); auto param_dptr = reinterpret_cast(param.data_ptr); auto param_dtype = static_cast(param.dtype); diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 12479f2a9c..154dab3143 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -178,6 +178,18 @@ struct Tensor { * Only meaningful for NVFP4 tensors. */ bool row_scaled_nvfp4 = false; + /*! \brief Whether NVFP4 uses 4over6 block scale selection. + * + * Only meaningful for NVFP4 tensors. 4over6 tensors store a selected + * map-to-4/map-to-6 candidate for each 1x16 block. + */ + bool nvfp4_4over6 = false; + /*! \brief Global E4M3 scale bound used by NVFP4. + * + * Standard NVFP4 uses 448. Some 4over6 tensors use 256 to leave room for + * map-to-4 local scale expansion. + */ + int nvfp4_e4m3_max = 448; /*! Map from NVTETensorParam to parameter sizes */ static constexpr size_t attr_sizes[] = { @@ -189,7 +201,9 @@ struct Tensor { sizeof(NVTEBasicTensor), // kNVTEColumnwiseScaleInv sizeof(NVTEBasicTensor), // kNVTEColumnwiseAmax sizeof(uint8_t), // kNVTEWithGEMMSwizzledScales - sizeof(uint8_t) // kNVTERowScaledNVFP4 + sizeof(uint8_t), // kNVTERowScaledNVFP4 + sizeof(uint8_t), // kNVTENVFP44Over6 + sizeof(int) // kNVTENVFP4E4M3Max }; Tensor() : scaling_mode{NVTE_DELAYED_TENSOR_SCALING}, nvte_tensor{0} {} @@ -206,6 +220,8 @@ struct Tensor { scaling_mode = NVTE_DELAYED_TENSOR_SCALING; with_gemm_swizzled_scales = false; row_scaled_nvfp4 = false; + nvfp4_4over6 = false; + nvfp4_e4m3_max = 448; } explicit operator NVTETensor() const noexcept { return nvte_tensor; } @@ -477,6 +493,10 @@ struct QuantizationConfig { bool nvfp4_2d_quantization = false; bool stochastic_rounding = false; bool use_fast_math = false; + bool nvfp4_4over6 = false; + int nvfp4_e4m3_max = 448; + NVTENVFP44Over6ErrMode nvfp4_4over6_err_mode = kNVTENVFP44Over6ErrMAE; + bool nvfp4_4over6_err_use_fast_math = false; static constexpr size_t attr_sizes[] = { sizeof(uint8_t), // force_pow_2_scales @@ -486,7 +506,11 @@ struct QuantizationConfig { sizeof(NVTETensor), // rng_seed and offset sizeof(uint8_t), // nvfp4_2d_quantization sizeof(uint8_t), // stochastic_rounding - sizeof(uint8_t) // use_fast_math + sizeof(uint8_t), // use_fast_math + sizeof(uint8_t), // nvfp4_4over6 + sizeof(int), // nvfp4_e4m3_max + sizeof(uint8_t), // nvfp4_4over6_err_mode + sizeof(uint8_t) // nvfp4_4over6_err_use_fast_math }; }; diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 045ae88893..c3b2aada00 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -83,6 +83,19 @@ enum NVTETensorParam { * its values are populated during quantization. */ kNVTERowScaledNVFP4 = 8, + /*! Whether an NVFP4 tensor is encoded with 4over6 semantics. + * + * This records whether block scales were selected by comparing map-to-4 + * and map-to-6 candidates. + */ + kNVTENVFP44Over6 = 9, + /*! Global E4M3 scale bound used by an NVFP4 tensor. + * + * This is part of the tensor data contract. Downstream dequantization and + * GEMM scale consumers must use the same bound used during quantization. + * Standard NVFP4 uses 448; 4over6 may use 256 for map-to-4 headroom. + */ + kNVTENVFP4E4M3Max = 10, kNVTENumTensorParams }; @@ -111,6 +124,14 @@ enum NVTEScalingMode { NVTE_INVALID_SCALING = 100 }; +/*! \enum NVTENVFP44Over6ErrMode + * \brief Candidate-selection error mode for NVFP4 4over6 quantization. + */ +enum NVTENVFP44Over6ErrMode { + kNVTENVFP44Over6ErrMAE = 0, /*!< Select the candidate with lower summed absolute error */ + kNVTENVFP44Over6ErrMSE = 1, /*!< Select the candidate with lower summed squared error */ +}; + /*! \brief TE Tensor type * * NVTETensor is a contiguous tensor type storing a pointer @@ -381,6 +402,34 @@ enum NVTEQuantizationConfigAttribute { * inconsistently between kernels. */ kNVTEQuantizationConfigUseFastMath = 7, + /*! Whether to use NVFP4 4over6 block scale selection. + * + * 4over6 evaluates map-to-4 and map-to-6 candidates for each 1x16 block, + * stores the lower-error candidate according to + * kNVTEQuantizationConfigNVFP44Over6ErrMode. The output tensor's + * kNVTENVFP44Over6 metadata must match this option. + */ + kNVTEQuantizationConfigNVFP44Over6 = 8, + /*! Global E4M3 scale bound to use for NVFP4 quantization. + * + * Standard NVFP4 uses 448. Some 4over6 tensors use 256 to leave room for + * map-to-4 local scale expansion. The output tensor's kNVTENVFP4E4M3Max + * metadata must match this option. + */ + kNVTEQuantizationConfigNVFP4E4M3Max = 9, + /*! Candidate-selection error mode for NVFP4 4over6 quantization. + * + * The value is an NVTENVFP44Over6ErrMode encoded as uint8_t. It is only + * used when kNVTEQuantizationConfigNVFP44Over6 is enabled. + */ + kNVTEQuantizationConfigNVFP44Over6ErrMode = 10, + /*! Whether the NVFP4 4over6 candidate error computation may use fast math. + * + * This is intentionally separate from kNVTEQuantizationConfigUseFastMath so + * callers can keep candidate selection bitwise deterministic independent + * of ordinary NVFP4 fast-math settings. + */ + kNVTEQuantizationConfigNVFP44Over6ErrUseFastMath = 11, kNVTEQuantizationConfigNumAttributes }; @@ -781,6 +830,16 @@ class TensorWrapper { nvte_set_tensor_param_v2(tensor_, kNVTERowScaledNVFP4, &val, sizeof(val)); } + void set_nvfp4_4over6(bool nvfp4_4over6) { + const auto val = static_cast(nvfp4_4over6); + nvte_set_tensor_param_v2(tensor_, kNVTENVFP44Over6, &val, sizeof(val)); + } + + void set_nvfp4_e4m3_max(int nvfp4_e4m3_max) { + const auto val = nvfp4_e4m3_max; + nvte_set_tensor_param_v2(tensor_, kNVTENVFP4E4M3Max, &val, sizeof(val)); + } + // Parameter getters NVTEBasicTensor get_parameter(const NVTETensorParam param) const noexcept { @@ -823,6 +882,18 @@ class TensorWrapper { return static_cast(val); } + bool get_nvfp4_4over6() const { + uint8_t val = 0; + nvte_get_tensor_param_v2(tensor_, kNVTENVFP44Over6, &val, sizeof(val), nullptr); + return static_cast(val); + } + + int get_nvfp4_e4m3_max() const { + int val = 448; + nvte_get_tensor_param_v2(tensor_, kNVTENVFP4E4M3Max, &val, sizeof(val), nullptr); + return val; + } + /*! \brief Get an underlying NVTETensor. * * \return NVTETensor held by this TensorWrapper. @@ -1318,6 +1389,34 @@ class QuantizationConfigWrapper { sizeof(val)); } + /*! \brief Set whether to use NVFP4 4over6 block scale selection */ + void set_nvfp4_4over6(bool nvfp4_4over6) { + const auto val = static_cast(nvfp4_4over6); + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNVFP44Over6, &val, + sizeof(val)); + } + + /*! \brief Set the global E4M3 scale bound used by NVFP4 quantization */ + void set_nvfp4_e4m3_max(int nvfp4_e4m3_max) { + const auto val = nvfp4_e4m3_max; + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNVFP4E4M3Max, &val, + sizeof(val)); + } + + /*! \brief Set NVFP4 4over6 candidate-selection error mode */ + void set_nvfp4_4over6_err_mode(NVTENVFP44Over6ErrMode mode) { + const auto val = static_cast(mode); + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNVFP44Over6ErrMode, &val, + sizeof(val)); + } + + /*! \brief Set whether NVFP4 4over6 candidate error computation uses fast math */ + void set_nvfp4_4over6_err_use_fast_math(bool use_fast_math) { + const auto val = static_cast(use_fast_math); + nvte_set_quantization_config_attribute( + config_, kNVTEQuantizationConfigNVFP44Over6ErrUseFastMath, &val, sizeof(val)); + } + private: /*! \brief Wrapped NVTEQuantizationConfig. */ NVTEQuantizationConfig config_ = nullptr; diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index b773a81d1b..a3f4f2d700 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -13,6 +13,8 @@ _BACKWARD_OVERRIDES = (None, "high_precision", "dequantized") +_NVFP4_4OVER6_SCOPES = (None, "weights", "activations", "all") +_NVFP4_4OVER6_ERR_MODES = ("MAE", "MSE") class _FormatHelper(NamedTuple): @@ -522,6 +524,20 @@ class NVFP4BlockScaling(Recipe): If set to `True`, forward activation quantizers emit row-scaled NVFP4 tensors. In this mode, rowwise ``amax`` metadata is stored as a vector with one FP32 value per tensor row. + nvfp4_4over6 : {None, 'weights', 'activations', 'all'}, default = None + Select tensors that use NVFP4 4over6. In this mode NVFP4 + quantization evaluates per-block map-to-4 and map-to-6 candidates + and chooses the one with lower configured error. Ties choose map-to-6. The + ``activations`` scope applies to every non-weight tensor role. + Random Hadamard transforms and stochastic rounding are not yet + supported on tensors that use 4over6; activation and backward + scopes therefore require ``disable_rht=True`` and + ``disable_stochastic_rounding=True``. + nvfp4_e4m3_max : {None, 'weights', 'activations', 'all'}, default = None + Select 4over6 tensors that use 256 as the global E4M3 scale + bound. If unset, 4over6 uses the standard NVFP4 448 bound. + nvfp4_4over6_err_mode : {'MAE', 'MSE'}, default = 'MAE' + Error metric used by NVFP4 4over6 candidate selection. backward_override : {None, 'high_precision', 'dequantized'}, default = None Backward precision mode. None does not modify backward behavior, `high_precision` keeps original high-precision operands for backward, @@ -536,6 +552,9 @@ class NVFP4BlockScaling(Recipe): ) disable_2d_quantization: bool = os.getenv("NVTE_NVFP4_DISABLE_2D_QUANTIZATION", "0") == "1" row_scaled_activation: bool = os.getenv("NVTE_NVFP4_ROW_SCALED_ACTIVATION", "0") == "1" + nvfp4_4over6: Optional[str] = os.getenv("NVTE_NVFP4_4OVER6", None) + nvfp4_e4m3_max: Optional[str] = os.getenv("NVTE_NVFP4_4OVER6_E4M3_USE_256", None) + nvfp4_4over6_err_mode: str = os.getenv("NVTE_NVFP4_4OVER6_ERR_MODE", "MAE").upper() fp4_format: Format = Format.E2M1 fp8_format: Format = Format.E4M3 @@ -551,6 +570,21 @@ def __post_init__(self) -> None: assert ( self.backward_override in _BACKWARD_OVERRIDES ), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'." + assert ( + self.nvfp4_4over6 in _NVFP4_4OVER6_SCOPES + ), "NVTE_NVFP4_4OVER6 must be unset or one of: 'weights', 'activations', 'all'." + assert self.nvfp4_e4m3_max in _NVFP4_4OVER6_SCOPES, ( + "NVTE_NVFP4_4OVER6_E4M3_USE_256 must be unset or one of: " + "'weights', 'activations', 'all'." + ) + assert ( + self.nvfp4_4over6_err_mode in _NVFP4_4OVER6_ERR_MODES + ), "NVTE_NVFP4_4OVER6_ERR_MODE must be one of: 'MAE', 'MSE'." + if self.nvfp4_4over6 in ("activations", "all"): + assert self.disable_rht, "NVFP4 4over6 currently requires RHT to be disabled" + assert ( + self.disable_stochastic_rounding + ), "NVFP4 4over6 currently requires stochastic rounding to be disabled" # Quantization params # Note: RHT is currently only applied to column-wise usage so that @@ -580,6 +614,9 @@ def _make_repr(self) -> str: f"fp8_mha={self.fp8_mha}, " f"backward_override={self.backward_override}, " f"row_scaled_activation={self.row_scaled_activation}, " + f"nvfp4_4over6={self.nvfp4_4over6}, " + f"nvfp4_e4m3_max={self.nvfp4_e4m3_max}, " + f"nvfp4_4over6_err_mode={self.nvfp4_4over6_err_mode}, " f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, " f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, " f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, " diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu index 1c419d4f8c..576e6139c7 100644 --- a/transformer_engine/common/recipe/nvfp4.cu +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -65,15 +65,15 @@ namespace nvfp4_recipe { * --------------------------------------------------------------------------- */ -// constexpr float factor = 6.0 * 6.0 * 448.0 * 448.0; -constexpr float factor_inv = 1.0 / (6.0 * 6.0 * 448.0 * 448.0); constexpr int kTileDim = 16; constexpr int kThreadsPerBlock = 256; // Kernel to compute alpha *= amax_A * amax_B / factor __global__ void compute_nvfp4_per_tensor_scale_kernel(float alpha_in, const float *amax_A, - const float *amax_B, float *alpha_out) { - // factor is defined in the enclosing namespace + const float *amax_B, float fp8_max_A, + float fp8_max_B, float *alpha_out) { + constexpr float fp4_max = 6.0f; + const float factor_inv = 1.0f / (fp4_max * fp4_max * fp8_max_A * fp8_max_B); *alpha_out = alpha_in * (*amax_A) * (*amax_B) * factor_inv; } @@ -924,6 +924,8 @@ void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_r void *amax_A_ptr = use_rowwise_amax_A ? tA->amax.dptr : tA->columnwise_amax.dptr; void *amax_B_ptr = use_rowwise_amax_B ? tB->amax.dptr : tB->columnwise_amax.dptr; void *alpha_ptr = tOut->data.dptr; + const float fp8_max_A = static_cast(tA->nvfp4_e4m3_max); + const float fp8_max_B = static_cast(tB->nvfp4_e4m3_max); // check for not null pointers NVTE_CHECK(amax_A_ptr != nullptr, "amax_A_ptr is null"); @@ -932,7 +934,8 @@ void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_r nvfp4_recipe::compute_nvfp4_per_tensor_scale_kernel<<<1, 1, 0, stream>>>( alpha_in, reinterpret_cast(amax_A_ptr), - reinterpret_cast(amax_B_ptr), reinterpret_cast(alpha_ptr)); + reinterpret_cast(amax_B_ptr), fp8_max_A, fp8_max_B, + reinterpret_cast(alpha_ptr)); NVTE_CHECK_CUDA(cudaGetLastError()); #else NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 1a52d76019..2378943526 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -855,6 +855,14 @@ void nvte_set_tensor_param_v2(NVTETensor tensor, NVTETensorParam param, const vo case kNVTERowScaledNVFP4: t.row_scaled_nvfp4 = static_cast(*reinterpret_cast(buf)); break; + case kNVTENVFP44Over6: + t.nvfp4_4over6 = static_cast(*reinterpret_cast(buf)); + break; + case kNVTENVFP4E4M3Max: + std::memcpy(&t.nvfp4_e4m3_max, buf, attr_size); + NVTE_CHECK(t.nvfp4_e4m3_max == 448 || t.nvfp4_e4m3_max == 256, + "Unsupported NVFP4 E4M3 max (got ", t.nvfp4_e4m3_max, ")"); + break; default: NVTE_ERROR("Unsupported tensor parameter (", static_cast(param), ")"); } @@ -938,6 +946,12 @@ void nvte_get_tensor_param_v2(const NVTETensor tensor, NVTETensorParam param, vo case kNVTERowScaledNVFP4: *reinterpret_cast(buf) = static_cast(t->row_scaled_nvfp4); break; + case kNVTENVFP44Over6: + *reinterpret_cast(buf) = static_cast(t->nvfp4_4over6); + break; + case kNVTENVFP4E4M3Max: + std::memcpy(buf, &t->nvfp4_e4m3_max, attr_size); + break; default: NVTE_ERROR("Unsupported tensor parameter (", static_cast(param), ")"); } @@ -1049,6 +1063,20 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigUseFastMath: bool_to_uint8(config_.use_fast_math, buf); break; + case kNVTEQuantizationConfigNVFP44Over6: + bool_to_uint8(config_.nvfp4_4over6, buf); + break; + case kNVTEQuantizationConfigNVFP4E4M3Max: + std::memcpy(buf, &config_.nvfp4_e4m3_max, attr_size); + break; + case kNVTEQuantizationConfigNVFP44Over6ErrMode: { + const auto val = static_cast(config_.nvfp4_4over6_err_mode); + std::memcpy(buf, &val, attr_size); + break; + } + case kNVTEQuantizationConfigNVFP44Over6ErrUseFastMath: + bool_to_uint8(config_.nvfp4_4over6_err_use_fast_math, buf); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } @@ -1104,6 +1132,25 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigUseFastMath: uint8_to_bool(buf, config_.use_fast_math); break; + case kNVTEQuantizationConfigNVFP44Over6: + uint8_to_bool(buf, config_.nvfp4_4over6); + break; + case kNVTEQuantizationConfigNVFP4E4M3Max: + std::memcpy(&config_.nvfp4_e4m3_max, buf, attr_size); + NVTE_CHECK(config_.nvfp4_e4m3_max == 448 || config_.nvfp4_e4m3_max == 256, + "Unsupported NVFP4 E4M3 max (got ", config_.nvfp4_e4m3_max, ")"); + break; + case kNVTEQuantizationConfigNVFP44Over6ErrMode: { + const auto val = *reinterpret_cast(buf); + NVTE_CHECK(val == static_cast(kNVTENVFP44Over6ErrMAE) || + val == static_cast(kNVTENVFP44Over6ErrMSE), + "Invalid NVFP4 4over6 error mode (got ", static_cast(val), ")"); + config_.nvfp4_4over6_err_mode = static_cast(val); + break; + } + case kNVTEQuantizationConfigNVFP44Over6ErrUseFastMath: + uint8_to_bool(buf, config_.nvfp4_4over6_err_use_fast_math); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 94350da1e6..af8b0835d5 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -327,6 +327,11 @@ class NVFP4Quantizer : public Quantizer { // 2D block scaling bool with_2d_quantization; bool stochastic_rounding; + // Whether emitted NVFP4 tensors use 4over6 candidate selection. + bool nvfp4_use_4over6; + // Global E4M3 scale bound used by emitted NVFP4 tensors. + int nvfp4_e4m3_max; + NVTENVFP44Over6ErrMode nvfp4_4over6_err_mode; // Whether tensors emitted by this quantizer use row-scaled NVFP4 metadata. bool row_scaled_nvfp4; diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 2b38339d67..89f745afd7 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -84,6 +84,8 @@ void group_quantize_nvfp4_impl(const GroupedTensorWrapper &grouped_input_tensor, // assert the 2D scaling case, since 2D scaling grouped quant kernel is not ready yet NVTE_CHECK(!nvfp4_quantizer_cpp->with_2d_quantization, "2D scaling grouped quant kernel is not ready yet"); + NVTE_CHECK(!nvfp4_quantizer_cpp->nvfp4_use_4over6, + "NVFP4 4over6 quantization is not supported for grouped quantization."); auto quant_config_cpp = QuantizationConfigWrapper(); @@ -722,6 +724,8 @@ std::tuple, std::vector, bool> bulk_alloc // Quantization parameters const auto rowwise_usage = quantizer_cpp_list[0]->rowwise_usage; const bool row_scaled_nvfp4 = quantizer_cpp_list[0]->row_scaled_nvfp4; + const bool nvfp4_use_4over6 = quantizer_cpp_list[0]->nvfp4_use_4over6; + const int nvfp4_e4m3_max = quantizer_cpp_list[0]->nvfp4_e4m3_max; const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage; if (row_scaled_nvfp4) { NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 bulk allocation requires rowwise usage."); @@ -866,10 +870,12 @@ std::tuple, std::vector, bool> bulk_alloc py::object amax_columnwise = columnwise_usage ? py::cast(amax_columnwise_list[i]) : py::none(); // Construct Python tensor - tensor_py_list.emplace_back(NVFP4TensorClass(rowwise_data, rowwise_scale, columnwise_data, - columnwise_scale, amax_rowwise, amax_columnwise, - fp4_dtype, quantizer_py_list[i], - with_gemm_swizzled_scales, row_scaled_nvfp4)); + tensor_py_list.emplace_back( + NVFP4TensorClass(rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, + amax_rowwise, amax_columnwise, fp4_dtype, quantizer_py_list[i], + with_gemm_swizzled_scales, py::arg("row_scaled_nvfp4") = row_scaled_nvfp4, + py::arg("nvfp4_use_4over6") = nvfp4_use_4over6, + py::arg("nvfp4_e4m3_max") = nvfp4_e4m3_max)); // Construct C++ tensor // Use a TensorWrapper variable to hold the output of makeTransformerEngineTensor, @@ -887,6 +893,8 @@ std::tuple, std::vector, bool> bulk_alloc columnwise_usage ? columnwise_scale_shapes[i] : std::vector{0}, scaling_mode); tensor_wrapper.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); tensor_wrapper.set_row_scaled_nvfp4(row_scaled_nvfp4); + tensor_wrapper.set_nvfp4_4over6(nvfp4_use_4over6); + tensor_wrapper.set_nvfp4_e4m3_max(nvfp4_e4m3_max); // Set the amax rowwise and amax columnwise if available if (rowwise_usage) { @@ -997,6 +1005,8 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, cudaStream_t stream) { const size_t num_tensors = split_sections.size(); const auto &quantizer = *quantizers.front(); + NVTE_CHECK(!quantizer.nvfp4_use_4over6, + "NVFP4 4over6 quantization is not supported with RHT split quantization."); std::vector nvte_tensor_input_list; std::vector nvte_tensor_output_list; @@ -1032,6 +1042,17 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, num_tensors, need_stochastic_rounding, with_bulk_generate_rng_states, need_separate_rng_states, quant_config_list, quant_config_list_colwise); + for (auto &config : quant_config_list) { + config.set_nvfp4_4over6(quantizer.nvfp4_use_4over6); + config.set_nvfp4_e4m3_max(quantizer.nvfp4_e4m3_max); + config.set_nvfp4_4over6_err_mode(quantizer.nvfp4_4over6_err_mode); + } + for (auto &config : quant_config_list_colwise) { + config.set_nvfp4_4over6(quantizer.nvfp4_use_4over6); + config.set_nvfp4_e4m3_max(quantizer.nvfp4_e4m3_max); + config.set_nvfp4_4over6_err_mode(quantizer.nvfp4_4over6_err_mode); + } + // Enable NVFP4 kernels to use math operations that sacrifice // accuracy for performance. These optimizations are experimental // and inconsistently implemented. @@ -1039,8 +1060,10 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, // 1. replace 1 / x by reciprocal_approximate_ftz(x) // 2. when RHT cast fusion is available, fusion allows cast to be performed on FP32 data, // this will essentially remove a round trip between FP32 to BF16 then FP32 + // NVFP4 4over6 candidate error math is controlled separately by + // NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH. const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); - if (use_fast_math) { + if (use_fast_math && !quantizer.nvfp4_use_4over6) { for (auto &config : quant_config_list) { config.set_use_fast_math(true); } @@ -1049,6 +1072,17 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, } } + const auto use_4over6_err_use_fast_math = + transformer_engine::getenv("NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH"); + if (use_4over6_err_use_fast_math) { + for (auto &config : quant_config_list) { + config.set_nvfp4_4over6_err_use_fast_math(true); + } + for (auto &config : quant_config_list_colwise) { + config.set_nvfp4_4over6_err_use_fast_math(true); + } + } + auto &quant_config_list_colwise_to_use = need_separate_rng_states ? quant_config_list_colwise : quant_config_list; @@ -1157,6 +1191,8 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, cudaStream_t stream) { const size_t num_tensors = input_list.size(); const auto &quantizer = *quantizers.front(); + NVTE_CHECK(!quantizer.nvfp4_use_4over6 || !quantizer.stochastic_rounding, + "NVFP4 4over6 quantization does not support stochastic rounding."); std::vector nvte_tensor_input_list; std::vector nvte_tensor_output_list; @@ -1189,6 +1225,29 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, need_separate_rng_states, quant_config_list, dummy_quant_config_list_colwise); // colwise rng states are not needed in this case + for (auto &config : quant_config_list) { + config.set_nvfp4_4over6(quantizer.nvfp4_use_4over6); + config.set_nvfp4_e4m3_max(quantizer.nvfp4_e4m3_max); + config.set_nvfp4_4over6_err_mode(quantizer.nvfp4_4over6_err_mode); + } + + // NVFP4 4over6 candidate error math is controlled separately by + // NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH. + const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); + if (use_fast_math && !quantizer.nvfp4_use_4over6) { + for (auto &config : quant_config_list) { + config.set_use_fast_math(true); + } + } + + const auto use_4over6_err_use_fast_math = + transformer_engine::getenv("NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH"); + if (use_4over6_err_use_fast_math) { + for (auto &config : quant_config_list) { + config.set_nvfp4_4over6_err_use_fast_math(true); + } + } + // We need: // 1. Rowwise amax = amax for input // 2. Columnwise amax = amax for input too @@ -1259,6 +1318,11 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, "NVFP4 split-quantize does not support 2D quantization"); NVTE_CHECK(!quantizer.with_amax_reduction, "NVFP4 split-quantize does not support amax reduction"); + if (quantizer.nvfp4_use_4over6) { + NVTE_CHECK(!quantizer.with_rht, "NVFP4 4over6 quantization does not support RHT."); + NVTE_CHECK(!quantizer.stochastic_rounding, + "NVFP4 4over6 quantization does not support stochastic rounding."); + } // Check input tensor shape const size_t input_last_dim = input.ndim() > 0 ? input.size(input.ndim() - 1) : 1; diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 7045995dd7..9dc592fce1 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1729,6 +1729,19 @@ NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantize this->with_post_rht_amax = quantizer.attr("with_post_rht_amax").cast(); this->with_2d_quantization = quantizer.attr("with_2d_quantization").cast(); this->stochastic_rounding = quantizer.attr("stochastic_rounding").cast(); + this->nvfp4_use_4over6 = quantizer.attr("nvfp4_use_4over6").cast(); + this->nvfp4_e4m3_max = + this->nvfp4_use_4over6 ? quantizer.attr("nvfp4_e4m3_max").cast() : 448; + NVTE_CHECK(this->nvfp4_e4m3_max == 448 || this->nvfp4_e4m3_max == 256, + "Unsupported NVFP4 E4M3 max: ", this->nvfp4_e4m3_max); + const auto nvfp4_4over6_err_mode = quantizer.attr("nvfp4_4over6_err_mode").cast(); + if (nvfp4_4over6_err_mode == "MAE") { + this->nvfp4_4over6_err_mode = kNVTENVFP44Over6ErrMAE; + } else if (nvfp4_4over6_err_mode == "MSE") { + this->nvfp4_4over6_err_mode = kNVTENVFP44Over6ErrMSE; + } else { + NVTE_ERROR("Unsupported NVFP4 4over6 error mode: ", nvfp4_4over6_err_mode); + } this->row_scaled_nvfp4 = quantizer.attr("row_scaled_nvfp4").cast(); // Get amax reduction group if needed for NVFP4 AG @@ -1778,6 +1791,8 @@ std::pair NVFP4Quantizer::create_tensor( "NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE, " (got shape=", shape, ")"); const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; + const bool nvfp4_use_4over6 = this->nvfp4_use_4over6; + const int nvfp4_e4m3_max = this->nvfp4_e4m3_max; if (row_scaled_nvfp4) { NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 quantization requires rowwise usage."); NVTE_CHECK(!columnwise_usage, @@ -1845,6 +1860,8 @@ std::pair NVFP4Quantizer::create_tensor( kwargs["quantizer"] = this->quantizer; kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); + kwargs["nvfp4_use_4over6"] = py::cast(nvfp4_use_4over6); + kwargs["nvfp4_e4m3_max"] = py::cast(nvfp4_e4m3_max); kwargs["fake_dtype"] = GetATenDType(dtype); py::tuple args(0); @@ -1875,6 +1892,8 @@ std::pair NVFP4Quantizer::create_tensor( kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); kwargs["device"] = py::cast(device); kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); + kwargs["nvfp4_use_4over6"] = py::cast(nvfp4_use_4over6); + kwargs["nvfp4_e4m3_max"] = py::cast(nvfp4_e4m3_max); py::tuple args(0); PyObject* result = PyObject_Call(reinterpret_cast(NVFP4TensorPythonClass), args.ptr(), kwargs.ptr()); @@ -1908,6 +1927,8 @@ std::pair NVFP4Quantizer::create_tensor( } out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); out_cpp.set_row_scaled_nvfp4(row_scaled_nvfp4); + out_cpp.set_nvfp4_4over6(nvfp4_use_4over6); + out_cpp.set_nvfp4_e4m3_max(nvfp4_e4m3_max); this->set_quantization_params(&out_cpp); return {std::move(out_cpp), std::move(out_py)}; @@ -1936,6 +1957,8 @@ std::pair NVFP4Quantizer::create_grouped_tenso std::optional columnwise_amax; const std::vector logical_shape_vec = {logical_first_dim, logical_last_dim}; const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; + const bool nvfp4_use_4over6 = this->nvfp4_use_4over6; + const int nvfp4_e4m3_max = this->nvfp4_e4m3_max; if (row_scaled_nvfp4) { NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 grouped quantization requires rowwise usage."); NVTE_CHECK(!columnwise_usage, @@ -2010,6 +2033,8 @@ std::pair NVFP4Quantizer::create_grouped_tenso kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); kwargs["with_gemm_swizzled_scales"] = this->optimize_for_gemm; kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); + kwargs["nvfp4_use_4over6"] = py::cast(nvfp4_use_4over6); + kwargs["nvfp4_e4m3_max"] = py::cast(nvfp4_e4m3_max); PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); if (result == nullptr) { PyErr_Print(); @@ -2085,12 +2110,16 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( const auto [flat_first_dim, flat_last_dim] = get_2d_dims(shape); const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; + const bool nvfp4_use_4over6 = this->nvfp4_use_4over6; + const int nvfp4_e4m3_max = this->nvfp4_e4m3_max; if (row_scaled_nvfp4) { NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 quantization requires rowwise usage."); NVTE_CHECK(!columnwise_usage, "Row-scaled NVFP4 quantization does not support columnwise usage."); } tensor.attr("_row_scaled_nvfp4") = py::cast(row_scaled_nvfp4); + tensor.attr("_nvfp4_use_4over6") = py::cast(nvfp4_use_4over6); + tensor.attr("_nvfp4_e4m3_max") = py::cast(nvfp4_e4m3_max); // Coerce row-wise data if (rowwise_usage) { @@ -2195,6 +2224,8 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( } out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); out_cpp.set_row_scaled_nvfp4(row_scaled_nvfp4); + out_cpp.set_nvfp4_4over6(nvfp4_use_4over6); + out_cpp.set_nvfp4_e4m3_max(nvfp4_e4m3_max); this->set_quantization_params(&out_cpp); return {std::move(out_cpp), std::move(tensor)}; @@ -2285,6 +2316,18 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou } quant_config.set_nvfp4_2d_quantization(this->with_2d_quantization); quant_config.set_stochastic_rounding(this->stochastic_rounding); + quant_config.set_nvfp4_4over6(this->nvfp4_use_4over6); + quant_config.set_nvfp4_e4m3_max(this->nvfp4_e4m3_max); + quant_config.set_nvfp4_4over6_err_mode(this->nvfp4_4over6_err_mode); + quant_config_columnwise.set_nvfp4_4over6(this->nvfp4_use_4over6); + quant_config_columnwise.set_nvfp4_e4m3_max(this->nvfp4_e4m3_max); + quant_config_columnwise.set_nvfp4_4over6_err_mode(this->nvfp4_4over6_err_mode); + + if (this->nvfp4_use_4over6) { + NVTE_CHECK(!this->with_rht, "NVFP4 4over6 quantization does not support RHT."); + NVTE_CHECK(!this->stochastic_rounding, + "NVFP4 4over6 quantization does not support stochastic rounding."); + } // We only need RHT for columnwise usage. // flat first dim and last dim for multi dimensional input @@ -2425,12 +2468,21 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou // 1. replace 1 / x by reciprocal_approximate_ftz(x) // 2. when RHT cast fusion is available, fusion allows cast to be performed on FP32 data, // this will essentially remove a round trip between FP32 to BF16 then FP32 + // NVFP4 4over6 candidate error math is controlled separately by + // NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH. const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); - if (use_fast_math) { + if (use_fast_math && !this->nvfp4_use_4over6) { quant_config.set_use_fast_math(true); quant_config_columnwise.set_use_fast_math(true); } + const auto use_4over6_err_use_fast_math = + transformer_engine::getenv("NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH"); + if (use_4over6_err_use_fast_math) { + quant_config.set_nvfp4_4over6_err_use_fast_math(true); + quant_config_columnwise.set_nvfp4_4over6_err_use_fast_math(true); + } + if (this->with_rht) { if (eligible_for_rht_cast_fusion) { // fusion kernel requires passing in RHT matrix directly for maximum performance diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index 37ab0b0535..deb5183916 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -135,6 +135,8 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) const bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); const bool with_gemm_swizzled_scales = tensor.attr("_with_gemm_swizzled_scales").cast(); const bool row_scaled_nvfp4 = tensor.attr("_row_scaled_nvfp4").cast(); + const bool nvfp4_use_4over6 = tensor.attr("_nvfp4_use_4over6").cast(); + const int nvfp4_e4m3_max = tensor.attr("_nvfp4_e4m3_max").cast(); NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for NVFP4 Tensor."); @@ -165,6 +167,8 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) // Scale layout ret.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); ret.set_row_scaled_nvfp4(row_scaled_nvfp4); + ret.set_nvfp4_4over6(nvfp4_use_4over6); + ret.set_nvfp4_e4m3_max(nvfp4_e4m3_max); // Quantizer state quantizer->set_quantization_params(&ret); diff --git a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py index acb7abefd1..42ce32065c 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py @@ -221,6 +221,8 @@ class NVFP4TensorRef(QuantizedTensorStorage): scale_t: Optional[torch.Tensor] = None global_amax_row: Optional[torch.Tensor] = None global_amax_col: Optional[torch.Tensor] = None + nvfp4_use_4over6: bool = False + nvfp4_e4m3_max: int = 448 dtype: Optional[torch.dtype] = None device: Optional[torch.device] = None @@ -350,9 +352,15 @@ def __init__( eps: float = 0.0, quant_tile_shape: Tuple[int, int] = (1, 16), row_scaled_nvfp4: bool = False, + nvfp4_use_4over6: bool = False, + nvfp4_e4m3_max: int = 448, + nvfp4_4over6_err_mode: str = "MAE", with_rht: bool = False, with_random_sign_mask: bool = True, ): + nvfp4_4over6_err_mode = nvfp4_4over6_err_mode.upper() + if nvfp4_4over6_err_mode not in ("MAE", "MSE"): + raise ValueError("nvfp4_4over6_err_mode must be 'MAE' or 'MSE'.") if row_scaled_nvfp4: if not rowwise: raise ValueError("Row-scaled NVFP4 reference quantization requires rowwise usage.") @@ -360,6 +368,13 @@ def __init__( raise ValueError( "Row-scaled NVFP4 reference quantization does not support columnwise usage." ) + if nvfp4_use_4over6: + if pow_2_scales: + raise ValueError("4over6 is only supported for NVFP4 (non-pow2) mode.") + if quant_tile_shape not in ((1, 16), (16, 16)): + raise ValueError("4over6 reference quantization only supports 1x16 or 16x16 tiles.") + if with_rht: + raise ValueError("4over6 reference quantization does not support RHT.") super().__init__(rowwise=rowwise, columnwise=columnwise) self.internal = True @@ -368,6 +383,11 @@ def __init__( self.eps = eps self.quant_tile_shape = quant_tile_shape self.row_scaled_nvfp4 = row_scaled_nvfp4 + self.nvfp4_use_4over6 = nvfp4_use_4over6 + self.nvfp4_e4m3_max = nvfp4_e4m3_max if nvfp4_use_4over6 else 448 + if self.nvfp4_e4m3_max not in (448, 256): + raise ValueError("nvfp4_e4m3_max must be 448 or 256.") + self.nvfp4_4over6_err_mode = nvfp4_4over6_err_mode self.with_rht = with_rht self.with_random_sign_mask = with_random_sign_mask @@ -446,6 +466,113 @@ def _recover_swizzled_scales( result = torch.reshape(tmp, (rounded_m, rounded_n)) return result[:m, :scale_n] + @staticmethod + def _quantize_blockwise_4over6_reference( + x: torch.Tensor, + vec_max: torch.Tensor, + global_amax: torch.Tensor, + global_encode_scale: torch.Tensor, + global_decode_scale: torch.Tensor, + row_scaled_nvfp4: bool, + tile_len_y: int, + nvfp4_4over6_err_mode: str, + nvfp4_e4m3_max: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantize NVFP4 with 4over6 candidate selection. + + This mirrors the CUDA path: map-to-4 uses a 1.5x expanded E4M3 block scale, + the configured error is computed in the original input domain with the + selected global E4M3 denominator, and ties choose map-to-6. + """ + m, num_blocks, tile_len_x = x.shape + n = num_blocks * tile_len_x + FLOAT4_E2M1_MAX = torch.tensor(6.0, device=x.device, dtype=torch.float32) + FLOAT8_E4M3_MAX = torch.tensor(448.0, device=x.device, dtype=torch.float32) + GLOBAL_SCALE_E4M3_MAX = torch.tensor( + float(nvfp4_e4m3_max), device=x.device, dtype=torch.float32 + ) + + decode_scale_base = torch.div(vec_max, FLOAT4_E2M1_MAX) * global_encode_scale + decode_scale_map4 = decode_scale_base * 1.5 + decode_scale_map6 = decode_scale_base + decode_scale_map4 = torch.clamp( + decode_scale_map4, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX + ).to(torch.float8_e4m3fn) + decode_scale_map6 = torch.clamp( + decode_scale_map6, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX + ).to(torch.float8_e4m3fn) + + fp32_max = torch.tensor( + torch.finfo(torch.float32).max, + device=decode_scale_map4.device, + dtype=torch.float32, + ) + encode_scale_map4 = torch.min( + torch.div(1.0, decode_scale_map4.to(torch.float32) * global_decode_scale), + fp32_max, + ) + encode_scale_map6 = torch.min( + torch.div(1.0, decode_scale_map6.to(torch.float32) * global_decode_scale), + fp32_max, + ) + + clipped_x_map4 = torch.clamp( + x.to(torch.float32) * encode_scale_map4, + -FLOAT4_E2M1_MAX, + FLOAT4_E2M1_MAX, + ).reshape(m, n) + clipped_x_map6 = torch.clamp( + x.to(torch.float32) * encode_scale_map6, + -FLOAT4_E2M1_MAX, + FLOAT4_E2M1_MAX, + ).reshape(m, n) + qx_map4 = cast_to_fp4x2(clipped_x_map4) + qx_map6 = cast_to_fp4x2(clipped_x_map6) + + fp4_map4 = cast_from_fp4x2(qx_map4, torch.float32).view(m, num_blocks, tile_len_x) + fp4_map6 = cast_from_fp4x2(qx_map6, torch.float32).view(m, num_blocks, tile_len_x) + denom = FLOAT4_E2M1_MAX * GLOBAL_SCALE_E4M3_MAX + sf_map4 = decode_scale_map4.to(torch.float32).squeeze(-1) + sf_map6 = decode_scale_map6.to(torch.float32).squeeze(-1) + if row_scaled_nvfp4: + error_global_amax = global_amax.squeeze(-1) + else: + error_global_amax = global_amax + x_float = x.to(torch.float32) + err_map4 = torch.zeros_like(vec_max) + err_map6 = torch.zeros_like(vec_max) + for idx in range(tile_len_x): + val_map4 = fp4_map4[:, :, idx] * sf_map4 + val_map4 = val_map4 * error_global_amax + val_map4 = val_map4 / denom + diff_map4 = val_map4 - x_float[:, :, idx] + if nvfp4_4over6_err_mode == "MSE": + err_map4 = err_map4 + (diff_map4 * diff_map4).unsqueeze(-1) + else: + err_map4 = err_map4 + torch.abs(diff_map4).unsqueeze(-1) + + val_map6 = fp4_map6[:, :, idx] * sf_map6 + val_map6 = val_map6 * error_global_amax + val_map6 = val_map6 / denom + diff_map6 = val_map6 - x_float[:, :, idx] + if nvfp4_4over6_err_mode == "MSE": + err_map6 = err_map6 + (diff_map6 * diff_map6).unsqueeze(-1) + else: + err_map6 = err_map6 + torch.abs(diff_map6).unsqueeze(-1) + if tile_len_y == 1: + pick_map4 = err_map4 < err_map6 + else: + err_map4_blocks = err_map4.view(m // tile_len_y, tile_len_y, num_blocks, 1).sum(dim=1) + err_map6_blocks = err_map6.view(m // tile_len_y, tile_len_y, num_blocks, 1).sum(dim=1) + pick_map4 = (err_map4_blocks < err_map6_blocks).repeat_interleave(tile_len_y, dim=0) + qx = torch.where( + pick_map4.expand(-1, -1, tile_len_x // 2), + qx_map4.view(m, num_blocks, tile_len_x // 2), + qx_map6.view(m, num_blocks, tile_len_x // 2), + ).reshape(m, n // 2) + decode_scale = torch.where(pick_map4, decode_scale_map4, decode_scale_map6).squeeze(-1) + return qx, decode_scale + @classmethod def _quantize_blockwise_reference( cls, @@ -456,6 +583,9 @@ def _quantize_blockwise_reference( *, pow_2_scales: bool, row_scaled_nvfp4: bool = False, + nvfp4_use_4over6: bool = False, + nvfp4_e4m3_max: int = 448, + nvfp4_4over6_err_mode: str = "MAE", eps: float, # pylint: disable=unused-argument ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -488,6 +618,10 @@ def _quantize_blockwise_reference( x = x.view(m, n // tile_len_x, tile_len_x) FLOAT4_E2M1_MAX = torch.tensor(6.0, device=x.device, dtype=torch.float32) FLOAT8_E4M3_MAX = torch.tensor(448.0, device=x.device, dtype=torch.float32) + global_scale_e4m3_max = float(nvfp4_e4m3_max if nvfp4_use_4over6 else 448) + GLOBAL_SCALE_E4M3_MAX = torch.tensor( + global_scale_e4m3_max, device=x.device, dtype=torch.float32 + ) decode_scale = torch.div(vec_max, FLOAT4_E2M1_MAX) if pow_2_scales: @@ -500,7 +634,7 @@ def _quantize_blockwise_reference( if row_scaled_nvfp4: global_amax = global_amax.to(torch.float32).view(m, 1, 1) - global_encode_scale = torch.div(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX, global_amax) + global_encode_scale = torch.div(GLOBAL_SCALE_E4M3_MAX * FLOAT4_E2M1_MAX, global_amax) global_encode_scale = torch.min( global_encode_scale, torch.tensor( @@ -519,6 +653,22 @@ def _quantize_blockwise_reference( global_encode_scale, ) global_decode_scale = torch.div(1.0, global_encode_scale) + if nvfp4_use_4over6: + # FourOverSix compares map-to-4 and map-to-6 candidates using + # the configured original input-domain error, while keeping TE-style FP4 + # quantization for each candidate. + return cls._quantize_blockwise_4over6_reference( + x, + vec_max, + global_amax, + global_encode_scale, + global_decode_scale, + row_scaled_nvfp4, + tile_len_y, + nvfp4_4over6_err_mode, + nvfp4_e4m3_max, + ) + global_encode_scale_multiplier = global_encode_scale * torch.reciprocal(FLOAT4_E2M1_MAX) # Match the kernel's default path: fold the FP4 reciprocal into the @@ -679,6 +829,9 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ self.quant_tile_shape[0], pow_2_scales=self.pow_2_scales, row_scaled_nvfp4=self.row_scaled_nvfp4, + nvfp4_use_4over6=self.nvfp4_use_4over6, + nvfp4_e4m3_max=self.nvfp4_e4m3_max, + nvfp4_4over6_err_mode=self.nvfp4_4over6_err_mode, eps=self.eps, ) if transpose_scales: @@ -702,6 +855,9 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ self.quant_tile_shape[1], self.quant_tile_shape[0], pow_2_scales=self.pow_2_scales, + nvfp4_use_4over6=self.nvfp4_use_4over6, + nvfp4_e4m3_max=self.nvfp4_e4m3_max, + nvfp4_4over6_err_mode=self.nvfp4_4over6_err_mode, eps=self.eps, ) @@ -741,6 +897,8 @@ def quantize( scale_t=sx_t, global_amax_row=global_amax_row, global_amax_col=global_amax_col, + nvfp4_use_4over6=self.nvfp4_use_4over6, + nvfp4_e4m3_max=self.nvfp4_e4m3_max, dtype=tensor.dtype, device=tensor.device, quant_dtype=self.dtype, @@ -788,6 +946,8 @@ def update_quantized( dst.scale_t = sx_t dst.global_amax_row = global_amax_row dst.global_amax_col = global_amax_col + dst.nvfp4_use_4over6 = self.nvfp4_use_4over6 + dst.nvfp4_e4m3_max = self.nvfp4_e4m3_max dst.dtype = src.dtype dst.quant_dtype = self.dtype dst.original_shape = original_shape @@ -893,7 +1053,35 @@ def qgemm( sx = sx.to(torch.float32) sw = sw.to(torch.float32) - factor = 6.0 * 6.0 * 448.0 * 448.0 + qresult_x_nvfp4_use_4over6 = getattr( + qresult_x, + "nvfp4_use_4over6", + getattr(qresult_x, "_nvfp4_use_4over6", self.nvfp4_use_4over6), + ) + qresult_w_nvfp4_use_4over6 = getattr( + qresult_w, + "nvfp4_use_4over6", + getattr(qresult_w, "_nvfp4_use_4over6", self.nvfp4_use_4over6), + ) + qresult_x_e4m3_max = getattr( + qresult_x, + "nvfp4_e4m3_max", + getattr(qresult_x, "_nvfp4_e4m3_max", self.nvfp4_e4m3_max), + ) + qresult_w_e4m3_max = getattr( + qresult_w, + "nvfp4_e4m3_max", + getattr(qresult_w, "_nvfp4_e4m3_max", self.nvfp4_e4m3_max), + ) + if qresult_x_nvfp4_use_4over6: + fp8_max_x = float(qresult_x_e4m3_max) + else: + fp8_max_x = 448.0 + if qresult_w_nvfp4_use_4over6: + fp8_max_w = float(qresult_w_e4m3_max) + else: + fp8_max_w = 448.0 + factor = 6.0 * 6.0 * fp8_max_x * fp8_max_w if gemm_type == quantization.GEMMType.WGRAD: partial_alpha = qresult_x.global_amax_col * qresult_w.global_amax_col diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 0c40723517..5872c70963 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -1655,6 +1655,23 @@ def _qparams(tensor_type: str): def _make(tensor_type: str) -> NVFP4Quantizer: qparams = _qparams(tensor_type) + nvfp4_use_4over6 = False + if self.recipe.nvfp4_4over6 == "all": + nvfp4_use_4over6 = True + elif self.recipe.nvfp4_4over6 == "weights": + nvfp4_use_4over6 = tensor_type == "weight" + elif self.recipe.nvfp4_4over6 == "activations": + nvfp4_use_4over6 = tensor_type != "weight" + nvfp4_e4m3_max = 448 + if nvfp4_use_4over6: + if self.recipe.nvfp4_e4m3_max == "all": + nvfp4_e4m3_max = 256 + elif self.recipe.nvfp4_e4m3_max == "weights": + if tensor_type == "weight": + nvfp4_e4m3_max = 256 + elif self.recipe.nvfp4_e4m3_max == "activations": + if tensor_type != "weight": + nvfp4_e4m3_max = 256 return NVFP4Quantizer( fp4_dtype=self.dtype, rowwise=True, @@ -1668,6 +1685,9 @@ def _make(tensor_type: str) -> NVFP4Quantizer: and tensor_type != "weight" and self.recipe.row_scaled_activation ), + nvfp4_use_4over6=nvfp4_use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=self.recipe.nvfp4_4over6_err_mode, ) if self.mode not in ("forward", "backward"): diff --git a/transformer_engine/pytorch/tensor/grouped_tensor.py b/transformer_engine/pytorch/tensor/grouped_tensor.py index f28f972b58..0cc03602a1 100644 --- a/transformer_engine/pytorch/tensor/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/grouped_tensor.py @@ -93,6 +93,8 @@ def __new__( stride: Optional[List[int]] = None, with_gemm_swizzled_scales: bool = False, row_scaled_nvfp4: bool = False, + nvfp4_use_4over6: bool = False, + nvfp4_e4m3_max: int = 448, ): if ( shapes is not None @@ -166,6 +168,8 @@ def __new__( columnwise_scale_inv_offsets=columnwise_scale_inv_offsets, with_gemm_swizzled_scales=with_gemm_swizzled_scales, row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=nvfp4_use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, ) return instance @@ -198,6 +202,8 @@ def copy_grouped_storage_metadata(dst: GroupedTensor, src: GroupedTensor) -> Non dst.quantized_tensors = src.quantized_tensors dst._with_gemm_swizzled_scales = src._with_gemm_swizzled_scales dst.row_scaled_nvfp4 = src.row_scaled_nvfp4 + dst.nvfp4_use_4over6 = src.nvfp4_use_4over6 + dst.nvfp4_e4m3_max = src.nvfp4_e4m3_max def make_wrapper_like(src: GroupedTensor, requires_grad: bool) -> GroupedTensor: """Create a wrapper of the same type and tensor metadata as src.""" diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index df7a2b4bd3..eb38add62b 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -130,6 +130,12 @@ class NVFP4Quantizer(Quantizer): """Whether emitted NVFP4 tensors store one FP32 amax per row.""" row_scaled_nvfp4: bool + """Whether to use NVFP4 4over6 map-to-4/map-to-6 block selection.""" + nvfp4_use_4over6: bool + """Global E4M3 scale bound used by emitted NVFP4 tensors.""" + nvfp4_e4m3_max: int + """NVFP4 4over6 candidate-selection error mode.""" + nvfp4_4over6_err_mode: str """RHT matrix random sign mask""" rht_matrix_random_sign_mask_t: int @@ -147,6 +153,9 @@ def __init__( with_2d_quantization: bool = False, stochastic_rounding: bool = False, row_scaled_nvfp4: bool = False, + nvfp4_use_4over6: bool = False, + nvfp4_e4m3_max: int = 448, + nvfp4_4over6_err_mode: str = "MAE", with_random_sign_mask: bool = True, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) @@ -158,6 +167,13 @@ def __init__( self.with_2d_quantization = with_2d_quantization self.stochastic_rounding = stochastic_rounding self.row_scaled_nvfp4 = row_scaled_nvfp4 + self.nvfp4_use_4over6 = nvfp4_use_4over6 + self.nvfp4_e4m3_max = nvfp4_e4m3_max if nvfp4_use_4over6 else 448 + if self.nvfp4_e4m3_max not in (448, 256): + raise ValueError("nvfp4_e4m3_max must be 448 or 256.") + self.nvfp4_4over6_err_mode = nvfp4_4over6_err_mode.upper() + if self.nvfp4_4over6_err_mode not in ("MAE", "MSE"): + raise ValueError("nvfp4_4over6_err_mode must be 'MAE' or 'MSE'.") self.rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht( with_random_sign_mask, torch.cuda.current_device() ) @@ -204,6 +220,9 @@ def copy(self) -> NVFP4Quantizer: with_2d_quantization=self.with_2d_quantization, stochastic_rounding=self.stochastic_rounding, row_scaled_nvfp4=self.row_scaled_nvfp4, + nvfp4_use_4over6=self.nvfp4_use_4over6, + nvfp4_e4m3_max=self.nvfp4_e4m3_max, + nvfp4_4over6_err_mode=self.nvfp4_4over6_err_mode, ) quantizer.internal = self.internal quantizer.optimize_for_gemm = self.optimize_for_gemm @@ -356,6 +375,8 @@ def __new__( quantizer: Quantizer, with_gemm_swizzled_scales: bool, row_scaled_nvfp4: bool = False, + nvfp4_use_4over6: bool = False, + nvfp4_e4m3_max: int = 448, **kwargs, ): instance = super().__new__( @@ -371,6 +392,8 @@ def __new__( with_gemm_swizzled_scales, *args, row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=nvfp4_use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, **kwargs, ) return instance @@ -528,6 +551,9 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m columnwise_usage, self._amax_rowwise, self._amax_columnwise, + self._row_scaled_nvfp4, + self._nvfp4_use_4over6, + self._nvfp4_e4m3_max, self.shape[-1], ) return sharded_tensors, metadata @@ -546,7 +572,16 @@ def fsdp_post_all_gather( all-gathered rowwise data. Columnwise data is derived locally via _create_columnwise() instead of being all-gathered. """ - fp4_dtype, columnwise_usage, amax_rowwise, amax_columnwise, K = metadata + ( + fp4_dtype, + columnwise_usage, + amax_rowwise, + amax_columnwise, + row_scaled_nvfp4, + nvfp4_use_4over6, + nvfp4_e4m3_max, + K, + ) = metadata # Only rowwise data+scales were all-gathered rowwise_data, rowwise_scale_inv = all_gather_outputs[:2] @@ -569,6 +604,9 @@ def fsdp_post_all_gather( out._rowwise_scale_inv = rowwise_scale_inv out._amax_rowwise = amax_rowwise out._amax_columnwise = amax_columnwise + out._row_scaled_nvfp4 = row_scaled_nvfp4 + out._nvfp4_use_4over6 = nvfp4_use_4over6 + out._nvfp4_e4m3_max = nvfp4_e4m3_max else: # Construct new tensor (first iteration) out = NVFP4Tensor( @@ -584,6 +622,9 @@ def fsdp_post_all_gather( quantizer=self._quantizer, requires_grad=False, with_gemm_swizzled_scales=False, + row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=nvfp4_use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, ) # Derive columnwise data locally via transpose instead of all-gathering it @@ -722,6 +763,9 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): quantizer=tensor._quantizer, requires_grad=tensor.requires_grad, with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, + row_scaled_nvfp4=tensor._row_scaled_nvfp4, + nvfp4_use_4over6=tensor._nvfp4_use_4over6, + nvfp4_e4m3_max=tensor._nvfp4_e4m3_max, ) # Default case @@ -741,6 +785,9 @@ def _make_in_reduce_ex( dtype: torch.dtype, quantizer: Quantizer, with_gemm_swizzled_scales: bool = False, + row_scaled_nvfp4: bool = False, + nvfp4_use_4over6: bool = False, + nvfp4_e4m3_max: int = 448, ) -> NVFP4Tensor: """Build NVFP4Tensor, for use in __reduce__ @@ -761,6 +808,9 @@ def _make_in_reduce_ex( quantizer=quantizer, requires_grad=False, with_gemm_swizzled_scales=with_gemm_swizzled_scales, + row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=nvfp4_use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, ) def __reduce_ex__(self, protocol: int) -> tuple: @@ -779,6 +829,9 @@ def __reduce_ex__(self, protocol: int) -> tuple: self.dtype, self._quantizer, self._with_gemm_swizzled_scales, + self._row_scaled_nvfp4, + self._nvfp4_use_4over6, + self._nvfp4_e4m3_max, ), ) @@ -831,6 +884,9 @@ def _set_data(self, tensor: torch.Tensor) -> None: self._amax_rowwise = tensor._amax_rowwise self._amax_columnwise = tensor._amax_columnwise self._with_gemm_swizzled_scales = tensor._with_gemm_swizzled_scales + self._row_scaled_nvfp4 = tensor._row_scaled_nvfp4 + self._nvfp4_use_4over6 = tensor._nvfp4_use_4over6 + self._nvfp4_e4m3_max = tensor._nvfp4_e4m3_max return # Quantize to FP8 @@ -951,6 +1007,9 @@ def forward( fp4_dtype=tensor._fp4_dtype, requires_grad=tensor.requires_grad, with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, + row_scaled_nvfp4=tensor._row_scaled_nvfp4, + nvfp4_use_4over6=tensor._nvfp4_use_4over6, + nvfp4_e4m3_max=tensor._nvfp4_e4m3_max, ) @staticmethod @@ -993,6 +1052,9 @@ def backward( fp4_dtype=grad._fp4_dtype, requires_grad=grad.requires_grad, with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales, + row_scaled_nvfp4=grad._row_scaled_nvfp4, + nvfp4_use_4over6=grad._nvfp4_use_4over6, + nvfp4_e4m3_max=grad._nvfp4_e4m3_max, ) return dgrad, None return grad.view(ctx.shape), None @@ -1077,6 +1139,9 @@ def forward( fp4_dtype=tensor._fp4_dtype, requires_grad=tensor.requires_grad, with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, + row_scaled_nvfp4=tensor._row_scaled_nvfp4, + nvfp4_use_4over6=tensor._nvfp4_use_4over6, + nvfp4_e4m3_max=tensor._nvfp4_e4m3_max, ) @staticmethod @@ -1119,6 +1184,9 @@ def backward( fp4_dtype=grad._fp4_dtype, requires_grad=grad.requires_grad, with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales, + row_scaled_nvfp4=grad._row_scaled_nvfp4, + nvfp4_use_4over6=grad._nvfp4_use_4over6, + nvfp4_e4m3_max=grad._nvfp4_e4m3_max, ) return dgrad, None return grad.view(ctx.shape), None diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index ac56d334bc..438e124021 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -73,6 +73,8 @@ def _initialize_storage_fields( stride: Optional[List[int]] = None, with_gemm_swizzled_scales: bool = False, row_scaled_nvfp4: bool = False, + nvfp4_use_4over6: bool = False, + nvfp4_e4m3_max: int = 448, ) -> None: """ Initialize a GroupedTensor. @@ -149,6 +151,8 @@ def _initialize_storage_fields( instance.quantized_tensors = None instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales instance.row_scaled_nvfp4 = row_scaled_nvfp4 + instance.nvfp4_use_4over6 = nvfp4_use_4over6 + instance.nvfp4_e4m3_max = nvfp4_e4m3_max if nvfp4_use_4over6 else 448 def __new__( cls, @@ -175,6 +179,8 @@ def __new__( stride: Optional[List[int]] = None, with_gemm_swizzled_scales: bool = False, row_scaled_nvfp4: bool = False, + nvfp4_use_4over6: bool = False, + nvfp4_e4m3_max: int = 448, ): instance = object.__new__(cls) cls._initialize_storage_fields( @@ -201,6 +207,8 @@ def __new__( stride=stride, with_gemm_swizzled_scales=with_gemm_swizzled_scales, row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=nvfp4_use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, ) return instance @@ -307,6 +315,33 @@ def get_dtype(self) -> torch.dtype: return self.fake_dtype + @property + def row_scaled_nvfp4(self) -> bool: + """Whether grouped NVFP4 tensors use row-scaled amax metadata.""" + return self._row_scaled_nvfp4 + + @row_scaled_nvfp4.setter + def row_scaled_nvfp4(self, row_scaled_nvfp4: bool) -> None: + self._row_scaled_nvfp4 = row_scaled_nvfp4 + + @property + def nvfp4_use_4over6(self) -> bool: + """Whether grouped NVFP4 tensors carry 4over6 metadata.""" + return self._nvfp4_use_4over6 + + @nvfp4_use_4over6.setter + def nvfp4_use_4over6(self, nvfp4_use_4over6: bool) -> None: + self._nvfp4_use_4over6 = nvfp4_use_4over6 + + @property + def nvfp4_e4m3_max(self) -> int: + """Global E4M3 scale bound used by grouped NVFP4 tensors.""" + return self._nvfp4_e4m3_max + + @nvfp4_e4m3_max.setter + def nvfp4_e4m3_max(self, nvfp4_e4m3_max: int) -> None: + self._nvfp4_e4m3_max = nvfp4_e4m3_max + def prepare_for_saving( self, ) -> Tuple[list[Optional[torch.Tensor]], "GroupedTensorStorage"]: @@ -376,6 +411,8 @@ def clear(self) -> None: self.tensor_shapes = [] self.fake_dtype = torch.float32 self.row_scaled_nvfp4 = False + self.nvfp4_use_4over6 = False + self.nvfp4_e4m3_max = 448 def __repr__(self) -> str: """String representation of the GroupedTensorStorage.""" @@ -545,6 +582,8 @@ def copy(self) -> "GroupedTensorStorage": columnwise_scale_inv_offsets=self.columnwise_scale_inv_offsets, with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, row_scaled_nvfp4=self.row_scaled_nvfp4, + nvfp4_use_4over6=self.nvfp4_use_4over6, + nvfp4_e4m3_max=self.nvfp4_e4m3_max, ) @staticmethod @@ -656,6 +695,8 @@ def make_grouped_tensor( scale_inv_offsets = None columnwise_scale_inv_offsets = None row_scaled_nvfp4 = False + nvfp4_use_4over6 = False + nvfp4_e4m3_max = 448 if no_quantization: assert dtype is not None, "dtype must be provided for unquantized GroupedTensor" if rowwise_usage: @@ -715,6 +756,8 @@ def make_grouped_tensor( amax = torch.empty(num_tensors, dtype=torch.float32, device=device) elif quantizer._get_compatible_recipe().nvfp4(): row_scaled_nvfp4 = quantizer.row_scaled_nvfp4 + nvfp4_use_4over6 = quantizer.nvfp4_use_4over6 + nvfp4_e4m3_max = quantizer.nvfp4_e4m3_max if row_scaled_nvfp4: if not rowwise_usage: raise ValueError( @@ -843,6 +886,8 @@ def make_grouped_tensor( quantizer.optimize_for_gemm if quantizer is not None else False ), row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=nvfp4_use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, ) grouped_tensor.quantized_tensors = grouped_tensor.split_into_quantized_tensors() return grouped_tensor @@ -957,6 +1002,8 @@ def split_into_quantized_tensors( self.columnwise_scale_inv_offsets = columnwise_scale_inv_offsets nvfp4_rowwise_amax_offsets = None row_scaled_nvfp4 = self.row_scaled_nvfp4 + nvfp4_use_4over6 = self.nvfp4_use_4over6 + nvfp4_e4m3_max = self.nvfp4_e4m3_max if recipe.nvfp4() and row_scaled_nvfp4: cum = 0 nvfp4_rowwise_amax_offsets = [0] @@ -1184,6 +1231,8 @@ def split_into_quantized_tensors( quantizer=quantizer, with_gemm_swizzled_scales=quantizer.optimize_for_gemm, row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=nvfp4_use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, ) result.append(tensor) diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index e51acb71e5..7a7f6012dd 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -99,6 +99,10 @@ class NVFP4TensorStorage(QuantizedTensorStorage): _with_gemm_swizzled_scales: bool # Whether this NVFP4 tensor uses row-scaled amax metadata _row_scaled_nvfp4: bool + # Whether this NVFP4 tensor uses 4over6 map-to-4/map-to-6 block selection + _nvfp4_use_4over6: bool + # Global E4M3 scale bound used by this NVFP4 tensor + _nvfp4_e4m3_max: int def __new__( cls, @@ -114,6 +118,8 @@ def __new__( *args, fake_dtype: Optional[torch.dtype] = None, row_scaled_nvfp4: bool = False, + nvfp4_use_4over6: bool = False, + nvfp4_e4m3_max: int = 448, **kwargs, ): if cls is NVFP4TensorStorage: @@ -132,6 +138,8 @@ def __new__( instance._amax_columnwise = amax_columnwise instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales instance._row_scaled_nvfp4 = row_scaled_nvfp4 + instance._nvfp4_use_4over6 = nvfp4_use_4over6 + instance._nvfp4_e4m3_max = nvfp4_e4m3_max if nvfp4_use_4over6 else 448 return instance @@ -158,6 +166,10 @@ def copy_from_storage(self, src: QuantizedTensorStorage) -> None: raise RuntimeError("Scale layout mismatch in copy_from_storage") if self._row_scaled_nvfp4 != src._row_scaled_nvfp4: raise RuntimeError("Rowwise amax scaling mode mismatch in copy_from_storage") + if self._nvfp4_use_4over6 != src._nvfp4_use_4over6: + raise RuntimeError("NVFP4 4over6 mode mismatch in copy_from_storage") + if self._nvfp4_e4m3_max != src._nvfp4_e4m3_max: + raise RuntimeError("NVFP4 4over6 E4M3 scale bound mismatch in copy_from_storage") def _copy_optional(dst: Optional[torch.Tensor], src_tensor: Optional[torch.Tensor]): if dst is not None and src_tensor is not None: @@ -183,6 +195,8 @@ def get_metadata(self) -> Dict[str, Any]: "quantizer": self._quantizer, "with_gemm_swizzled_scales": self._with_gemm_swizzled_scales, "row_scaled_nvfp4": self._row_scaled_nvfp4, + "nvfp4_use_4over6": self._nvfp4_use_4over6, + "nvfp4_e4m3_max": self._nvfp4_e4m3_max, "fake_dtype": self._dtype, } @@ -316,6 +330,8 @@ def view(self, shape: torch.Size): fp4_dtype=self._fp4_dtype, with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, row_scaled_nvfp4=self._row_scaled_nvfp4, + nvfp4_use_4over6=self._nvfp4_use_4over6, + nvfp4_e4m3_max=self._nvfp4_e4m3_max, fake_dtype=self._dtype, )