diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index bcacb2f801..4c1ffbaaa4 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -178,9 +178,13 @@ void run_grouped_gemm_case(const TestParams& params) { GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.3+, but compile-time cuBLAS version is " << CUBLAS_VERSION << "."; #else - if (getDeviceComputeCapability() < blackwellComputeCapability) { + const int compute_capability = getDeviceComputeCapability(); + if (compute_capability < blackwellComputeCapability) { GTEST_SKIP() << "Grouped GEMM requires Blackwell (SM100) or newer."; } + if (compute_capability == 120) { + GTEST_SKIP() << "Grouped GEMM is currently unsupported on SM120."; + } const std::vector> shapes = make_shapes(params.shape_case); @@ -356,9 +360,13 @@ void run_grouped_gemm_discrete_out_case(const TestParams& params) { GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.3+, but compile-time cuBLAS version is " << CUBLAS_VERSION << "."; #else - if (getDeviceComputeCapability() < blackwellComputeCapability) { + const int compute_capability = getDeviceComputeCapability(); + if (compute_capability < blackwellComputeCapability) { GTEST_SKIP() << "Grouped GEMM requires Blackwell (SM100) or newer."; } + if (compute_capability == 120) { + GTEST_SKIP() << "Grouped GEMM is currently unsupported on SM120."; + } const std::vector> shapes = make_shapes(params.shape_case); @@ -527,9 +535,13 @@ void run_grouped_gemm_discrete_in_case(const TestParams& params) { GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.3+, but compile-time cuBLAS version is " << CUBLAS_VERSION << "."; #else - if (getDeviceComputeCapability() < blackwellComputeCapability) { + const int compute_capability = getDeviceComputeCapability(); + if (compute_capability < blackwellComputeCapability) { GTEST_SKIP() << "Grouped GEMM requires Blackwell (SM100) or newer."; } + if (compute_capability == 120) { + GTEST_SKIP() << "Grouped GEMM is currently unsupported on SM120."; + } const std::vector> shapes = make_shapes(params.shape_case); diff --git a/tests/pytorch/debug/run_distributed.py b/tests/pytorch/debug/run_distributed.py index 285ec7ba0c..78f45286ea 100644 --- a/tests/pytorch/debug/run_distributed.py +++ b/tests/pytorch/debug/run_distributed.py @@ -48,6 +48,19 @@ fp8_available = is_fp8_available() +def _cmp_dist(ground_truth, output, parallel_mode): + if parallel_mode == "column" and torch.cuda.get_device_capability() == (12, 0): + # SM120: distributed column-parallel path may show a single-element + # activation outlier slightly above default fp32 atol, while grads match. + torch.testing.assert_close( + ground_truth["activation"], output["activation"], atol=1.2e-5, rtol=1.3e-6 + ) + torch.testing.assert_close(ground_truth["wgrad"], output["wgrad"]) + torch.testing.assert_close(ground_truth["dgrad"], output["dgrad"]) + else: + _cmp(ground_truth, output) + + def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None, tp_rank=None): if tp_size is None: tp_size = WORLD_SIZE @@ -445,7 +458,7 @@ def test_disable_fp8_gemms(fprop_fp8, dgrad_fp8, wgrad_fp8, parallel_mode, **kwa x.grad.zero_() ground_truth = _emulate_linear_distributed(x, weight, parallel_mode=parallel_mode, **fp8_kwargs) - _cmp(ground_truth, output) + _cmp_dist(ground_truth, output, parallel_mode) @run_debug_test @@ -466,7 +479,7 @@ def test_disable_fp8_layer(parallel_mode, **kwargs): y = _run_forward_backward(x, model, parallel_mode) output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()} - _cmp(ground_truth, output) + _cmp_dist(ground_truth, output, parallel_mode) @run_debug_test @@ -554,7 +567,7 @@ def test_per_tensor_scaling( x, weight, parallel_mode=parallel_mode, loss_multiplier=LOSS_MULTIPLIER, **fp8_kwargs ) - _cmp(ground_truth, output) + _cmp_dist(ground_truth, output, parallel_mode) @run_debug_test @@ -617,7 +630,7 @@ def test_fake_quant_fp8( _get_current_scale(x, wgrad_input) if not fp8_kwargs["wgrad_fp8"] else None ) ground_truth = _emulate_linear_distributed(x, weight, parallel_mode=parallel_mode, **fp8_kwargs) - _cmp(ground_truth, output) + _cmp_dist(ground_truth, output, parallel_mode) def _init_distributed(): diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index 53c7a5e7cc..b1883a3bc9 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -30,6 +30,11 @@ warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning) +FP8_DEFAULT_RTOL_ATOL = (0.125, 0.0625) +FP8_CS_SM120_DETERMINISTIC_RTOL_ATOL = (0.4, 0.25) +BF16_DEFAULT_RTOL_ATOL = (0.025, 0.00125) +BF16_SM120_DETERMINISTIC_OVERLAP_RTOL_ATOL = (0.05, 0.01) + class multi_module_model(torch.nn.Module): def __init__(self, module, num_layers, *args, **kwargs): @@ -551,9 +556,33 @@ def run_fwd_bwd(model, x): # Now validate accuracy if not bool(numerics_failed.item()): + is_sm120 = torch.cuda.get_device_capability() == (12, 0) + is_deterministic_mode = os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1") == "0" for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)): - rtol = 0.125 if opts.fp8 else 0.025 - atol = 0.0625 if opts.fp8 else 0.00125 + if opts.fp8: + if ( + opts.quantization == "fp8_current_scaling" + and is_sm120 + and is_deterministic_mode + ): + # SM120 deterministic mode disables fused attn, so rt uses alternate attn backends. + # Combined with FP8 CS, this path needs the looser distributed fp8_cs tolerance policy. + rtol, atol = FP8_CS_SM120_DETERMINISTIC_RTOL_ATOL + else: + rtol, atol = FP8_DEFAULT_RTOL_ATOL + else: + rtol, atol = BF16_DEFAULT_RTOL_ATOL + if ( + is_sm120 + and is_deterministic_mode + and opts.layer_type == te.TransformerLayer + and opts.num_layers > 1 + and opts.overlap_rs_dgrad + ): + # SM120 + deterministic training disables fused attn . + # Rt then selects an alternate attn backend, and + # the overlap path can show tiny BF16 accumulation-order drift vs reference. + rtol, atol = BF16_SM120_DETERMINISTIC_OVERLAP_RTOL_ATOL grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol) dist_print(grad_info, src=WORLD_RANK, error=grad_failed) numerics_failed[0] = int(grad_failed) diff --git a/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py b/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py index d46a874695..1f95b8f31b 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py +++ b/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py @@ -27,6 +27,35 @@ recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) +SM120_SWIZZLED_SCALE_RTOL_ATOL = (1e-3, 1e-3) +STRICT_SCALE_RTOL_ATOL = (0.0, 0.0) + + +def _scale_compare_tolerances(expected_swizzled_layout: bool) -> tuple[float, float]: + """Return comparison tolerances for NVFP4 scale tensors. + + On SM120 with swizzled scale layout enabled, grouped NVFP4 can route through a + fallback path whose scale accumulation order differs slightly from the + Python reference. Layout must still match, but exact bitwise equality of + scale values is not guaranteed. + """ + if torch.cuda.get_device_capability() == (12, 0) and expected_swizzled_layout: + return SM120_SWIZZLED_SCALE_RTOL_ATOL + return STRICT_SCALE_RTOL_ATOL + + +def _reference_scale_for_layout( + ref_unswizzled: torch.Tensor, + split_m: int, + n: int, + columnwise: bool, + with_gemm_swizzled_scales: bool, +) -> torch.Tensor: + """Return reference scale in expected backend-reported layout.""" + if with_gemm_swizzled_scales: + return swizzle_nvfp4_scale(split_m, n, ref_unswizzled.clone(), columnwise=columnwise) + return ref_unswizzled + def fused_grouped_quantize( x: torch.Tensor, split_section_tensor: torch.Tensor, quantizer: NVFP4Quantizer @@ -56,7 +85,6 @@ def check_grouped_tensor_nvfp4_versus_reference( ) -> None: te_dtype = tex.DType.kFloat4E2M1 - split_section_tensor = torch.tensor(split_sections, dtype=torch.int64, device="cuda") # Setup device and random seed @@ -98,6 +126,15 @@ def check_grouped_tensor_nvfp4_versus_reference( group_quantized_output = fused_grouped_quantize(x, split_section_tensor, grouped_quantizer) # get a list of nvfp4 quantized tensors for testing split_quantize_outputs = group_quantized_output.split_into_quantized_tensors() + expected_swizzled_layout = bool(group_quantized_output._with_gemm_swizzled_scales) + for i, output in enumerate(split_quantize_outputs): + split_flag = bool(output._with_gemm_swizzled_scales) + assert split_flag == expected_swizzled_layout, ( + "Grouped output and split output disagree on swizzled-scale metadata " + f"(split {i}: grouped={expected_swizzled_layout}, split={split_flag})" + ) + # Fetch appropriate scale comparison tolerances based on expected swizzled layout and CC + scale_atol, scale_rtol = _scale_compare_tolerances(expected_swizzled_layout) if return_rowwise: x_qx = [output._rowwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs] @@ -121,11 +158,15 @@ def check_grouped_tensor_nvfp4_versus_reference( ), "The scale shape is not correctly aligned" x_sx_i = x_sx[i].clone() x_sx_ref_i = x_sx_ref[i].clone() - if optimize_for_gemm: - x_sx_ref_i = swizzle_nvfp4_scale( - split_sections[i], N, x_sx_ref_i, columnwise=False - ) - torch.testing.assert_close(x_sx_i, x_sx_ref_i, atol=0.0, rtol=0.0) + # Swizzle the reference scale based on expected_swizzled_layout + x_sx_ref_i = _reference_scale_for_layout( + ref_unswizzled=x_sx_ref_i, + split_m=split_sections[i], + n=N, + columnwise=False, + with_gemm_swizzled_scales=expected_swizzled_layout, + ) + torch.testing.assert_close(x_sx_i, x_sx_ref_i, atol=scale_atol, rtol=scale_rtol) if return_transpose: x_qx_t = [ @@ -151,11 +192,14 @@ def check_grouped_tensor_nvfp4_versus_reference( ), "The scale shape is not correctly aligned" x_sx_t_i = x_sx_t[i].clone() x_sx_t_ref_i = x_sx_t_ref[i].clone() - if optimize_for_gemm: - x_sx_t_ref_i = swizzle_nvfp4_scale( - split_sections[i], N, x_sx_t_ref_i, columnwise=True - ) - torch.testing.assert_close(x_sx_t_i, x_sx_t_ref_i, atol=0.0, rtol=0.0) + x_sx_t_ref_i = _reference_scale_for_layout( + ref_unswizzled=x_sx_t_ref_i, + split_m=split_sections[i], + n=N, + columnwise=True, + with_gemm_swizzled_scales=expected_swizzled_layout, + ) + torch.testing.assert_close(x_sx_t_i, x_sx_t_ref_i, atol=scale_atol, rtol=scale_rtol) def check_grouped_tensor_nvfp4_with_paged_stashing( @@ -173,7 +217,6 @@ def check_grouped_tensor_nvfp4_with_paged_stashing( ) -> None: te_dtype = tex.DType.kFloat4E2M1 - assert valid_M is not None, "valid_M must be provided when with_paged_stashing is True" assert valid_M < M, "valid_M must be less than M when with_paged_stashing is True" @@ -225,6 +268,15 @@ def check_grouped_tensor_nvfp4_with_paged_stashing( # get a list of nvfp4 quantized tensors for testing split_quantize_outputs = group_quantized_output.split_into_quantized_tensors() + expected_swizzled_layout = bool(group_quantized_output._with_gemm_swizzled_scales) + for i, output in enumerate(split_quantize_outputs): + split_flag = bool(output._with_gemm_swizzled_scales) + assert split_flag == expected_swizzled_layout, ( + "Grouped output and split output disagree on swizzled-scale metadata " + f"(split {i}: grouped={expected_swizzled_layout}, split={split_flag})" + ) + # Fetch appropriate scale comparison tolerances based on expected swizzled layout and CC + scale_atol, scale_rtol = _scale_compare_tolerances(expected_swizzled_layout) if return_rowwise: x_qx = [output._rowwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs] @@ -248,11 +300,15 @@ def check_grouped_tensor_nvfp4_with_paged_stashing( ), "The scale shape is not correctly aligned" x_sx_i = x_sx[i].clone() x_sx_ref_i = x_sx_ref[i].clone() - if optimize_for_gemm: - x_sx_ref_i = swizzle_nvfp4_scale( - split_sections[i], N, x_sx_ref_i, columnwise=False - ) - torch.testing.assert_close(x_sx_i, x_sx_ref_i, atol=0.0, rtol=0.0) + # Swizzle the reference scale based on expected swizzled layout + x_sx_ref_i = _reference_scale_for_layout( + ref_unswizzled=x_sx_ref_i, + split_m=split_sections[i], + n=N, + columnwise=False, + with_gemm_swizzled_scales=expected_swizzled_layout, + ) + torch.testing.assert_close(x_sx_i, x_sx_ref_i, atol=scale_atol, rtol=scale_rtol) if return_transpose: x_qx_t = [ @@ -275,11 +331,14 @@ def check_grouped_tensor_nvfp4_with_paged_stashing( valid_scale_shape = get_nvfp4_scale_shape_no_padding(x_splits[i].shape, True) x_sx_t_i = x_sx_t[i].clone() x_sx_t_ref_i = x_sx_t_ref[i].clone() - if optimize_for_gemm: - x_sx_t_ref_i = swizzle_nvfp4_scale( - split_sections[i], N, x_sx_t_ref_i, columnwise=True - ) - torch.testing.assert_close(x_sx_t_i, x_sx_t_ref_i, atol=0.0, rtol=0.0) + x_sx_t_ref_i = _reference_scale_for_layout( + ref_unswizzled=x_sx_t_ref_i, + split_m=split_sections[i], + n=N, + columnwise=True, + with_gemm_swizzled_scales=expected_swizzled_layout, + ) + torch.testing.assert_close(x_sx_t_i, x_sx_t_ref_i, atol=scale_atol, rtol=scale_rtol) @pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) @@ -402,6 +461,11 @@ def test_grouped_tensor_nvfp4_with_paged_stashing( with_rht: bool, optimize_for_gemm: bool, ) -> None: + if torch.cuda.get_device_capability() == (12, 0): + pytest.skip( + "SM120: paged-stashing grouped NVFP4 path is currently unsupported. " + "group_hadamard_transform_amax assumes sum(split_sections) == input rows)." + ) # paged stashing means that the sum of total tokens is less than # or equal to the buffer size, you can have buffer [2048, 1024] diff --git a/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py index b14eeb815b..7a6fd5b43a 100755 --- a/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py +++ b/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py @@ -14,11 +14,31 @@ recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) +SM120_SR_EQUIVALENCE_ATOL = 2e-7 + seed = 12345 torch.manual_seed(seed) torch.cuda.manual_seed(seed) +def _assert_sr_vs_rn_behavior( + me_sr: torch.Tensor, + me_rn: torch.Tensor, + me_t_sr: torch.Tensor, + me_t_rn: torch.Tensor, +) -> None: + if torch.cuda.get_device_capability() == (12, 0): + # SM120 currently disables NVFP4 stochastic rounding in backend paths, + # so SR and RN should be numerically equivalent. + torch.testing.assert_close(me_sr, me_rn, atol=SM120_SR_EQUIVALENCE_ATOL, rtol=0.0) + torch.testing.assert_close(me_t_sr, me_t_rn, atol=SM120_SR_EQUIVALENCE_ATOL, rtol=0.0) + else: + assert me_sr < me_rn, "Stochastic rounding failed - error larger than the round to nearest." + assert ( + me_t_sr < me_t_rn + ), "Stochastic rounding failed - error larger than the round to nearest." + + def unpack_fp4(x: torch.Tensor) -> torch.Tensor: repeated = x.repeat_interleave(2, dim=1) repeated[:, 0::2] &= 0x0F @@ -247,7 +267,7 @@ def check_quantization_nvfp4_versus_reference( me_t_rn = torch.sqrt((error_t_rn * error_t_rn).mean()) sr_result = torch.zeros_like(x).float() sr_t_result = torch.zeros_like(x).float().t().contiguous() - for i in range(n_iters): + for _ in range(n_iters): q_sr, s_sr, q_t_sr, s_t_sr = quantize_fp4( x, use_stochastic_rounding=True, use_2D=use_2D, use_RHT=use_RHT ) @@ -278,8 +298,7 @@ def check_quantization_nvfp4_versus_reference( print(f"RMSE SR: {me_sr:.3e} | RMSE RN: {me_rn:.3e}") print(f"RMSE SR_t: {me_t_sr:.3e} | RMSE RN_t: {me_t_rn:.3e}") - assert me_sr < me_rn, "Stochastic rounding failed - error larger than the round to nearest." - assert me_t_sr < me_t_rn, "Stochastic rounding failed - error larger than the round to nearest." + _assert_sr_vs_rn_behavior(me_sr, me_rn, me_t_sr, me_t_rn) def check_group_quantization_nvfp4_versus_reference( @@ -362,10 +381,7 @@ def check_group_quantization_nvfp4_versus_reference( print(f"RMSE SR: {me_sr:.3e} | RMSE RN: {me_rn:.3e}") print(f"RMSE SR_t: {me_t_sr:.3e} | RMSE RN_t: {me_t_rn:.3e}") - assert me_sr < me_rn, "Stochastic rounding failed - error larger than the round to nearest." - assert ( - me_t_sr < me_t_rn - ), "Stochastic rounding failed - error larger than the round to nearest." + _assert_sr_vs_rn_behavior(me_sr, me_rn, me_t_sr, me_t_rn) @pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) diff --git a/tests/pytorch/test_custom_recipe.py b/tests/pytorch/test_custom_recipe.py index 62a6291797..4cf5c6ec1b 100644 --- a/tests/pytorch/test_custom_recipe.py +++ b/tests/pytorch/test_custom_recipe.py @@ -128,7 +128,7 @@ def test_custom_recipe_grouped_linear_sanity(): in_features = 64 out_features = 64 # Each per-GEMM M dim must be a multiple of 16 to satisfy cuBLAS FP8 GEMM's - # leading-dimension alignment requirement on Hopper (sm_90). + # leading-dimension alignment requirement on Hopper and SM120 paths. m_splits = [16] * num_gemms batch = sum(m_splits) @@ -281,7 +281,7 @@ def test_custom_recipe_factory_invocation_counts_and_cycling(): in_features = 64 out_features = 64 # batch must be a multiple of 16 to satisfy cuBLAS FP8 GEMM's leading-dim - # alignment requirement on Hopper (sm_90). + # alignment requirement on Hopper and SM120 paths. batch = 16 op = Linear(in_features, out_features, params_dtype=torch.bfloat16) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index a718ea2a8a..4df2e73dec 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -64,6 +64,7 @@ nvfp4_available = is_nvfp4_available() sm_80plus = get_device_compute_capability() >= (8, 0) +sm_120 = get_device_compute_capability() == (12, 0) seed = 1234 # Reset RNG states. @@ -2703,9 +2704,15 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): max_seqlen_kv=config.max_seqlen_kv, ) + tols = dtype_tols(dtype) + if sm_120: + # sm120 FusedAttention does not support T3HD/TH3D layouts, so for T3HD/TH3D, the test falls back to using Flash Attn backend + # whereas for BSHD/SBHD, the test uses FusedAttention backend by default. Hence, relaxing the atol tolerance for T3HD/TH3D. + tols["atol"] = max(tols["atol"], 4e-3) torch.testing.assert_close( y_bshd, y_thd.reshape(bs, config.max_seqlen_q, config.hidden_size).contiguous(), + **tols, ) @@ -2887,6 +2894,8 @@ def test_grouped_gemm_grouped_tensor(z, m, n, k, case, layout, accumulate, use_b pytest.skip("Grouped GEMM requires cuBLAS 13.3+.") if torch.cuda.get_device_capability() < (10, 0): pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.") + if torch.cuda.get_device_capability() == (12, 0): + pytest.skip("Grouped GEMM is currently unsupported on SM120.") if not is_bf16_available(): pytest.skip("bfloat16 is required for grouped GEMM test.") @@ -3039,6 +3048,8 @@ def test_grouped_gemm_grouped_tensor_zero_work(layout, accumulate, quant_type) - """ if torch.cuda.get_device_capability() < (10, 0): pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.") + if torch.cuda.get_device_capability() == (12, 0): + pytest.skip("Grouped GEMM is currently unsupported on SM120.") if not is_bf16_available(): pytest.skip("bfloat16 is required for grouped GEMM test.") if quant_type == "mxfp8" and not mxfp8_available: @@ -3209,6 +3220,8 @@ def test_grouped_gemm_grouped_tensor_mxfp8( pytest.skip("Grouped GEMM requires cuBLAS 13.3+.") if torch.cuda.get_device_capability() < (10, 0): pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.") + if torch.cuda.get_device_capability() == (12, 0): + pytest.skip("Grouped GEMM is currently unsupported on SM120.") if dtype == torch.bfloat16 and not is_bf16_available(): pytest.skip("bfloat16 is required for grouped GEMM test.") diff --git a/transformer_engine/common/cast/dispatch/dequantize.cuh b/transformer_engine/common/cast/dispatch/dequantize.cuh index 63c1b046ff..0568fc521b 100644 --- a/transformer_engine/common/cast/dispatch/dequantize.cuh +++ b/transformer_engine/common/cast/dispatch/dequantize.cuh @@ -39,7 +39,7 @@ inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t break; } case NVTE_MXFP8_1D_SCALING: { - if (is_supported_by_CC_100()) { + if (is_supported_by_CC_100_or_newer()) { mxfp8::dequantize(input, output, stream); } else { NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); @@ -62,7 +62,7 @@ inline void group_dequantize_helper(const GroupedTensor &input, GroupedTensor *o switch (input.scaling_mode) { case NVTE_MXFP8_1D_SCALING: { - if (is_supported_by_CC_100()) { + if (is_supported_by_CC_100_or_newer()) { mxfp8::group_dequantize(&input, output, stream); } else { NVTE_ERROR("MXFP8 Grouped Dequantization is NOT supported by architectures < 10.0"); diff --git a/transformer_engine/common/cast/dispatch/gated.cuh b/transformer_engine/common/cast/dispatch/gated.cuh index 06e8f0e306..11b28c2483 100644 --- a/transformer_engine/common/cast/dispatch/gated.cuh +++ b/transformer_engine/common/cast/dispatch/gated.cuh @@ -46,7 +46,10 @@ void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_outp switch (output->scaling_mode) { case NVTE_DELAYED_TENSOR_SCALING: { - const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); + // SM120 has lower shared-memory headroom than SM100 for this kernel family. + // Keep TMA kernels disabled on SM120 and use the non-TMA fallback path. + const bool use_tma_kernels = + (cols % 32 == 0) && is_supported_by_CC_100_or_newer() && !is_supported_by_CC_120(); if (use_tma_kernels) { Tensor dummy_grad_tensor; fp8::cast_gated_tma(input, dummy_grad_tensor, @@ -83,7 +86,7 @@ void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_outp NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), "The type of the columnwise output tensor should be FP8."); } - NVTE_CHECK(is_supported_by_CC_100(), + NVTE_CHECK(is_supported_by_CC_100_or_newer(), "Gated FWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); Tensor dummy_grad_tensor; mxfp8::quantize_gated(input, dummy_grad_tensor, @@ -137,7 +140,10 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte switch (output->scaling_mode) { case NVTE_DELAYED_TENSOR_SCALING: { - const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); + // SM120 has lower shared-memory headroom than SM100 for this kernel family. + // Keep TMA kernels disabled on SM120 and use the non-TMA fallback path. + const bool use_tma_kernels = + (cols % 32 == 0) && is_supported_by_CC_100_or_newer() && !is_supported_by_CC_120(); if (use_tma_kernels) { fp8::cast_gated_tma(gated_input, grad, output, p, stream); @@ -173,7 +179,7 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), "The type of the columnwise output tensor should be FP8."); } - NVTE_CHECK(is_supported_by_CC_100(), + NVTE_CHECK(is_supported_by_CC_100_or_newer(), "Gated BWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); mxfp8::quantize_gated(gated_input, grad, output, p, diff --git a/transformer_engine/common/cast/fp8/quantize_fp8.cuh b/transformer_engine/common/cast/fp8/quantize_fp8.cuh index 96a42b494d..a06ed5f046 100644 --- a/transformer_engine/common/cast/fp8/quantize_fp8.cuh +++ b/transformer_engine/common/cast/fp8/quantize_fp8.cuh @@ -532,7 +532,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); // Supported by the Arch >= 10.0 - if (is_supported_by_CC_100()) { + if (is_supported_by_CC_100_or_newer()) { if (!IS_DBIAS && !IS_DACT) { if (common::full_tile_1D_tensor(output, ELEMS_PER_BLOCK) && is_fp8_dtype(output->dtype()) && is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) && diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index 1549a292d8..a8926a7408 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -491,6 +491,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } } + // Ensure async shared->global copy is done reading shared source before reuse. + ptx::cp_async_bulk_wait_group_read<0>(); + // Ensure all warps reach the reuse boundary before DBIAS scratch writes. + __syncthreads(); + parity ^= 1; if constexpr (IS_DBIAS) { diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index 1bdd80a369..52c1f2bb3b 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -281,12 +281,18 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); } -bool is_supported_by_CC_100() { +bool is_supported_by_CC_100_or_newer() { int deviceComputeCapability = cuda::sm_arch(cuda::current_device()); return deviceComputeCapability >= 100; } +bool is_supported_by_CC_120() { + int deviceComputeCapability = cuda::sm_arch(cuda::current_device()); + + return deviceComputeCapability == 120; +} + std::vector> convert_tensor_array(NVTETensor **nvte_tensors, size_t outer_size, size_t inner_size) { std::vector> ret; diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 12479f2a9c..15e0627591 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -1054,7 +1054,9 @@ void create_2D_tensor_map( const uint32_t stride_elems, const uint32_t offset_elems, const size_t type_num_bits, const CUtensorMapSwizzle swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); -bool is_supported_by_CC_100(); +bool is_supported_by_CC_100_or_newer(); + +bool is_supported_by_CC_120(); std::vector> convert_tensor_array(NVTETensor **nvte_tensors, size_t outer_size, size_t inner_size); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 6df7ad35c8..e8f113bff6 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -385,7 +385,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( } Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1}); - if (is_ragged_q && cudnn_runtime_version >= 90600) { + if (use_ragged_stats) { Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); } else { Stats->set_stride({h * s_q, s_q, 1, 1}); @@ -1142,7 +1142,8 @@ void fused_attn_arbitrary_seqlen_fwd( Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_S->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) && + (sm_arch_ != 120)) { output_S->data.shape = {num_tokens_q, num_attn_heads, 1}; } else { output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index 6a7af158e5..2065c5b09a 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -303,8 +303,11 @@ inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) { inline void check_grouped_gemm_requirements(const char *api_name) { const int current_device = transformer_engine::cuda::current_device(); - NVTE_CHECK(transformer_engine::cuda::sm_arch(current_device) >= 100, api_name, - " requires Blackwell (SM100) or newer architecture."); + const int sm_arch = transformer_engine::cuda::sm_arch(current_device); + NVTE_CHECK(sm_arch >= 100, api_name, " requires Blackwell (SM100) or newer architecture."); + NVTE_CHECK(sm_arch != 120, api_name, + " is currently unsupported on SM120. Grouped cuBLASLt GEMM heuristic selection " + "returns CUBLAS_STATUS_NOT_SUPPORTED on this architecture (even with relaxed hints)"); NVTE_CHECK(transformer_engine::cuda::cublas_version() >= CUBLAS_GROUPED_GEMM_VERSION, api_name, " requires cuBLAS 13.3+, but run-time cuBLAS version is ", transformer_engine::cuda::cublas_version()); diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 1f1637cecd..f16d9b81cc 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -465,11 +465,20 @@ def get_attention_backend( if use_flash_attention_3 and FlashAttentionUtils.v3_is_installed: logger.debug("Disabling FlashAttention 3 for compute capability != sm90") use_flash_attention_3 = False - # FA4 supports SM80, SM90, SM100, SM120 + # FA4 supports SM80, SM90, SM100 if device_compute_capability < (8, 0): if use_flash_attention_4 and FlashAttentionUtils.v4_is_installed: logger.debug("Disabling FlashAttention 4 for compute capability < sm80") use_flash_attention_4 = False + # FA4 is temporarily disabled on SM120 due to failures observed with + # SplitKV, Block sparsity / paged KV and likely FAv4/DSL integration issues. + if device_compute_capability == (12, 0): + if use_flash_attention_4 and FlashAttentionUtils.v4_is_installed: + logger.warning( + "Disabling FlashAttention 4 on sm120 due to missings bits of support for SplitKV," + " Block sparsity / paged KV and likely FAv4/DSL integration issues." + ) + use_flash_attention_4 = False # On SM90, prefer FA3 over FA4 when FA3 is available. # FA3 is more mature on Hopper; FA4's SM90 backward has limitations # (MLA, non-standard head dims, SplitKV). diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 3ada2459c8..df86fde0ba 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -16,6 +16,7 @@ #include #include "../extensions.h" +#include "../util.h" #include "common.h" #include "common/util/system.h" #include "pybind.h" @@ -67,6 +68,52 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob namespace { +void split_quantize_nvfp4_impl(const TensorWrapper &input, + const std::vector &input_list, + std::vector &output_list, + const std::vector &split_sections, + const std::vector &quantizers); + +// Converts the per-group GPU row counts (first_dims, int64 CUDA tensor) +// into a host vector of per-group row counts and returns it. +// The returned vector is used by NVFP4 grouped-quantize to split the input +// tensor into per-group sub-tensors. +// Currently, only used for SM120 NVFP4 grouped-quantize fallback. +std::vector get_split_sections(std::optional first_dims, size_t num_tensors) { + auto first_dims_tensor = first_dims.value(); + NVTE_CHECK(first_dims_tensor.scalar_type() == at::kLong, + "Expected first_dims dtype=int64, got scalar_type enum=", + static_cast(first_dims_tensor.scalar_type())); + // D2H copy to CPU + auto first_dims_cpu = first_dims_tensor.contiguous().to(at::kCPU); + NVTE_CHECK(static_cast(first_dims_cpu.numel()) == num_tensors, "Expected ", num_tensors, + " first_dims entries, but got ", first_dims_cpu.numel(), "."); + std::vector split_sections(num_tensors, 0); + const int64_t *first_dims_ptr = first_dims_cpu.data_ptr(); + for (size_t i = 0; i < num_tensors; ++i) { + NVTE_CHECK(first_dims_ptr[i] >= 0, "first_dims must be non-negative, got ", first_dims_ptr[i], + " at index ", i, "."); + split_sections[i] = static_cast(first_dims_ptr[i]); + } + return split_sections; +} + +// Converts the Python GroupedTensor into a C++ vector of TensorWrappers, +// which are used by NVFP4 grouped-quantize to store the quantized output tensors. +// Currently, only used for SM120 NVFP4 grouped-quantize fallback. +std::vector get_grouped_outputs(const py::object &grouped_output_py, + size_t num_tensors) { + py::list split_outputs = grouped_output_py.attr("split_into_quantized_tensors")(); + NVTE_CHECK(static_cast(py::len(split_outputs)) == num_tensors, "Expected ", num_tensors, + " output tensors, but got ", py::len(split_outputs), "."); + std::vector output_list; + output_list.reserve(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + output_list.emplace_back(makeTransformerEngineTensor(split_outputs[i], py::none())); + } + return output_list; +} + // helper functions for NVFP4 grouped quantization (cuda graph safe with shapes stored in device without D2H copy) void group_quantize_nvfp4_impl(const GroupedTensorWrapper &grouped_input_tensor, GroupedTensorWrapper &grouped_output_tensor, @@ -147,10 +194,11 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const using namespace transformer_engine::pytorch::detail; init_extension(); - NVTE_CHECK(tensor.dim() == 2, "Tensor must be 2D"); + auto input_contiguous = tensor.contiguous(); + NVTE_CHECK(input_contiguous.dim() == 2, "Tensor must be 2D"); std::vector logical_shape; - for (const auto &d : tensor.sizes()) { + for (const auto &d : input_contiguous.sizes()) { logical_shape.push_back(d); } const auto logical_first_dim = logical_shape[0]; @@ -162,8 +210,9 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const // Create input GroupedTensor. auto grouped_input_tensor = GroupedTensorWrapper(num_tensors, logical_shape); - grouped_input_tensor.set_rowwise_data( - tensor.data_ptr(), GetTransformerEngineDType(tensor.scalar_type()), getTensorShape(tensor)); + grouped_input_tensor.set_rowwise_data(input_contiguous.data_ptr(), + GetTransformerEngineDType(input_contiguous.scalar_type()), + getTensorShape(input_contiguous)); // Create output GroupedTensor. auto [grouped_output_tensor_cpp, grouped_output_py] = quantizer_cpp->create_grouped_tensor( @@ -196,8 +245,45 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const case GroupedQuantizationMode::NVFP4_GROUPED_QUANTIZE: { // NVFP4 grouped quantization NVFP4Quantizer *nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); - group_quantize_nvfp4_impl(grouped_input_tensor, grouped_output_tensor_cpp, - nvfp4_quantizer_cpp, at::cuda::getCurrentCUDAStream()); + const bool enable_sm120_grouped_nvfp4_fallback = is_sm120_device() && first_dims.has_value(); + // SM120 fallback does not support GEMM-swizzled NVFP4 scale layouts in this path. + if (enable_sm120_grouped_nvfp4_fallback) { + // Use a local quantizer copy so fallback behavior does not mutate shared quantizer state. + NVFP4Quantizer fallback_quantizer = *nvfp4_quantizer_cpp; + fallback_quantizer.optimize_for_gemm = false; + + // As SM120 does not support GEMM-swizzled NVFP4 scale layouts in this path, + // we need to split the input tensor into per-group sub-tensors and quantize them separately. + auto split_sections = get_split_sections(first_dims, num_tensors); + std::vector input_list; + input_list.reserve(num_tensors); + auto *input_dptr = reinterpret_cast(input_contiguous.data_ptr()); + const auto input_dtype = GetTransformerEngineDType(input_contiguous.scalar_type()); + const size_t dim0_stride = logical_first_dim == 0 + ? 0 + : static_cast(input_contiguous.element_size()) * + static_cast(input_contiguous.numel()) / + logical_first_dim; + size_t dim0_offset = 0; + for (size_t i = 0; i < num_tensors; ++i) { + NVTE_CHECK(dim0_offset + split_sections[i] <= logical_first_dim, + "Split sections exceed input tensor first dimension."); + std::vector split_shape = {split_sections[i], logical_last_dim}; + void *split_dptr = static_cast(input_dptr + dim0_offset * dim0_stride); + input_list.emplace_back( + makeTransformerEngineTensor(split_dptr, split_shape, input_dtype)); + dim0_offset += split_sections[i]; + } + // Get the quantized output tensors from the Python GroupedTensor. + auto output_list = get_grouped_outputs(grouped_output_py, num_tensors); + std::vector quantizers(num_tensors, &fallback_quantizer); + auto input_tensor_cpp = makeTransformerEngineTensor(input_contiguous); + split_quantize_nvfp4_impl(input_tensor_cpp, input_list, output_list, split_sections, + quantizers); + } else { + group_quantize_nvfp4_impl(grouped_input_tensor, grouped_output_tensor_cpp, + nvfp4_quantizer_cpp, at::cuda::getCurrentCUDAStream()); + } break; } case GroupedQuantizationMode::MXFP8_GROUPED_QUANTIZE: { @@ -989,6 +1075,7 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, cudaStream_t stream) { const size_t num_tensors = split_sections.size(); const auto &quantizer = *quantizers.front(); + const bool sm120_device = is_sm120_device(); std::vector nvte_tensor_input_list; std::vector nvte_tensor_output_list; @@ -1001,6 +1088,8 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, bool all_aligned_token_dim = std::all_of(split_sections.begin(), split_sections.end(), [](size_t split_section) { return split_section % 128 == 0; }); + // SM120 fallback: avoid the fully fused grouped row+col RHT kernel path. + all_aligned_token_dim = all_aligned_token_dim && !sm120_device; // in the case when rowwise and colwise cannot be fused, we have to generate the RNG states twice // so that rowwise and colwise will have different random numbers @@ -1019,7 +1108,7 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, bool with_bulk_generate_rng_states = true; // Stochastic rounding - bool need_stochastic_rounding = quantizer.stochastic_rounding; + bool need_stochastic_rounding = quantizer.stochastic_rounding && !sm120_device; auto stochastic_rng_state_resources = setup_stochastic_rounding_rng_states_helper( num_tensors, need_stochastic_rounding, with_bulk_generate_rng_states, need_separate_rng_states, quant_config_list, quant_config_list_colwise); @@ -1108,6 +1197,8 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, if (quantizer.columnwise_usage) { std::vector out_transpose_list; std::vector nvte_tensor_out_transpose_list; + std::vector rht_output_t_tensors; + rht_output_t_tensors.reserve(num_tensors); for (size_t i = 0; i < num_tensors; i++) { bool is_empty_split = input_list[i].numel() == 0; auto out_columnwise_data = output_list[i].get_columnwise_data(); @@ -1133,10 +1224,35 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, out_transpose_list.emplace_back(std::move(out_transpose)); nvte_tensor_out_transpose_list.push_back(out_transpose_list.back().data()); } - nvte_group_hadamard_transform_cast_fusion_columnwise( - input.data(), reinterpret_cast(nvte_tensor_out_transpose_list.data()), - rht_matrix_nvte.data(), split_sections.data(), num_tensors, - quant_config_list_colwise_to_use[0], stream); + if (sm120_device) { + // SM120 fallback: avoid grouped columnwise RHT fusion path and run unfused per split. + for (size_t i = 0; i < num_tensors; i++) { + if (input_list[i].numel() == 0) { + continue; + } + const int rows = static_cast(split_sections[i]); + const int cols = static_cast(input_list[i].size(input_list[i].ndim() - 1)); + auto rht_output_t = allocateTorchTensor(cols, rows, input_list[i].dtype()); + rht_output_t_tensors.push_back(rht_output_t); + TensorWrapper rht_output_t_cpp; + rht_output_t_cpp.set_rowwise_data( + rht_output_t.data_ptr(), input_list[i].dtype(), + std::vector{static_cast(cols), static_cast(rows)}); + // SM120 unfused columnwise path (per split): + // 1) Apply RHT on the input and write the result in transposed layout (shape [cols, rows]) into rht_output_t_cpp. + // Columnwise NVFP4 scales are obtained by running rowwise NVFP4 on x_t, so we need the transposed layout here. + nvte_hadamard_transform(input_list[i].data(), rht_output_t_cpp.data(), 0, + quantizer.rht_matrix_random_sign_mask_t, stream); + // 2) NVFP4-quantize the RHT(x_t) output into the columnwise (out_transpose) slot. + nvte_quantize_v2(rht_output_t_cpp.data(), out_transpose_list[i].data(), + quant_config_list_colwise_to_use[i], stream); + } + } else { + nvte_group_hadamard_transform_cast_fusion_columnwise( + input.data(), reinterpret_cast(nvte_tensor_out_transpose_list.data()), + rht_matrix_nvte.data(), split_sections.data(), num_tensors, + quant_config_list_colwise_to_use[0], stream); + } } } } @@ -1149,6 +1265,7 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, cudaStream_t stream) { const size_t num_tensors = input_list.size(); const auto &quantizer = *quantizers.front(); + const bool sm120_device = is_sm120_device(); std::vector nvte_tensor_input_list; std::vector nvte_tensor_output_list; @@ -1171,7 +1288,7 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, // so that we can generate all rng states at once bool with_bulk_generate_rng_states = false; - bool need_stochastic_rounding = quantizer.stochastic_rounding; + bool need_stochastic_rounding = quantizer.stochastic_rounding && !sm120_device; // place holder for colwise rng states, which are not needed in this case std::vector dummy_quant_config_list_colwise; diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 82dfe4d222..edfb3841a8 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -11,6 +11,7 @@ #include "common/util/system.h" #include "pybind.h" #include "torch/torch.h" +#include "util.h" namespace transformer_engine::pytorch { @@ -1942,7 +1943,13 @@ std::pair NVFP4Quantizer::create_grouped_tenso getTensorShape(*tensor_offsets)); } - out_cpp.set_with_gemm_swizzled_scales(this->optimize_for_gemm); + const bool enable_sm120_grouped_nvfp4_fallback = first_dims.has_value() && is_sm120_device(); + // Keep grouped metadata aligned with runtime behavior: + // - default: follow optimize_for_gemm + // - SM120 fallback path: force unswizzled layout + const bool with_gemm_swizzled_scales = + this->optimize_for_gemm && !enable_sm120_grouped_nvfp4_fallback; + out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal); py::dict kwargs; @@ -1965,7 +1972,7 @@ std::pair NVFP4Quantizer::create_grouped_tenso kwargs["first_dims"] = first_dims.has_value() ? py::cast(*first_dims) : py::none(); kwargs["last_dims"] = py::none(); kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); - kwargs["with_gemm_swizzled_scales"] = this->optimize_for_gemm; + kwargs["with_gemm_swizzled_scales"] = with_gemm_swizzled_scales; kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); if (result == nullptr) { @@ -2241,7 +2248,12 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou quant_config_columnwise.set_noop_tensor(noop_flag->data()); } quant_config.set_nvfp4_2d_quantization(this->with_2d_quantization); - quant_config.set_stochastic_rounding(this->stochastic_rounding); + + // Disable stochastic-rounding FP4 cast path for SM120, which relies on arch-specific PTX + // instructions. + const bool sm120_device = is_sm120_device(); + const bool use_stochastic_rounding = this->stochastic_rounding && !sm120_device; + quant_config.set_stochastic_rounding(use_stochastic_rounding); // We only need RHT for columnwise usage. // flat first dim and last dim for multi dimensional input @@ -2280,11 +2292,11 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou // 3. Columnwise usage is enabled // 4. Rowwise and columnwise quantization are not fused, // because within a single kernel we can generate two different random numbers for rowwise and columnwise - const bool need_separate_columnwise_rng = this->stochastic_rounding && this->with_rht && + const bool need_separate_columnwise_rng = use_stochastic_rounding && this->with_rht && this->columnwise_usage && (!eligible_for_rht_cast_fusion); - if (this->stochastic_rounding) { + if (use_stochastic_rounding) { const size_t rng_elts_per_thread = 1024; // Wild guess, probably can be tightened auto gen = at::get_generator_or_default( std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); diff --git a/transformer_engine/pytorch/csrc/util.h b/transformer_engine/pytorch/csrc/util.h index 132db4075f..5eb51721df 100644 --- a/transformer_engine/pytorch/csrc/util.h +++ b/transformer_engine/pytorch/csrc/util.h @@ -13,6 +13,7 @@ #include #include +#include "common/util/cuda_runtime.h" #include "transformer_engine/transformer_engine.h" namespace transformer_engine { @@ -65,6 +66,11 @@ std::optional maybe_swizzle_grouped_tensor(GroupedTensorW */ at::Tensor convert_block_scaling_to_mxfp8_tensor(TensorWrapper& input, bool rowwise); +/*! \brief Check whether the current CUDA device is SM120. */ +inline bool is_sm120_device() { + return transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()) == 120; +} + } // namespace pytorch } // namespace transformer_engine diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index ac56d334bc..712b4eac62 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -1047,7 +1047,9 @@ def split_into_quantized_tensors( columnwise_scale_inv=columnwise_scale_inv, fp8_dtype=quantizer.dtype, quantizer=quantizer, - with_gemm_swizzled_scales=quantizer.optimize_for_gemm, + # Use the actual grouped-output layout. This can differ from the requested + # quantizer flag if the backend produces a different layout (e.g. sm120) + with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, ) result.append(tensor) @@ -1182,7 +1184,9 @@ def split_into_quantized_tensors( amax_columnwise=amax_columnwise, fp4_dtype=quantizer.dtype, quantizer=quantizer, - with_gemm_swizzled_scales=quantizer.optimize_for_gemm, + # Use the actual grouped-output layout. This can differ from the requested + # quantizer flag if the backend produces a different layout (e.g. sm120) + with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, row_scaled_nvfp4=row_scaled_nvfp4, ) result.append(tensor)