diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py index a22fb7e16..9482dcccf 100644 --- a/tests/pytorch/distributed/run_numerics.py +++ b/tests/pytorch/distributed/run_numerics.py @@ -215,6 +215,9 @@ def _get_tolerances(dtype): if QUANTIZATION == "fp8_cs": return {"rtol": 0.4, "atol": 0.25} elif QUANTIZATION == "nvfp4": + if IS_HIP_EXTENSION: + # Higher tolerance for AMDGPU to account for intermediate bf16 step in GEMM + return {"rtol": 0.125, "atol": 0.15} # TODO(zhongboz): investigate why the tolerance is so large return {"rtol": 0.125, "atol": 0.12} elif QUANTIZATION is not None: diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index b22b50c70..e8f0f299e 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -103,6 +103,46 @@ def get_tensor_device(tensor: torch.Tensor) -> int: return tensor._transpose.device.index return torch.cuda.current_device() + +if IS_HIP_EXTENSION: + def _should_use_bf16_output_for_nvfp4_tn( + A, + B, + layout: str, + out_dtype: Optional[torch.dtype], + out, + bias, + quantization_params, + debug_quantizer, + grad: bool, + accumulate: bool, + ub, + extra_output, + gelu: bool, + ) -> bool: + """Work around ROCm NVFP4 TN GEMM corruption when requesting FP32 output. + + FIXME: hipBLASLt BF16xBF16->FP32 GEMM algos with ALPHA_DEVICE_VECTOR + produce incorrect results intermittently on AMDGPU. Return True for the + narrow path where we force BF16 output, which empirically covers the + corruption cases. + """ + return ( + layout == "TN" + and out_dtype == torch.float32 + and out is None + and bias is not None + and quantization_params is None + and debug_quantizer is None + and not grad + and not accumulate + and ub is None + and extra_output is None + and not gelu + and (isinstance(A, NVFP4TensorStorage) or isinstance(B, NVFP4TensorStorage)) + ) + + def _select_kernel_fp4(layout: str, grad: bool, M: int, N: int, K: int): """Select kernel via tuned CSV lookup, falling back to AITER heuristic.""" from aiter.ops.gemm_op_a4w4 import get_GEMM_config @@ -371,6 +411,24 @@ def general_gemm( # FP8 block-scaling requires split accumulator use_split_accumulator = True + if IS_HIP_EXTENSION: + use_bf16_tn_output_workaround = _should_use_bf16_output_for_nvfp4_tn( + A, + B, + layout, + out_dtype, + out, + bias, + quantization_params, + debug_quantizer, + grad, + accumulate, + ub, + extra_output, + gelu, + ) + out_dtype = torch.bfloat16 if use_bf16_tn_output_workaround else out_dtype + args = ( A, transa, # transa @@ -400,6 +458,9 @@ def general_gemm( out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) + if IS_HIP_EXTENSION and use_bf16_tn_output_workaround: + out = cast_if_needed(out, torch.float32) + if debug_quantizer is not None: out = debug_quantizer.process_gemm_output(out)