Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
19b6b08
Initial implementation
zianglih May 9, 2026
7b0b2d0
Make 4over6 compile time for dequant
zianglih May 9, 2026
1e5b6ad
Expand 1d fwd+bwd test
zianglih May 9, 2026
99660fc
Refactor
zianglih May 9, 2026
cb2e0a3
Clean up
zianglih May 9, 2026
2c066f9
Clean up
zianglih May 9, 2026
69e8f3a
Add gemm test
zianglih May 9, 2026
009e651
Add more tests and fix offload
zianglih May 9, 2026
3153fc3
Fix offload
zianglih May 9, 2026
e31b758
Clean up arg
zianglih May 9, 2026
fcd526c
Add more test
zianglih May 9, 2026
100c378
Add more tests
zianglih May 10, 2026
1c9f26b
Clean up test
zianglih May 10, 2026
93fe922
Refactor cuh kernel impl
zianglih May 10, 2026
f4e4a4e
Further extract
zianglih May 10, 2026
b3f59ee
Clean up
zianglih May 10, 2026
31decf9
Add recipe_id
zianglih May 10, 2026
2fa6b8c
Fix failing unit tests
zianglih May 10, 2026
7df2db0
Clean up test
zianglih May 10, 2026
ce85be2
Clean up
zianglih May 10, 2026
1b68038
Refactor ref
zianglih May 10, 2026
bb722a3
Update comments and docs
zianglih May 10, 2026
fe18a1e
Drop unnecessary test_sanity workaround
zianglih May 10, 2026
522e93e
Refactor `QuantizerRole`
zianglih May 11, 2026
782b7ee
Allow separate recipe 4over6 config
zianglih May 11, 2026
d9cd12c
Support 2d
zianglih May 12, 2026
708c1ec
Refactor 2d
zianglih May 12, 2026
4d31f18
Clean up anti pattern
zianglih May 12, 2026
dfc15f2
Enforce 4over6 consistency
zianglih May 12, 2026
9453670
Update comments
zianglih May 12, 2026
6d871da
Update docs
zianglih May 12, 2026
f8338e8
Fix test
zianglih May 12, 2026
c9bc921
Drop test_fusible_ops
zianglih May 12, 2026
00ba694
Revert "Drop test_fusible_ops"
zianglih May 12, 2026
3252d4e
Refactor test_fusible_ops
zianglih May 12, 2026
3f33c1d
Refactor ref and extend cpp test
zianglih May 12, 2026
8607e03
Clean up cpp test
zianglih May 12, 2026
d3dbf34
Minor comment
zianglih May 12, 2026
565f33f
Drop doc
zianglih May 12, 2026
54b4da8
Explicit handle conditional smem buffer
zianglih May 12, 2026
fa09200
Further clean up
zianglih May 12, 2026
e57e8be
More templates
zianglih May 12, 2026
a1df319
Simplify cpp
zianglih May 12, 2026
21720da
Drop write back lifting
zianglih May 12, 2026
b1d073a
Add MAE and dedicated fast math env var
zianglih May 12, 2026
0392708
Harden cpp test
zianglih May 12, 2026
0b77a37
Add warning and err fast math coverage
zianglih May 12, 2026
81e579e
Fold test case and clean up cpp test
zianglih May 12, 2026
1e311ef
Initial 448 vs 256 implementation
zianglih May 12, 2026
38a1c4c
Use e4m3 max instead of boolean, more template
zianglih May 12, 2026
3cdd9d9
Add benchmark script and minor optimization
zianglih May 13, 2026
7deba75
Use standalone kernels
zianglih May 13, 2026
93dbf2b
Use cp async
zianglih May 13, 2026
8819d12
Add benchmark script
zianglih May 13, 2026
24e417b
Minor fix after rebase
zianglih May 13, 2026
472e5b8
Naming consistency
zianglih May 13, 2026
83e2308
Remove 4over6 benchmark
zianglih May 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions docs/envvars.rst
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,30 @@ Kernel Configuration
:Default: ``0``
:Description: Enable row-scaled NVFP4 tensors for forward activation quantizers in the ``NVFP4BlockScaling`` recipe. When set to ``1`` (or when ``NVFP4BlockScaling(row_scaled_activation=True)`` is used), rowwise ``amax`` metadata is stored as one FP32 value per tensor row instead of a single scalar.

.. envvar:: NVTE_NVFP4_4OVER6

:Type: ``str`` (``weights``, ``activations``, or ``all``)
:Default: unset
:Description: Enable per-block map-to-4 versus map-to-6 candidate selection for selected NVFP4 quantizers in the ``NVFP4BlockScaling`` recipe. ``weights`` selects weight tensor roles, ``activations`` selects non-weight tensor roles, and ``all`` selects both. The selected block scale is the candidate with lower configured input-domain error, and ties select map-to-6. By default, this mode keeps the standard NVFP4 global E4M3 scale bound of 448. Tensors using 4over6 currently require RHT and stochastic rounding to be disabled; activation and backward scopes therefore require ``NVTE_NVFP4_DISABLE_RHT=1`` and ``NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING=1``.

.. envvar:: NVTE_NVFP4_4OVER6_E4M3_USE_256

:Type: ``str`` (``weights``, ``activations``, or ``all``)
:Default: unset
:Description: Select NVFP4 4over6 quantizers that use 256 instead of 448 as the global E4M3 scale bound. ``weights`` selects weight tensor roles, ``activations`` selects non-weight tensor roles, and ``all`` selects both. This option is only meaningful for tensor roles that also enable :envvar:`NVTE_NVFP4_4OVER6`.

.. envvar:: NVTE_NVFP4_4OVER6_ERR_MODE

:Type: ``str`` (``MAE`` or ``MSE``)
:Default: ``MAE``
:Description: Select the input-domain error metric used by NVFP4 4over6 map-to-4 versus map-to-6 candidate selection in the ``NVFP4BlockScaling`` recipe.

.. envvar:: NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH

:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Allow the NVFP4 4over6 candidate error computation to use faster non-strict floating-point expressions. By default, 4over6 error comparison uses strict expressions; ``NVTE_USE_FAST_MATH`` does not control this error-comparison path.

Torch Compilation and Fusion
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
628 changes: 529 additions & 99 deletions tests/cpp/operator/test_cast_nvfp4_transpose.cu

Large diffs are not rendered by default.

73 changes: 59 additions & 14 deletions tests/cpp/operator/test_dequantize_nvfp4.cu
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is okay, but it would provide much more confidence if the NVFP4 quantization tests compared against a CPU reference impl.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extended tests/cpp/operator/test_cast_nvfp4_transpose.cu coverage in 3bb42b1.

Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ void compute_ref_dequantize_nvfp4(const uint8_t *packed_data,
OType *output,
size_t rows,
size_t cols,
size_t scale_stride) {
constexpr float factor_inv = 1.0f / (6.0f * 448.0f);
size_t scale_stride,
int e4m3_max) {
const float factor_inv = 1.0f / (6.0f * static_cast<float>(e4m3_max));
constexpr size_t BLOCK_SIZE = 16;
const size_t Mread = cols / BLOCK_SIZE;
const size_t bytes_per_block = BLOCK_SIZE / 2;
Expand Down Expand Up @@ -90,7 +91,9 @@ float compute_amax(test::Tensor &t, size_t rows, size_t cols) {
// against a CPU reference computed from the quantized data.
template <typename OutputType>
void performTest_dequantize_nvfp4(const size_t rows, const size_t cols,
const bool row_scaled_nvfp4) {
const bool row_scaled_nvfp4,
const bool use_4over6,
const int e4m3_max) {
using namespace test;
DType otype = TypeInfo<OutputType>::dtype;

Expand All @@ -105,6 +108,10 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols,

// Configure quantized tensor amax
size_t amax_size = 1;
quantized.set_nvfp4_4over6(use_4over6);
quantized.set_nvfp4_e4m3_max((use_4over6 ? e4m3_max : 448));
ASSERT_EQ(quantized.nvfp4_4over6(), use_4over6);
ASSERT_EQ(quantized.nvfp4_e4m3_max(), (use_4over6 ? e4m3_max : 448));
if (row_scaled_nvfp4) {
quantized.set_row_scaled_nvfp4(true);
amax_size = rows;
Expand All @@ -116,7 +123,10 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols,

// Quantize
if (rows > 0 && cols > 0) {
nvte_quantize(input.data(), quantized.data(), 0);
QuantizationConfigWrapper quant_config;
quant_config.set_nvfp4_4over6(use_4over6);
quant_config.set_nvfp4_e4m3_max((use_4over6 ? e4m3_max : 448));
nvte_quantize_v2(input.data(), quantized.data(), quant_config, 0);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
Expand Down Expand Up @@ -146,7 +156,7 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols,
std::make_unique<OutputType[]>(rows * cols);
compute_ref_dequantize_nvfp4<OutputType>(
fp4_data, scales, amax_vals, ref_output.get(),
rows, cols, scale_stride);
rows, cols, scale_stride, (use_4over6 ? e4m3_max : 448));

// Compare results from TE and reference impls
auto [atol, rtol] = getTolerances(otype);
Expand All @@ -156,7 +166,9 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols,
// Dequantize NVFP4 with GEMM-swizzled scales and compare against compact path.
template <typename OutputType>
void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols,
const bool row_scaled_nvfp4) {
const bool row_scaled_nvfp4,
const bool use_4over6,
const int e4m3_max) {
using namespace test;
DType otype = TypeInfo<OutputType>::dtype;

Expand All @@ -165,6 +177,10 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols,

Tensor quantized_compact("quantized_compact", std::vector<size_t>{rows, cols},
DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING);
quantized_compact.set_nvfp4_4over6(use_4over6);
quantized_compact.set_nvfp4_e4m3_max((use_4over6 ? e4m3_max : 448));
ASSERT_EQ(quantized_compact.nvfp4_4over6(), use_4over6);
ASSERT_EQ(quantized_compact.nvfp4_e4m3_max(), (use_4over6 ? e4m3_max : 448));
if (row_scaled_nvfp4) {
quantized_compact.set_row_scaled_nvfp4(true);
} else if (rows > 0 && cols > 0) {
Expand All @@ -174,7 +190,10 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols,
}

if (rows > 0 && cols > 0) {
nvte_quantize(input.data(), quantized_compact.data(), 0);
QuantizationConfigWrapper quant_config;
quant_config.set_nvfp4_4over6(use_4over6);
quant_config.set_nvfp4_e4m3_max((use_4over6 ? e4m3_max : 448));
nvte_quantize_v2(input.data(), quantized_compact.data(), quant_config, 0);
cudaDeviceSynchronize();
}

Expand All @@ -186,6 +205,10 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols,
// Create tensor with same FP4 data but swizzled scales
Tensor quantized_swizzled("quantized_swizzled", std::vector<size_t>{rows, cols},
DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING);
quantized_swizzled.set_nvfp4_4over6(use_4over6);
quantized_swizzled.set_nvfp4_e4m3_max((use_4over6 ? e4m3_max : 448));
ASSERT_EQ(quantized_swizzled.nvfp4_4over6(), use_4over6);
ASSERT_EQ(quantized_swizzled.nvfp4_e4m3_max(), (use_4over6 ? e4m3_max : 448));
if (row_scaled_nvfp4) {
quantized_swizzled.set_row_scaled_nvfp4(true);
} else {
Expand Down Expand Up @@ -260,7 +283,9 @@ std::vector<std::pair<size_t, size_t>> nvfp4_tensor_dims = {
class DequantizeNVFP4TestSuite : public ::testing::TestWithParam
<std::tuple<std::pair<size_t, size_t>,
transformer_engine::DType,
bool>> {};
bool,
bool,
int>> {};

TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4)
{
Expand All @@ -271,10 +296,12 @@ TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4)
const auto tensor_size = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const bool row_scaled_nvfp4 = std::get<2>(GetParam());
const bool use_4over6 = std::get<3>(GetParam());
const int e4m3_max = use_4over6 ? std::get<4>(GetParam()) : 448;

TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType,
performTest_dequantize_nvfp4<OutputType>(
tensor_size.first, tensor_size.second, row_scaled_nvfp4);
tensor_size.first, tensor_size.second, row_scaled_nvfp4, use_4over6, e4m3_max);
);
}

Expand All @@ -284,21 +311,30 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Combine(
::testing::ValuesIn(nvfp4_tensor_dims),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Bool()),
::testing::Bool(),
::testing::Bool(),
::testing::Values(448, 256)),
[](const testing::TestParamInfo<DequantizeNVFP4TestSuite::ParamType>& info)
{
std::string name = std::to_string(std::get<0>(info.param).first) + "X" +
std::to_string(std::get<0>(info.param).second) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
(std::get<2>(info.param) ? "RowScaled" : "PerTensor");
(std::get<2>(info.param) ? "RowScaled" : "PerTensor") + "X" +
(std::get<3>(info.param) ? "FourOverSix" : "Default") + "X" +
(std::get<3>(info.param)
? (std::get<4>(info.param) == 256 ? "E4M3Max256" : "E4M3Max448")
: (std::get<4>(info.param) == 256 ? "E4M3Max256Ignored"
: "E4M3Max448"));
return name;
}
);

class DequantizeNVFP4SwizzledTestSuite : public ::testing::TestWithParam
<std::tuple<std::pair<size_t, size_t>,
transformer_engine::DType,
bool>> {};
bool,
bool,
int>> {};

TEST_P(DequantizeNVFP4SwizzledTestSuite, TestDequantizeNVFP4Swizzled)
{
Expand All @@ -309,10 +345,12 @@ TEST_P(DequantizeNVFP4SwizzledTestSuite, TestDequantizeNVFP4Swizzled)
const auto tensor_size = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const bool row_scaled_nvfp4 = std::get<2>(GetParam());
const bool use_4over6 = std::get<3>(GetParam());
const int e4m3_max = use_4over6 ? std::get<4>(GetParam()) : 448;

TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType,
performTest_dequantize_nvfp4_swizzled<OutputType>(
tensor_size.first, tensor_size.second, row_scaled_nvfp4);
tensor_size.first, tensor_size.second, row_scaled_nvfp4, use_4over6, e4m3_max);
);
}

Expand All @@ -322,13 +360,20 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Combine(
::testing::ValuesIn(nvfp4_tensor_dims),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Bool()),
::testing::Bool(),
::testing::Bool(),
::testing::Values(448, 256)),
[](const testing::TestParamInfo<DequantizeNVFP4SwizzledTestSuite::ParamType>& info)
{
std::string name = std::to_string(std::get<0>(info.param).first) + "X" +
std::to_string(std::get<0>(info.param).second) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
(std::get<2>(info.param) ? "RowScaled" : "PerTensor") + "X" +
(std::get<3>(info.param) ? "FourOverSix" : "Default") + "X" +
(std::get<3>(info.param)
? (std::get<4>(info.param) == 256 ? "E4M3Max256" : "E4M3Max448")
: (std::get<4>(info.param) == 256 ? "E4M3Max256Ignored"
: "E4M3Max448")) + "X" +
"Swizzled";
return name;
}
Expand Down
24 changes: 24 additions & 0 deletions tests/cpp/test_common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,30 @@ void Tensor::set_row_scaled_nvfp4(bool row_scaled_nvfp4) {
}
}

void Tensor::set_nvfp4_4over6(bool nvfp4_4over6) {
NVTE_CHECK(tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING,
"NVFP4 4over6 is only supported for NVFP4 tensors.");
tensor_.set_nvfp4_4over6(nvfp4_4over6);
}

void Tensor::set_nvfp4_e4m3_max(int nvfp4_e4m3_max) {
NVTE_CHECK(tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING,
"NVFP4 E4M3 max is only supported for NVFP4 tensors.");
tensor_.set_nvfp4_e4m3_max(nvfp4_e4m3_max);
}

bool Tensor::nvfp4_4over6() const {
NVTE_CHECK(tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING,
"NVFP4 4over6 is only supported for NVFP4 tensors.");
return tensor_.get_nvfp4_4over6();
}

int Tensor::nvfp4_e4m3_max() const {
NVTE_CHECK(tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING,
"NVFP4 E4M3 max is only supported for NVFP4 tensors.");
return tensor_.get_nvfp4_e4m3_max();
}

void Tensor::to_cpu() {
if (data_rowwise_) { data_rowwise_->to_cpu(); }
if (data_columnwise_) { data_columnwise_->to_cpu(); }
Expand Down
5 changes: 5 additions & 0 deletions tests/cpp/test_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,15 @@ class Tensor {
return columnwise_;
}

bool nvfp4_4over6() const;
int nvfp4_e4m3_max() const;

void set_tensor_amax_nullptr();

void set_with_gemm_swizzled_scales(bool with_gemm_swizzled_scales);
void set_row_scaled_nvfp4(bool row_scaled_nvfp4);
void set_nvfp4_4over6(bool nvfp4_4over6);
void set_nvfp4_e4m3_max(int nvfp4_e4m3_max);

void to_cpu();
void from_cpu();
Expand Down
Loading