From 6e3eea50e5be7f4a0c1b81899a84b6356cd98e22 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 2 Apr 2026 11:27:27 -0500 Subject: [PATCH] Enable NVFP4 recipe --- transformer_engine/common/CMakeLists.txt | 6 +++--- transformer_engine/pytorch/quantization.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 27c26f7a8..9a23efc25 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -214,7 +214,8 @@ list(APPEND transformer_engine_cuda_sources fused_router/fused_topk_with_score_function.cu recipe/current_scaling.cu recipe/delayed_scaling.cu - recipe/fp8_block_scaling.cu) + recipe/fp8_block_scaling.cu + recipe/nvfp4.cu) list(APPEND transformer_engine_cuda_arch_specific_sources cast/cast.cu @@ -238,8 +239,7 @@ if(USE_CUDA) fused_attn/fused_attn_fp8.cu fused_attn/utils.cu swizzle/swizzle.cu - swizzle/swizzle_block_scaling.cu - recipe/nvfp4.cu) + swizzle/swizzle_block_scaling.cu) list(APPEND transformer_engine_cuda_arch_specific_sources gemm/cutlass_grouped_gemm.cu transpose/quantize_transpose_square_blockwise.cu) diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 37766f5ce..8eded3622 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -87,7 +87,7 @@ def check_mxfp8_support() -> Tuple[bool, str]: @functools.lru_cache(maxsize=None) def check_nvfp4_support() -> Tuple[bool, str]: if IS_HIP_EXTENSION: - return False, "ROCm TE currently not supporting NVFP4" + return True, "" """Return if nvfp4 support is available""" if get_device_compute_capability() >= (10, 0): # blackwell and above return True, ""