Implement 4over6 NVFP4 recipe#2972
Conversation
Greptile SummaryThis PR implements the NVFP4 4over6 quantization algorithm from the FourOverSix paper, enabling per-block map-to-4 vs. map-to-6 candidate selection for 1D and 2D NVFP4 quantization. The feature is gated behind a new
Confidence Score: 4/5Safe to merge with awareness of the pre-existing fast-math gap on the single-tensor path; all new 4over6 code is well-guarded and the core quantization logic aligns with the reference. The 4over6 feature is implemented thoroughly across all 41 changed files. The flag threads correctly through Python recipe, tensor metadata, C++ quantizer, and all CUDA dispatch paths. Validation checks reject incompatible combinations at multiple layers. The Python reference mirrors the CUDA MSE logic with the correct 256-denominator and 1.5x scale expansion. The only new findings are minor style-level concerns. transformer_engine/pytorch/csrc/extensions/cast.cpp - the use_fast_math env-var read is only wired into the split-quantize helper, not into the single-tensor quantize_impl, so the fast-math variant of the 4over6 MSE kernel is unreachable for ordinary single-tensor calls. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["NVFP4BlockScaling recipe\nnvfp4_4over6: weights/activations/all/None"] --> B["RecipeState._make\nresolves use_4over6 per tensor_type"]
B --> C["NVFP4Quantizer with use_4over6"]
C --> D["create_tensor / quantize_impl"]
D --> E{use_4over6?}
E -->|No| G["Standard NVFP4 path\n448-based global scale"]
E -->|Yes| F{valid combo?}
F -->|RHT or SR or grouped| ERR["NVTE_CHECK error"]
F -->|OK| H["compute_global_encode_scaling_factor 256-based"]
H --> I["compute_4over6_decoding_scaling_factors\nmap6 and map4 candidates"]
I --> J["cvt_fp32_to_fp4_8x_with_mse_rn x2\naccumulate MSE for both candidates"]
J --> K{err_map4 < err_map6?}
K -->|Yes| L["rOut_map4 selected"]
K -->|No - ties go to map6| M["rOut_map6 selected"]
L --> N["NVFP4Tensor _use_4over6=True\nglobal E4M3 bound = 256"]
M --> N
N --> O["Dequant and GEMM scale use 256 denominator"]
Reviews (7): Last reviewed commit: "Drop write back lifting" | Re-trigger Greptile |
|
Functionality has been verified by internal RL experiments. |
|
Need to rebase. |
| * its values are populated during quantization. | ||
| */ | ||
| kNVTERowScaledNVFP4 = 8, | ||
| kNVTENVFP44Over6 = 9, /*!< Whether an NVFP4 tensor uses 4over6 scaling */ |
There was a problem hiding this comment.
We are specifying this redundantly in NVTETensor and NVTEQuantizationConfig. If this option can be isolated to quantization, then we should not add clutter to the tensor. If the option is needed for downstream consumers (dequantization, GEMM), then it should be treated as part of the tensor data. I'm not especially familiar, but 4over6 seems like it should be specific to quantization.
There was a problem hiding this comment.
4over6 changes the decode convention from 1 / (6 * 448) to 1 / (6 * 256). Therefore, for our current representation 4over6 is part of the tensor data contract, not just a quantization option.
| using namespace detail; | ||
| constexpr float fp8_max = TypeExtrema<fp8e4m3>::max; // 448.0f; | ||
| constexpr float fp4_max = TypeExtrema<fp4e2m1>::max; // 6.0f; | ||
| constexpr float fp8_max = USE_4OVER6 ? 256.0f : TypeExtrema<fp8e4m3>::max; // 448.0f; |
There was a problem hiding this comment.
How much benefit does changing the FP8 scale have on convergence? If we don't see a clear benefit, then it would be nicer to use the same scale for 4over6 and non-4over6. That way keep can keep this logic confined to quantization, and downstream consumers are completely unaffected.
If there is an impact on training quality, we should still consider disentangling the FP8 scaling from 4over6. I don't see why other NVFP4 recipes might not benefit from tweaking the scaling.
There was a problem hiding this comment.
From the original paper:
Finally, we make one modification to the computation of the tensor scale α (Equation 1) when
quantizing to NVFP4 with 4/6. When MFP4 ×MFP8 is used to compute the tensor scale, it ensures
that all quantized values will be less than 6 ×448. However, this makes it impossible to select a scale
of 4 for the blocks that contain a tensor’s largest values, because the block’s scale would need to be
448 × 6/4 = 672, which would overflow since 448 is the maximum value that can be represented by
E4M3. As a result, when computing the tensor scale, we replace MFP8 to 256 in Equation 1, since
256 is the largest E4M3 that can be multiplied by 6/4 and represented without error in E4M3, as 384.
Also:
In Section 3.1, we propose calculating the FP32 global tensor scale using 256 as the maximum FP8
E4M3 value rather than the default of 448, as this allows blocks with a tensor’s largest value to have
the option to have a largest FP4 value of 4. In Figure 6, we find that this provides a marginal benefit
over using the standard tensor scale calculation. Even though this adjustment only affects a small
number of large values, this performance gain may come from the fact that larger activation values
can have an outsize impact on model performance. This adjustment is incorporated into the remaining
experiments in this section.
There was a problem hiding this comment.
Not sure if there are internal or external studies about the convergence. But this is required to make it work. We need the largest value that is smaller than 448/1.5 and which is itself, and its multiplication by 1.5 is represented by E4M3 exactly. This would help to avoid quantization noise on both map to 4 and map to 6 paths.
There was a problem hiding this comment.
We did find the use of 256 to calculate the second level scaling factor helped convergence vs 448, but only slightly.
It's possible that the premise of the paper's argument (prevent saturations when 4 scaling effectively multiplies the block decode scale by 1.5) is sound, but a value larger than 256 can achieve this and the perfect representation of the block with the global amax value with both scalings is not worth the extra range loss.
There was a problem hiding this comment.
let me make 256 scaling a separate env var disabled by default
There was a problem hiding this comment.
448, 320, 288, 256 are all potential candidates for map-to-6:
- 448: effectively disable map-to-4 option above 256, preserve range
- 320, 288: map-to-4 uses 448, no precise 1.5x
- 256: map-to-4 uses 384, precise 1.5x
For now let me refactor the interface to NVTE_NVFP4_4OVER6_E4M3="448"|"256", default to "448" and dispatches to a number in template parameter in C++ code instead of a boolean toggle. People can add support for other values or make it more generic (like directly parsing the env var digits) in the future.
There was a problem hiding this comment.
NVTE_NVFP4_4OVER6_E4M3_USE_256=weights|activations|all is a cleaner pattern and allows separate configuration.
There was a problem hiding this comment.
This test is okay, but it would provide much more confidence if the NVFP4 quantization tests compared against a CPU reference impl.
There was a problem hiding this comment.
Extended tests/cpp/operator/test_cast_nvfp4_transpose.cu coverage in 3bb42b1.
| nvfp4_4over6 : {None, 'weights', 'activations', 'all'}, default = None | ||
| Select tensors that use NVFP4 4over6. In this mode NVFP4 | ||
| quantization evaluates per-block map-to-4 and map-to-6 candidates | ||
| and chooses the one with lower MSE. Ties choose map-to-6. The |
There was a problem hiding this comment.
We need both MSE (better for post-training?) and MAE (better for pre-training as per our internal studies) to be supported, with MAE as the default.
| using namespace detail; | ||
| constexpr float fp8_max = TypeExtrema<fp8e4m3>::max; // 448.0f; | ||
| constexpr float fp4_max = TypeExtrema<fp4e2m1>::max; // 6.0f; | ||
| constexpr float fp8_max = USE_4OVER6 ? 256.0f : TypeExtrema<fp8e4m3>::max; // 448.0f; |
There was a problem hiding this comment.
Not sure if there are internal or external studies about the convergence. But this is required to make it work. We need the largest value that is smaller than 448/1.5 and which is itself, and its multiplication by 1.5 is represented by E4M3 exactly. This would help to avoid quantization noise on both map to 4 and map to 6 paths.
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
This reverts commit 69f9ccc. Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Description
@HumansAnd
Implement 4over6 nvfp4 from:
FlashInfer PR:
Enable per-block map-to-4 versus map-to-6 candidate selection for 1D/2D NVFP4 quantization in the
NVFP4BlockScalingrecipe. This mode currently requires RHT and stochastic rounding to be disabled. Both original per-tensor scaling and row-scaling NVFP4 introduced by #2931 are supported.This PR also fixes a few minor bugs for row-scaled NVFP4 from #2931.
Type of change
Changes
Please list the changes introduced in this PR:
NVTE_NVFP4_4OVER6=weights|activations|all, with unset preserving existing behavior, and threads the selected scope through recipes, quantizers, tensor metadata, split quantization, single-tensor quantization, and C++ tensor/config APIs.NVTE_USE_FAST_MATH, and rejecting unsupported combinations such as stochastic rounding, grouped tensors, and RHT.Checklist: