From 8f9f431e0c49269863b145b3af00b0bd8b87956d Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 7 May 2026 00:40:01 +0000 Subject: [PATCH 1/4] NVFP4: Workaround intermittent incorrect results for backward GEMMs --- tests/pytorch/distributed/run_numerics.py | 28 +++++++++++++++-------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py index a22fb7e16..ad4056424 100644 --- a/tests/pytorch/distributed/run_numerics.py +++ b/tests/pytorch/distributed/run_numerics.py @@ -142,15 +142,25 @@ def main(argv=None, namespace=None): BATCH_SIZE = 128 HIDDEN_SIZE = 512 - test_dict = [ - test_quantizer, - test_quantized_all_gather, - test_linear, - test_layernorm, - test_layernorm_linear, - test_layernorm_mlp, - test_transformer_layer, - ] + # FIXME: hipBLASLt BF16xBF16->FP32 GEMM algos with ALPHA_DEVICE_VECTOR produce + # incorrect results intermittently on AMDGPU. Skip backward-containing sub-tests for + # nvfp4. + if IS_HIP_EXTENSION and QUANTIZATION == "nvfp4": + test_dict = [ + test_quantizer, + test_quantized_all_gather, + test_layernorm, + ] + else: + test_dict = [ + test_quantizer, + test_quantized_all_gather, + test_linear, + test_layernorm, + test_layernorm_linear, + test_layernorm_mlp, + test_transformer_layer, + ] for test in test_dict: test() From 9b3121fe76e9b58bbeea5adbd9c0b9ea6b564cf5 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 8 May 2026 15:30:27 -0500 Subject: [PATCH 2/4] Revert "NVFP4: Workaround intermittent incorrect results for backward GEMMs" This reverts commit 8f9f431e0c49269863b145b3af00b0bd8b87956d. --- tests/pytorch/distributed/run_numerics.py | 28 ++++++++--------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py index ad4056424..a22fb7e16 100644 --- a/tests/pytorch/distributed/run_numerics.py +++ b/tests/pytorch/distributed/run_numerics.py @@ -142,25 +142,15 @@ def main(argv=None, namespace=None): BATCH_SIZE = 128 HIDDEN_SIZE = 512 - # FIXME: hipBLASLt BF16xBF16->FP32 GEMM algos with ALPHA_DEVICE_VECTOR produce - # incorrect results intermittently on AMDGPU. Skip backward-containing sub-tests for - # nvfp4. - if IS_HIP_EXTENSION and QUANTIZATION == "nvfp4": - test_dict = [ - test_quantizer, - test_quantized_all_gather, - test_layernorm, - ] - else: - test_dict = [ - test_quantizer, - test_quantized_all_gather, - test_linear, - test_layernorm, - test_layernorm_linear, - test_layernorm_mlp, - test_transformer_layer, - ] + test_dict = [ + test_quantizer, + test_quantized_all_gather, + test_linear, + test_layernorm, + test_layernorm_linear, + test_layernorm_mlp, + test_transformer_layer, + ] for test in test_dict: test() From b609614386e799399962a2f77af2dd06facc1ec0 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 8 May 2026 15:42:14 -0500 Subject: [PATCH 3/4] use a BF16 fallback instead --- tests/pytorch/distributed/run_numerics.py | 3 + .../pytorch/cpp_extensions/gemm.py | 61 +++++++++++++++++++ 2 files changed, 64 insertions(+) diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py index a22fb7e16..1edc965e0 100644 --- a/tests/pytorch/distributed/run_numerics.py +++ b/tests/pytorch/distributed/run_numerics.py @@ -216,6 +216,9 @@ def _get_tolerances(dtype): return {"rtol": 0.4, "atol": 0.25} elif QUANTIZATION == "nvfp4": # TODO(zhongboz): investigate why the tolerance is so large + if IS_HIP_EXTENSION: + # Higher tolerance for AMDGPU to account for intermediate bf16 step in GEMM + return {"rtol": 0.125, "atol": 0.15} return {"rtol": 0.125, "atol": 0.12} elif QUANTIZATION is not None: return {"rtol": 0.125, "atol": 0.0625} diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index b22b50c70..e8f0f299e 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -103,6 +103,46 @@ def get_tensor_device(tensor: torch.Tensor) -> int: return tensor._transpose.device.index return torch.cuda.current_device() + +if IS_HIP_EXTENSION: + def _should_use_bf16_output_for_nvfp4_tn( + A, + B, + layout: str, + out_dtype: Optional[torch.dtype], + out, + bias, + quantization_params, + debug_quantizer, + grad: bool, + accumulate: bool, + ub, + extra_output, + gelu: bool, + ) -> bool: + """Work around ROCm NVFP4 TN GEMM corruption when requesting FP32 output. + + FIXME: hipBLASLt BF16xBF16->FP32 GEMM algos with ALPHA_DEVICE_VECTOR + produce incorrect results intermittently on AMDGPU. Return True for the + narrow path where we force BF16 output, which empirically covers the + corruption cases. + """ + return ( + layout == "TN" + and out_dtype == torch.float32 + and out is None + and bias is not None + and quantization_params is None + and debug_quantizer is None + and not grad + and not accumulate + and ub is None + and extra_output is None + and not gelu + and (isinstance(A, NVFP4TensorStorage) or isinstance(B, NVFP4TensorStorage)) + ) + + def _select_kernel_fp4(layout: str, grad: bool, M: int, N: int, K: int): """Select kernel via tuned CSV lookup, falling back to AITER heuristic.""" from aiter.ops.gemm_op_a4w4 import get_GEMM_config @@ -371,6 +411,24 @@ def general_gemm( # FP8 block-scaling requires split accumulator use_split_accumulator = True + if IS_HIP_EXTENSION: + use_bf16_tn_output_workaround = _should_use_bf16_output_for_nvfp4_tn( + A, + B, + layout, + out_dtype, + out, + bias, + quantization_params, + debug_quantizer, + grad, + accumulate, + ub, + extra_output, + gelu, + ) + out_dtype = torch.bfloat16 if use_bf16_tn_output_workaround else out_dtype + args = ( A, transa, # transa @@ -400,6 +458,9 @@ def general_gemm( out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) + if IS_HIP_EXTENSION and use_bf16_tn_output_workaround: + out = cast_if_needed(out, torch.float32) + if debug_quantizer is not None: out = debug_quantizer.process_gemm_output(out) From 6cbc4dc2e4d466820c5a0b2f6cbf04739d817e07 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Sun, 10 May 2026 09:26:15 -0500 Subject: [PATCH 4/4] move TODO comment --- tests/pytorch/distributed/run_numerics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py index 1edc965e0..9482dcccf 100644 --- a/tests/pytorch/distributed/run_numerics.py +++ b/tests/pytorch/distributed/run_numerics.py @@ -215,10 +215,10 @@ def _get_tolerances(dtype): if QUANTIZATION == "fp8_cs": return {"rtol": 0.4, "atol": 0.25} elif QUANTIZATION == "nvfp4": - # TODO(zhongboz): investigate why the tolerance is so large if IS_HIP_EXTENSION: # Higher tolerance for AMDGPU to account for intermediate bf16 step in GEMM return {"rtol": 0.125, "atol": 0.15} + # TODO(zhongboz): investigate why the tolerance is so large return {"rtol": 0.125, "atol": 0.12} elif QUANTIZATION is not None: return {"rtol": 0.125, "atol": 0.0625}