From d2fd00287761dc5b34130ffffe664f4069c00223 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 20 Jan 2026 09:14:08 -0800 Subject: [PATCH 01/22] Changed VERSION to 2.12.0 Signed-off-by: Przemek Tredak --- build_tools/VERSION.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt index d5e1cb291..d8b698973 100644 --- a/build_tools/VERSION.txt +++ b/build_tools/VERSION.txt @@ -1 +1 @@ -2.12.0.dev0 +2.12.0 From 6add8c95fe0f16d63389d35a2972682b26d3c7a9 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 20 Jan 2026 13:59:29 -0800 Subject: [PATCH 02/22] [Common] Enable determinism for cuDNN >= 9.18.1 on Blackwell (#2584) * update FE to 1.17 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add determinism flag Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add determinism to test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add determinism to qa/ Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * move bias/dbias/versioning/dropout logic to C API Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update qa/L0_pytorch_unittest/test.sh make .xml file specific to deterministic tests in qa/ Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add determinism to Jax extension Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add determinism to Jax tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update tests/jax/test_fused_attn.py fix typo Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Update transformer_engine/common/fused_attn/fused_attn.cpp fix indentation Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the AI fixes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Jax extension call Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fixes based on comments Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix selection logic and fwd arg Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix version check in Jax test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix pytorch CI failures Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix Jax CI failures Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix non-/determinism logic and CI Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix formatting Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/fused_attn/fused_attn.cpp fix and/or logic Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update to 9.18.1 for requirement Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * reduce Jax CI tests for determinism Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- qa/L0_jax_unittest/test.sh | 1 + qa/L0_pytorch_unittest/test.sh | 1 + tests/jax/test_fused_attn.py | 212 +++++++++++++++++- tests/pytorch/attention/test_attention.py | 41 +++- .../common/fused_attn/fused_attn.cpp | 41 ++-- .../include/transformer_engine/fused_attn.h | 3 +- .../jax/cpp_extensions/attention.py | 21 +- transformer_engine/jax/csrc/extensions.h | 2 +- .../jax/csrc/extensions/attention.cpp | 8 +- .../attention/dot_product_attention/utils.py | 5 +- transformer_engine/pytorch/csrc/extensions.h | 2 +- .../pytorch/csrc/extensions/attention.cpp | 4 +- 13 files changed, 299 insertions(+), 44 deletions(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 0258951d4..b372d3987 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 0258951d4d512f4714eb1574496f4d57669b1b93 +Subproject commit b372d39879d44c91a8d5b342022e74802b6a8da2 diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index ee9ce130a..3453e35d2 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -29,6 +29,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" mkdir -p "$XML_LOG_DIR" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*" +NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_fused_attn_with_determinism.xml $TE_PATH/tests/jax/test_fused_attn.py -k "TestFusedAttnWithDeterminism" || test_fail "tests/jax/test_fused_attn.py" pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist" diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 21eed2836..a13dfada7 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -45,6 +45,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_e python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" +NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index a0aee5043..f9946e1f7 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. """Tests for fused attention""" +import os from enum import Enum, auto from dataclasses import dataclass, field from functools import partial @@ -49,6 +50,9 @@ from distributed_test_base import assert_equal_collectives from utils import assert_allclose, print_debug_tensor_stats +# Get determinism +_deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) + @pytest.fixture(autouse=True, scope="module") def init(): @@ -413,15 +417,25 @@ def _check_configs(self): pytest.skip( "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN" ) - # TODO(KshitijLakhani): Set the upper limit for skipping this test when cuDNN adds support - if ( - get_device_compute_capability(0) >= 100 - and self.dropout_prob == 0.1 - and self.attn_bias_type is not AttnBiasType.NO_BIAS - ): - pytest.skip( - "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported" - ) + + if get_device_compute_capability(0) >= 100 and self.is_training: + if FusedAttnHelper.is_non_deterministic_allowed() and ( + (self.dropout_prob != 0.0 and self.attn_bias_type != AttnBiasType.NO_BIAS) + or get_cudnn_version() < 90700 + ): + pytest.skip( + "For sm100+, non-deterministic bprop (cuDNN 9.7+) does not support bias with" + " dropout" + ) + if not FusedAttnHelper.is_non_deterministic_allowed() and ( + self.dropout_prob != 0.0 + or self.attn_bias_type != AttnBiasType.NO_BIAS + or get_cudnn_version() < 91801 + ): + pytest.skip( + "For sm100+, deterministic bprop (cuDNN 9.18.1+) does not support bias or" + " dropout" + ) # Test the MLA case where head dims for qk differ from head dims for v, only if the tensors # are provided in BSHD_BSHD_BSHD or THD_THD_THD formats if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate(): @@ -1269,6 +1283,7 @@ def check_dqkv(primitive, reference, pad, idx): pytest.param(SeqDescFormat.SegmentIDs, id="SegmentIDs"), ], ) +@pytest.mark.skipif(_deterministic, reason="Test non-determinism only") class TestFusedAttn: """ Fused attention tester @@ -1392,3 +1407,182 @@ def test_backward( seq_desc_format, ) runner.test_backward() + + +@pytest.mark.parametrize( + "attn_mask_type", + [ + pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"), + pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"), + pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"), + pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"), + pytest.param( + AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, id="PADDING_CAUSAL_BOTTOM_RIGHT" + ), + ], +) +@pytest.mark.parametrize( + "softmax_type", + [ + pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"), + ], +) +@pytest.mark.parametrize( + "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype, qkv_layout", + [ + # large data size + fp16 + cross attn + gqa + diff hidden v dim + qkv separate + pytest.param( + 2, + 1024, + 2048, + 12, + 6, + 128, + 64, + jnp.bfloat16, + QKVLayout.BSHD_BSHD_BSHD, + id="2-1024-2048-12-6-128-64-BF16-CROSS-GQA-SEPARATE", + ), + pytest.param( + 2, + 1024, + 2048, + 12, + 6, + 128, + 64, + jnp.bfloat16, + QKVLayout.THD_THD_THD, + id="2-1024-2048-12-6-128-64-BF16-CROSS-GQA-RAGGED_SEPARATE", + ), + ], +) +@pytest.mark.parametrize( + "dropout_prob", + [ + pytest.param(0.0, id="DROP_0.0"), + ], +) +@pytest.mark.parametrize( + "swa", + [ + pytest.param(False, id="NO_SWA"), + ], +) +@pytest.mark.parametrize( + "seq_desc_format", + [ + pytest.param(SeqDescFormat.Seqlens, id="Seqlens"), + ], +) +@pytest.mark.skipif(not _deterministic, reason="Test determinism only") +class TestFusedAttnWithDeterminism: + """ + Fused attention tester with determinism + """ + + @staticmethod + @pytest.mark.parametrize( + "is_training", + [ + pytest.param(True, id="TRAINING"), + ], + ) + @pytest.mark.parametrize( + "attn_bias_type, bias_shape", + [ + pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"), + ], + ) + def _test_forward( + b, + s_q, + s_kv, + h_q, + h_kv, + d_qk, + d_v, + attn_bias_type, + attn_mask_type, + softmax_type, + dropout_prob, + dtype, + is_training, + qkv_layout, + bias_shape, + swa, + seq_desc_format, + ): + """ + Test forward with parameterized configs + This test is not intended to run automatically during CI as it is time-consuming + It is kept for development and debugging + """ + TestFusedAttn._test_forward( + b, + s_q, + s_kv, + h_q, + h_kv, + d_qk, + d_v, + attn_bias_type, + attn_mask_type, + softmax_type, + dropout_prob, + dtype, + is_training, + qkv_layout, + bias_shape, + swa, + seq_desc_format, + ) + + @staticmethod + @pytest.mark.parametrize( + "attn_bias_type, bias_shape", + [ + pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"), + ], + ) + def test_backward( + b, + s_q, + s_kv, + h_q, + h_kv, + d_qk, + d_v, + attn_bias_type, + attn_mask_type, + softmax_type, + dropout_prob, + dtype, + qkv_layout, + bias_shape, + swa, + seq_desc_format, + ): + """ + Test backward with parameterized configs + """ + TestFusedAttn.test_backward( + b, + s_q, + s_kv, + h_q, + h_kv, + d_qk, + d_v, + attn_bias_type, + attn_mask_type, + softmax_type, + dropout_prob, + dtype, + qkv_layout, + bias_shape, + swa, + seq_desc_format, + ) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index eb7905bcd..9111d3511 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -72,6 +72,14 @@ f" sm{device_compute_capability[0] * 10 + device_compute_capability[1]}" ) + +# Get determinism +_deterministic = ( + not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) + or torch.are_deterministic_algorithms_enabled() +) + + # Reset RNG seed and states seed = 1234 reset_rng_states() @@ -160,6 +168,7 @@ def test_dot_product_attention( qkv_layout=qkv_layout, pad_between_seqs=pad_between_seqs, is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if not fused_attn_supported: @@ -170,6 +179,7 @@ def test_dot_product_attention( qkv_layout=qkv_layout, pad_between_seqs=pad_between_seqs, is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends @@ -886,11 +896,14 @@ def _run_dot_product_attention( reset_rng_states() os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" if backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0" + if backend == "UnfusedDotProductAttention": + os.environ["NVTE_UNFUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True # Create seqlens @@ -1292,6 +1305,7 @@ def test_transformer_layer( qkv_format.replace("hd", "h3d") if fused_qkv_params else qkv_format.replace("hd", "3hd") ), is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if not fused_attn_supported: @@ -1305,6 +1319,7 @@ def test_transformer_layer( else qkv_format.replace("hd", "3hd") ), is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends @@ -1432,10 +1447,13 @@ def _run_transformer_layer( reset_rng_states() os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" if backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" + if backend == "UnfusedDotProductAttention": + os.environ["NVTE_UNFUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True # Create input tensor @@ -1629,6 +1647,7 @@ def test_dpa_fp8_extra_state(model, dtype): qkv_dtype=torch.float8_e4m3fn, qkv_layout="sb3hd", is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if not fused_attn_supported and not flash_attn_supported: @@ -1819,6 +1838,7 @@ def test_mha_fp8_vs_f16( fp8=True, fp8_meta=fp8_meta, is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported_fp8, unfused_attn_supported = available_backends if flash_attn_supported + fused_attn_supported_fp8 < 1: @@ -1830,6 +1850,7 @@ def test_mha_fp8_vs_f16( qkv_dtype=dtype, qkv_layout=qkv_format.replace("hd", "h3d"), is_training=is_training, + deterministic=_deterministic, ) _, fused_attn_supported_f16, _ = available_backends if not fused_attn_supported_f16: @@ -1838,6 +1859,7 @@ def test_mha_fp8_vs_f16( if flash_attn_supported: os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16( @@ -1847,6 +1869,7 @@ def test_mha_fp8_vs_f16( if fused_attn_supported_fp8: os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16( @@ -1856,6 +1879,7 @@ def test_mha_fp8_vs_f16( if fused_attn_supported_f16: os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False") fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16( @@ -2068,6 +2092,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal fp8=True, fp8_meta=fp8_meta, is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if flash_attn_supported + fused_attn_supported < 1: @@ -2078,6 +2103,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal qkv_dtype=dtype, qkv_layout=qkv_layout, is_training=is_training, + deterministic=_deterministic, ) _, fused_attn_supported, _ = available_backends if not fused_attn_supported: @@ -2088,6 +2114,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal if flash_attn_supported: os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FlashAttention)") flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( @@ -2097,6 +2124,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal if unfused_attn_supported: os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)") unfused_attn_fwd_fp8, unfused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( @@ -2105,6 +2133,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)") fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( @@ -2113,6 +2142,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" if config.dropout_p == 0.0: # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)") @@ -2367,13 +2397,16 @@ def test_custom_mha_fp8_vs_f16(dtype, model): qkv_dtype=torch.float8_e4m3fn, qkv_layout="t3hd" if cudnn_frontend_version == 0 else "bs3hd", is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if not (fused_attn_backends and unfused_attn_supported): pytest.skip("Not enough backends to run this test with.") fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention") - unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention") + unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16( + dtype, config, "UnfusedDotProductAttention" + ) atol = 5e-1 rtol = 5e-1 @@ -2406,10 +2439,13 @@ def _run_custom_mha_fp8(dtype, config, backend): reset_rng_states() os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" if backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" + if backend == "UnfusedDotProductAttention": + os.environ["NVTE_UNFUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True inp = 0.0001 * torch.randint( @@ -2460,10 +2496,13 @@ def _run_ref_mha_f16(dtype, config, backend): os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" if backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" + if backend == "UnfusedDotProductAttention": + os.environ["NVTE_UNFUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True inp = torch.load("qkv.pt").to(device="cuda") diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index fde0d3892..415bfae06 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -206,7 +206,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph) { + int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic) { using namespace transformer_engine; NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; const int device_id = cuda::current_device(); @@ -440,7 +440,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // 9.13.1+: vanilla, off-by-one, learnable (cudnn_runtime_version >= 91301 || (cudnn_runtime_version < 91301 && - softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX))) { + softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)) && + // determinism on Blackwell + // pre-9.18.1: fwd: deterministic; bwd: non-deterministic + // 9.18.1+: fwd: deterministic; bwd: non-deterministic/deterministic + (sm_arch_ < 100 || + (sm_arch_ >= 100 && (!is_training || + (is_training && !deterministic && + (dropout == 0.0 || bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)) || + (is_training && deterministic && cudnn_runtime_version >= 91801 && + dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS))))) { flag_arb = true; } if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { @@ -553,7 +562,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_logit, - cuda_graph); + cuda_graph, false); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -595,7 +604,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, wkspace, stream, handle); #else NVTE_ERROR( - "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); + "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. " + "\n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) @@ -669,7 +679,8 @@ void nvte_fused_attn_bwd_qkvpacked( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h, - max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false, cuda_graph); + max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false, cuda_graph, + deterministic); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -855,7 +866,7 @@ void nvte_fused_attn_fwd_kvpacked( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, - return_max_logit, cuda_graph); + return_max_logit, cuda_graph, false); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -897,7 +908,8 @@ void nvte_fused_attn_fwd_kvpacked( input_page_table_v, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( - "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); + "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. " + "\n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) @@ -982,10 +994,10 @@ void nvte_fused_attn_bwd_kvpacked( const NVTEDType Q_type = static_cast(input_Q->data.dtype); const NVTEDType KV_type = static_cast(input_KV->data.dtype); - NVTE_Fused_Attn_Backend fused_attention_backend = - nvte_get_fused_attn_backend(true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, - softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, - d, window_size_left, window_size_right, false, cuda_graph); + NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( + true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, + h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, false, + cuda_graph, deterministic); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -1166,7 +1178,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, - return_max_logit, cuda_graph); + return_max_logit, cuda_graph, false); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -1189,7 +1201,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso input_page_table_v, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( - "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); + "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. " + "\n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) @@ -1262,7 +1275,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false, - cuda_graph); + cuda_graph, deterministic); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index dd70ccf8d..0fabb81ae 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -208,13 +208,14 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); * \param[in] window_size_right Sliding window size (the right half). * \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats. * \param[in] cuda_graph Whether cuda graph capture is enabled or not. + * \param[in] deterministic Whether determinism is required or not. */ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph); + int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic); /*! \brief Compute dot product attention with packed QKV input. * diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 0cdfcebf3..ee10115aa 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -144,6 +144,7 @@ def get_fused_attn_backend(self): self.head_dim_v, self.window_size[0], self.window_size[1], + not self.is_non_deterministic_allowed(), ) @staticmethod @@ -3563,13 +3564,21 @@ def fused_attn_bwd( softmax_offset, (None, HEAD_AXES, None, None) ) - # TODO(KshitijLakhani): Add a check for cuDNN version when determinism does get supported on - # sm100+ compute_capabilities = get_all_device_compute_capability() - if any(x >= 100 for x in compute_capabilities): - assert not ( - attn_bias_type != AttnBiasType.NO_BIAS and dropout_probability != 0 - ), "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported" + if any(x >= 100 for x in compute_capabilities) and is_training: + assert ( + FusedAttnHelper.is_non_deterministic_allowed() + and get_cudnn_version() >= (9, 7, 0) + and (attn_bias_type == AttnBiasType.NO_BIAS or dropout_probability == 0.0) + ) or ( + not FusedAttnHelper.is_non_deterministic_allowed() + and get_cudnn_version() >= (9, 18, 1) + and attn_bias_type == AttnBiasType.NO_BIAS + and dropout_probability == 0.0 + ), ( + "For sm100+, non-deterministic bprop (cuDNN 9.7+) does not support bias with dropout," + " and deterministic bprop (cuDNN 9.18.1+) does not support bias or dropout" + ) fused_config = _FusedAttnConfig( attn_bias_type=attn_bias_type, diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index a83a1e0a8..5f9339263 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -113,7 +113,7 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend( NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, - int64_t window_size_right); + int64_t window_size_right, bool deterministic); pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 540aeb8b2..4fe8e728a 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -16,12 +16,12 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend( NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, - int64_t window_size_right) { + int64_t window_size_right, bool deterministic) { auto backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - false, false); + false, false, deterministic); return backend; } @@ -266,7 +266,7 @@ static void FusedAttnForwardImpl( is_training, static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - false, false); + false, false, deterministic); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); /* Auxiliary tensors (to be propagated to the backward pass later) */ @@ -522,7 +522,7 @@ static void FusedAttnBackwardImpl( is_training, static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - false, false); + false, false, deterministic); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, softmax_aux, rng_state, bias, softmax_offset); diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index bf19388d7..cb74a15e7 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -994,6 +994,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt window_size[1], return_max_logit, cuda_graph, + deterministic, ) if fused_attention_backend == FusedAttnBackend["No_Backend"]: logger.debug("Disabling FusedAttention as no backend supports the provided input") @@ -1064,10 +1065,6 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt logger.debug("Disabling FusedAttention for determinism reasons with post_scale_bias") use_fused_attention = False fused_attention_backend = None - if is_training and device_compute_capability >= (10, 0): - logger.debug("Disabling FusedAttention for determinism reasons on Blackwell") - use_fused_attention = False - fused_attention_backend = None # use_flash_attention may have been set above use_flash_attention_2 = use_flash_attention and use_flash_attention_2 diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 9dc0d1f37..591c89f83 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -81,7 +81,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph); + int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic); std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index b455e0375..be645d91b 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -45,12 +45,12 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph) { + int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic) { NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right, - return_max_logit, cuda_graph); + return_max_logit, cuda_graph, deterministic); return fused_attention_backend; } From cfabd833d84805585025d0b6f6b680caebbc0c75 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com> Date: Wed, 21 Jan 2026 16:05:27 +0100 Subject: [PATCH 03/22] [Common] Tuned NVFP4 cast kernel (#2412) * Implemented persistent nvfp4 kernel Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix FP4 guard in ptx Signed-off-by: Oleg Goncharov * Fix Signed-off-by: Oleg Goncharov * Fix in ptx. reduxf32 guard Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Signed-off-by: Oleg Goncharov * Fixes per PR review Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixes per PR review. Added parameter to turn off the persistency Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Modified reference CPU implementation in C++ unit tests to match GPU (numerical truncation). Tightened the numerical tolerance Signed-off-by: Oleg Goncharov * Disabled persistency by default, as non-persistent kernel is more performant when inputs are large Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Use the tuned kernel also for the rowwise only quantization Signed-off-by: Oleg Goncharov * Fixed typo Signed-off-by: Oleg Goncharov * Addressed comments from the PR review Signed-off-by: Oleg Goncharov * Resolved conflicts Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Macros renaming Signed-off-by: Oleg Goncharov --------- Signed-off-by: Oleg Goncharov Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../cpp/operator/test_cast_nvfp4_transpose.cu | 130 +-- .../common/cast/core/common.cuh | 6 + .../cast/nvfp4/quantize_transpose_nvfp4.cuh | 7 + .../quantize_transpose_nvfp4_tuned_1D.cuh | 789 ++++++++++++++++++ transformer_engine/common/util/ptx.cuh | 306 +++++++ 5 files changed, 1184 insertions(+), 54 deletions(-) create mode 100644 transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 1904d03df..c4df8759f 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -54,12 +54,16 @@ std::vector create_transpose(const InputType* const input, const size } // Compute the global encode scale factor for a given global amax -float compute_global_encode_scaling_factor_FP4(const float global_amax) { +float compute_global_encode_scaling_factor_FP4(const float global_amax, const bool use_fast_math) { constexpr float fp8_max = 448.0f; // 448.0f; constexpr float fp4_max = 6.0f; // 6.0f; float global_encode_scale = fp8_max * fp4_max / global_amax; - // If scale is infinity, return max value of float32 - global_encode_scale = fminf(global_encode_scale, Numeric_Traits::maxNorm); + // If scale is infinity, return the max normalized value + const float max_norm_clamp = use_fast_math + ? Numeric_Traits::maxNorm + : Numeric_Traits::maxNorm; + + global_encode_scale = fminf(global_encode_scale, max_norm_clamp); // If global amax is 0 or infinity, return 1 if (global_amax == 0.0f || global_encode_scale == 0.0f) { return 1.0f; @@ -76,10 +80,11 @@ void quantize_nvfp4_1d(float (*OP)(const float), const size_t rows, const size_t cols, const size_t scales_stride, - const float global_amax) { + const float global_amax, + const bool use_fast_math) { // Compute a global encoding/decoding scaling factor for all S_dec_b - const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax); + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math); constexpr size_t block_size_X = 16; const size_t blocks_X = divide_round_up(cols, block_size_X); @@ -114,14 +119,20 @@ void quantize_nvfp4_1d(float (*OP)(const float), const float S_dec_b = block_amax / 6.0f; // Scale & Store per-block decoding scaling factor - const float S_dec_b_fp8 = S_dec_b * S_enc; + const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); + const float S_dec_b_fp32 = static_cast(S_dec_b_fp8); // Compute "correct" per-block encoding scaling factor - const float S_enc_b_fp8 = S_dec_b_fp8 == 0 ? 0.f : S_enc / S_dec_b_fp8; + const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : S_enc / S_dec_b_fp32; const size_t scale_idx = i * scales_stride + block_X; - scales[scale_idx] = static_cast(S_dec_b_fp8); - const float scale_reciprocal = S_enc_b_fp8; + scales[scale_idx] = S_dec_b_fp8; + + float scale_reciprocal = S_enc_b_fp8; + if (use_fast_math) { + // Numerical truncation to match GPU implementation, if mixed precision FMA instruction is used + scale_reciprocal = static_cast(static_cast(scale_reciprocal)); + } for (size_t j = j_min; j < j_max; j += 2) { const int idx_pair = (i * cols + j) / 2; @@ -136,7 +147,7 @@ void quantize_nvfp4_1d(float (*OP)(const float), fp4e2m1x2 casted_to_e2m1_pair(scaled_elt_pair); output[idx_pair] = casted_to_e2m1_pair; - // const double2 truncated_pair = cvt_fp4x2_to_double2(casted_to_e2m1_pair); + const double2 truncated_pair = cvt_fp4x2_to_double2(casted_to_e2m1_pair); } } } @@ -149,9 +160,10 @@ void compute_2d_mathematical_scales(float (*OP)(const float), const size_t rows, const size_t cols, const float global_amax, - std::vector>& math_scales) { + std::vector>& math_scales, + const bool use_fast_math) { - const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax); + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math); constexpr size_t block_size_Y = 16; constexpr size_t block_size_X = 16; const size_t blocks_Y = divide_round_up(rows, block_size_Y); @@ -195,13 +207,14 @@ void quantize_nvfp4_2d(float (*OP)(const float), const size_t rows, const size_t cols, const size_t scales_stride, - const float global_amax) { + const float global_amax, + const bool use_fast_math) { // Step 1: Compute mathematical 8x8 scaling factors std::vector> math_scales; - compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales); + compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math); - const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax); + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math); constexpr size_t block_size_Y = 16; constexpr size_t block_size_X = 16; const size_t blocks_Y = divide_round_up(rows, block_size_Y); @@ -282,11 +295,12 @@ void quantize_nvfp4(float (*OP)(const float), const size_t cols, const size_t scales_stride, const float global_amax, + const bool use_fast_math, const bool use_2d_quantization = false) { if (use_2d_quantization) { - quantize_nvfp4_2d(OP, input, output, scales, rows, cols, scales_stride, global_amax); + quantize_nvfp4_2d(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_fast_math); } else { - quantize_nvfp4_1d(OP, input, output, scales, rows, cols, scales_stride, global_amax); + quantize_nvfp4_1d(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_fast_math); } } @@ -302,6 +316,7 @@ void compute_ref(float (*OP)(const float), const size_t cols, const size_t scales_stride, const size_t scales_stride_t, + const bool use_fast_math, const bool use_2d_quantization = false) { std::vector input_t = create_transpose(input, rows, cols); @@ -309,7 +324,7 @@ void compute_ref(float (*OP)(const float), if (use_2d_quantization) { // Step 1: Compute mathematical 8×8 scaling factors std::vector> math_scales; - compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales); + compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math); constexpr size_t block_size_Y = 16; constexpr size_t block_size_X = 16; @@ -336,12 +351,16 @@ void compute_ref(float (*OP)(const float), // Step 4: Process quantized outputs using the same algorithm as quantize_nvfp4_2d // (This part processes the actual FP4 data using the mathematical scaling factors) - quantize_nvfp4_2d(OP, input, output, nullptr, rows, cols, scales_stride, global_amax); // scales already filled - quantize_nvfp4_2d(OP, input_t.data(), output_t, nullptr, cols, rows, scales_stride_t, global_amax); // scales_t already filled + quantize_nvfp4_2d(OP, input, output, nullptr, rows, cols, scales_stride, global_amax, + use_fast_math); // scales already filled + quantize_nvfp4_2d(OP, input_t.data(), output_t, nullptr, cols, rows, scales_stride_t, global_amax, + use_fast_math); // scales_t already filled } else { - quantize_nvfp4(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_2d_quantization); - quantize_nvfp4(OP, input_t.data(), output_t, scales_t, cols, rows, scales_stride_t, global_amax, use_2d_quantization); + quantize_nvfp4(OP, input, output, scales, rows, cols, scales_stride, global_amax, + use_fast_math, use_2d_quantization); + quantize_nvfp4(OP, input_t.data(), output_t, scales_t, cols, rows, scales_stride_t, global_amax, + use_fast_math, use_2d_quantization); } } @@ -349,6 +368,8 @@ void compare_nvfp4_tensors(const std::string& name, const fp4e2m1 *test_data, const fp4e2m1 *ref_data, const int rows, const int cols, double atol = 1e-5, double rtol = 1e-8) { + constexpr int max_mismatches_to_print = 3; + std::vector mismatch_messages; size_t total_mismatches = 0; @@ -362,29 +383,16 @@ void compare_nvfp4_tensors(const std::string& name, const double t = (k == 0 ? test_data_pair.x : test_data_pair.y); const double r = (k == 0 ? ref_data_pair.x : ref_data_pair.y); - bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); - /* For Float32 the floating point comparison is enough to error out */ - bool assertion = false; - if (mismatch && !assertion) { - /* Check if it is just a failure of round to nearest choosing different - side of the real value */ - const double mean = (t + r) / 2; - const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); - const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); - const double cast_mean_p = static_cast(static_cast(mean_p)); - const double cast_mean_m = static_cast(static_cast(mean_m)); - assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); - } - if (assertion) { + const bool mismatch = fabs(t - r) > (atol + fabs(r) * rtol); + if (mismatch) { total_mismatches++; - std::string msg = "Mismatch at place (" + std::to_string(idx + k) + "): " + - std::to_string(t) + " vs " + std::to_string(r) + - " (abs_diff: " + std::to_string(fabs(t - r)) + - ", rel_diff: " + std::to_string(r == 0 ? 0.0 : fabs((t - r) / r)) + ")"; - mismatch_messages.push_back(msg); - // Optional: limit number of detailed messages to avoid overwhelming output - if (mismatch_messages.size() <= 100) { + if (total_mismatches <= max_mismatches_to_print) { + std::string msg = "Mismatch at place (" + std::to_string(idx + k) + "): " + + std::to_string(t) + " vs " + std::to_string(r) + + " (abs_diff: " + std::to_string(fabs(t - r)) + + ", rel_diff: " + std::to_string(r == 0 ? 0.0 : fabs((t - r) / r)) + ")"; + mismatch_messages.push_back(msg); std::cout << "Error in tensor " << name << ": " << msg << std::endl; } } @@ -400,8 +408,9 @@ void compare_nvfp4_tensors(const std::string& name, std::cout << "STATUS: FAILED for output" << std::endl; std::cout << "Total mismatches found: " << total_mismatches << std::endl; std::cout << "Mismatch rate: " << (100.0 * total_mismatches) / (rows * cols) << "%" << std::endl; - if (mismatch_messages.size() > 100) { - std::cout << "... and " << (mismatch_messages.size() - 100) << " more mismatches (showing first 100)" << std::endl; + if (mismatch_messages.size() > max_mismatches_to_print) { + std::cout << "... and " << (mismatch_messages.size() - max_mismatches_to_print) + << " more mismatches (showing first " << max_mismatches_to_print << ")" << std::endl; } std::cout << "============================" << std::endl; @@ -519,7 +528,8 @@ void compareResults_nvfp4(const Tensor &test, template void performTest(float (*OP)(const float), - const std::vector& shape) { + const std::vector& shape, + const bool use_fast_math) { using namespace test; DType itype = TypeInfo::dtype; @@ -580,15 +590,16 @@ void performTest(float (*OP)(const float), cols, scales_stride, scales_stride_t, + use_fast_math, use_2d_quantization); - - QuantizationConfigWrapper quant_config; - // Initialize stochastic rounding Tensor rng_state("rng_state", std::vector{2}, DType::kInt64); rng_state.rowwise_cpu_dptr()[0] = 123; // rng_seed rng_state.rowwise_cpu_dptr()[1] = 321; // rng_sequence rng_state.from_cpu(); + + QuantizationConfigWrapper quant_config; + quant_config.set_use_fast_math(use_fast_math); quant_config.set_stochastic_rounding(false); quant_config.set_rng_state(rng_state.data()); @@ -619,8 +630,8 @@ void performTest(float (*OP)(const float), } ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); - const double atol = 0.05; - const double rtol = 0.1; + const double atol = 1.0E-6; + const double rtol = 1.0E-6; // Set dump_data=true to enable dumping tensor data to files for analysis compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, false); @@ -666,12 +677,18 @@ std::vector Activation_types = { ActivationType::Identity }; +std::vector use_fast_nvfp4_scaling_vec = { + false, + true +}; + } // namespace class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam , - transformer_engine::DType>> {}; + transformer_engine::DType, + bool>> {}; TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { // Skip tests for pre-Blackwell architectures @@ -685,6 +702,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { const ActivationType Act_type = std::get<0>(GetParam()); const auto tensor_dims = std::get<1>(GetParam()); const DType input_type = std::get<2>(GetParam()); + const bool use_fast_math = std::get<3>(GetParam()); // Skip tests if the input tensor is 1D if (tensor_dims.size() < 2) { @@ -702,7 +720,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { } TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, - performTest(OP, tensor_dims); + performTest(OP, tensor_dims, use_fast_math); ); } @@ -724,7 +742,8 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Combine( ::testing::ValuesIn(Activation_types), ::testing::ValuesIn(tensor_dims), - ::testing::Values(DType::kBFloat16)), + ::testing::Values(DType::kBFloat16), + ::testing::ValuesIn(use_fast_nvfp4_scaling_vec)), [](const testing::TestParamInfo& info) { std::string name = to_string(std::get<0>(info.param)); const auto& shape = std::get<1>(info.param); @@ -732,5 +751,8 @@ INSTANTIATE_TEST_SUITE_P( name += "X" + std::to_string(s); } name += "X" + test::typeName(std::get<2>(info.param)); + if (std::get<3>(info.param)) { + name += "X_FAST_SCALING"; + } return name; }); diff --git a/transformer_engine/common/cast/core/common.cuh b/transformer_engine/common/cast/core/common.cuh index a5c8327cd..0997b01f7 100644 --- a/transformer_engine/common/cast/core/common.cuh +++ b/transformer_engine/common/cast/core/common.cuh @@ -35,6 +35,12 @@ inline bool dimensions_supported_by_TMA(const Tensor *const t) { return cols % alignment_requirement == 0; } +__device__ __forceinline__ unsigned char *align_smem_ptr_per_TMA_requirements(unsigned char *p) { + size_t addr = reinterpret_cast(p); + addr = (addr + TMA_SHMEM_ALIGNMENT - 1) & ~(TMA_SHMEM_ALIGNMENT - 1); + return reinterpret_cast(addr); +} + namespace kernel { constexpr size_t THREADS_PER_BLOCK = 256; diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index 5da9cc5a5..99776db28 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -21,6 +21,7 @@ #include "../../util/ptx.cuh" #include "../../utils.cuh" #include "core_nvfp4.cuh" +#include "specialized/quantize_transpose_nvfp4_tuned_1D.cuh" namespace transformer_engine { namespace dispatch { @@ -1159,6 +1160,7 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, #if FP4_TYPE_SUPPORTED using namespace quantize_transpose_kernel; using namespace ptx; + bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; // If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to @@ -1166,6 +1168,11 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, // TODO(Frank): Is there a better way to do this? bool return_transpose = output->has_columnwise_data(); + if (!use_2d_quantization && (input.dtype() == DType::kBFloat16)) { + quantize_transpose_tuned_1D(input, noop, output, quant_config, stream); + return; + } + constexpr bool COMPUTE_ACTIVATIONS = false; using ParamOP = Empty; constexpr float (*OP)(float, const ParamOP &) = nullptr; diff --git a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh new file mode 100644 index 000000000..af1b01d6b --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh @@ -0,0 +1,789 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_transpose_nvfp4_tuned_1D.cuh + * \brief Tuned kernel to cast to NVFP4 and transpose. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_TUNED_1D_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_TUNED_1D_CUH_ + +#include +#include +#include +#include + +#include "../../../common.h" +#include "../../../util/math.h" +#include "../../../util/ptx.cuh" +#include "../../../utils.cuh" +#include "../core_nvfp4.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { + +namespace quantize_transpose_tuned_kernel { + +using namespace quantization_and_transposition_SF; +using namespace core; +using namespace ptx; + +#if FP4_TYPE_SUPPORTED + +struct TunableConfig { + static constexpr int CHUNK_DIM_Y = 128; + static constexpr int CHUNK_DIM_X = 128; + static constexpr int PREFETCH_STAGES = 1; + static constexpr bool PERSISTENT = false; +}; + +constexpr int SCALE_DIM = 16; // NVFP4 block (x16 elts) +constexpr int THREADS_NUM = 128; +constexpr int ELTS_PER_THREAD = 16; +constexpr int TILE_DIM_Y = 64; +constexpr int TILE_DIM_X = 64; + +static_assert(ELTS_PER_THREAD == SCALE_DIM && "Hardcoded and fixed parameter\0"); + +static_assert((THREADS_NUM * ELTS_PER_THREAD <= TILE_DIM_Y * TILE_DIM_X) && + "Unbalanced threads workload\0"); + +static_assert((TunableConfig::CHUNK_DIM_Y % TILE_DIM_Y == 0) && + "Chunk size Y must be evenly divisible by the tile size Y\0"); +static_assert((TunableConfig::CHUNK_DIM_X % TILE_DIM_X == 0) && + "Chunk size X must be evenly divisible by the tile size X\0"); + +static_assert((TILE_DIM_Y % SCALE_DIM == 0) && + "Tile size Y must be evenly divisible by the scale dim\0"); +static_assert((TILE_DIM_X % SCALE_DIM == 0) && + "Tile size X must be evenly divisible by the scale dim\0"); + +constexpr int TILES_Y = TunableConfig::CHUNK_DIM_Y / TILE_DIM_Y; +constexpr int TILES_X = TunableConfig::CHUNK_DIM_X / TILE_DIM_X; + +constexpr int THREADS_PER_SCALE_ROWWISE = SCALE_DIM / ELTS_PER_THREAD; + +constexpr int SCALES_PER_CHUNK_Y = TunableConfig::CHUNK_DIM_Y / SCALE_DIM; +constexpr int SCALES_PER_CHUNK_X = TunableConfig::CHUNK_DIM_X / SCALE_DIM; + +constexpr int SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; +constexpr int SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; + +constexpr int STAGES_Y = TILES_Y; +constexpr int STAGES_X = TILES_X; +constexpr int STAGES = STAGES_Y * STAGES_X; + +constexpr int BUFFS_NUM = TunableConfig::PREFETCH_STAGES + 1; +constexpr int BUFFS_NUM_IN = BUFFS_NUM; +constexpr int BUFFS_NUM_OUT = BUFFS_NUM; +constexpr int BUFFS_NUM_OUT_TR = 2; +constexpr int BUFF_DIM_Y = TILE_DIM_Y; +constexpr int BUFF_DIM_X = TILE_DIM_X; +constexpr int BUFF_SIZE = BUFF_DIM_Y * BUFF_DIM_X; +constexpr int BUFF_SIZE_TOTAL = BUFF_SIZE * BUFFS_NUM; + +// Input buffer (BF16) +constexpr int BUFF_IN_DIM_Y = BUFF_DIM_Y; +constexpr int BUFF_IN_DIM_X = BUFF_DIM_X; +constexpr int BUFF_IN_SIZE = BUFF_IN_DIM_Y * BUFF_IN_DIM_X; +constexpr int BUFF_IN_ELTS_NUM = BUFF_IN_DIM_Y * BUFF_IN_DIM_X; + +// Output buffer (NVFP4) +constexpr int BUFF_OUT_DIM_Y = BUFF_DIM_Y; +constexpr int BUFF_OUT_DIM_X = (BUFF_DIM_X * 4) / 8; +constexpr int BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X; + +// Output transpose buffer (NVFP4) +constexpr int BUFF_OUT_TR_DIM_Y = BUFF_DIM_X; +constexpr int BUFF_OUT_TR_DIM_X = (BUFF_DIM_Y * 4) / 8; +constexpr int BUFF_OUT_TR_SIZE = BUFF_OUT_TR_DIM_Y * BUFF_OUT_TR_DIM_X; + +// Manual swizzling parameters to reduce SHMEM bank conflicts +constexpr int PACK_SIZE = 8; +constexpr int WAVES = ELTS_PER_THREAD / PACK_SIZE; + +constexpr int THREADS_X_ROWWISE = TILE_DIM_X / ELTS_PER_THREAD; +constexpr int THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; + +constexpr int THREADS_X_TR = TILE_DIM_X / 2; +constexpr int THREADS_Y_TR = THREADS_NUM / THREADS_X_TR; + +constexpr int ITERATIONS_NORMAL = BUFF_DIM_Y / THREADS_Y_ROWWISE; +constexpr int ITERATIONS_TR = SCALES_PER_TILE_Y / THREADS_Y_TR; +static_assert(ITERATIONS_TR >= 1 && "Number of transpose iterations should be >=1\0"); +static_assert((SCALES_PER_TILE_Y % THREADS_Y_TR == 0) && + "Partial transpose iterations are not supported\0"); + +constexpr int BUFF_OUT_IT_OFFSET = BUFF_OUT_TR_DIM_X / ITERATIONS_TR / STAGES; + +static_assert(BUFF_DIM_Y >= SCALE_DIM && + "Number of buffer rows must be greater or equal to the size of the columwise " + "scaling block\0"); +static_assert(TunableConfig::CHUNK_DIM_Y >= BUFF_DIM_Y); +static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE && + "Number of buffer rows must be greater or equal to the number of rowwise " + "processing threads in Y dimension\0"); + +// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory +constexpr int TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr int THREADS_PER_BANK = TOTAL_BANKS_WIDTH / ELTS_PER_THREAD; + +using IType = bf16; +using IType2 = typename ptx::FPx2; +using IType3D = IType[BUFFS_NUM_IN][BUFF_IN_DIM_Y][BUFF_IN_DIM_X]; +using IType2x3D = IType2[BUFFS_NUM_IN][BUFF_IN_DIM_Y][BUFF_IN_DIM_X / 2]; +using OType2x3D = fp4e2m1x2[BUFFS_NUM_OUT][BUFF_OUT_DIM_Y][BUFF_OUT_DIM_X]; +using OType2xt3D = fp4e2m1x2[BUFFS_NUM_OUT_TR][BUFF_OUT_TR_DIM_Y][BUFF_OUT_TR_DIM_X]; +using ScalesType2D = nvfp4_scale_t[TunableConfig::CHUNK_DIM_Y][SCALES_PER_CHUNK_X]; +using ScalesTypeTr2D = nvfp4_scale_t[TunableConfig::CHUNK_DIM_X][SCALES_PER_CHUNK_Y]; +using RNG_t = typename transformer_engine::curanddx::detail::philox4x32_native_state<10>; + +template +struct SCALING_COEFFICIENT_TYPE {}; +template <> +struct SCALING_COEFFICIENT_TYPE { + using type = float; +}; +template <> +struct SCALING_COEFFICIENT_TYPE { + using type = bf16; +}; + +__device__ __forceinline__ float get_amax_of_pair(const IType2 pair) { + return static_cast(__hmax(__habs(pair.x), __habs(pair.y))); +} + +// Compute "correct" per-block encoding scaling factor +template +__device__ __forceinline__ SF_TYPE +compute_nvfp4_scaling_coefficient(const nvfp4_scale_t S_dec_block, const float S_enc) { + constexpr float float_max = detail::TypeExtrema::max; + const float scale_rcp = fminf(S_enc / static_cast(S_dec_block), float_max); + return static_cast(scale_rcp); +} + +template +__device__ __forceinline__ void colwise_scaling(const IType *__restrict__ sIn_ptr, + fp4e2m1x2 *__restrict__ sOut_tr_ptr, + nvfp4_scale_t *__restrict__ sSFcolwise_ptr, + const float S_enc_colwise, const int stage_Y, + const int stage_X, const int buff_in, + const int buff_out_tr, RNG_t &rng, + uint4 &random_uint4, int &rnd_idx) { + using scaling_coeff_type = typename SCALING_COEFFICIENT_TYPE::type; + + const auto &sIn2x = *reinterpret_cast(sIn_ptr); + auto &sOut_tr = *reinterpret_cast(sOut_tr_ptr); + auto &sSFcolwise = *reinterpret_cast(sSFcolwise_ptr); + + const int warp = threadIdx.x / THREADS_PER_WARP; + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + + const int tid_Y_colwise = (thread_lane % 4 + warp) % 4; + const int tid_X_colwise = thread_lane; + + const int thread_offset_Y_colwise = tid_Y_colwise * SCALE_DIM; + const int thread_offset_X_colwise = tid_X_colwise * 2; + + const int in_thread_offset_Y = thread_offset_Y_colwise; + const int in_thread_offset_X = thread_offset_X_colwise / 2; + + const int out_tr_thread_offset_Y = thread_offset_X_colwise; + const int out_tr_thread_offset_X = thread_offset_Y_colwise / 2; + + const int scale_tr_offset_Y = (stage_X * TILE_DIM_X) + 2 * tid_X_colwise; + const int scale_tr_offset_X = (stage_Y * SCALES_PER_TILE_Y) + tid_Y_colwise; + + __align__(8) IType rIn[2][SCALE_DIM]; + // Read (cache) a pair of input elements (S2R). Find NVFP4-block AMAX + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + const IType2 elt_pair = + ptx::ld_shared_b32(&sIn2x[buff_in][in_thread_offset_Y + i][in_thread_offset_X]); + rIn[0][i] = elt_pair.x; + rIn[1][i] = elt_pair.y; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, elt_pair); + } + const float block_amax[2] = {static_cast(__habs(thread_amax_2x.x)), + static_cast(__habs(thread_amax_2x.y))}; +#pragma unroll + for (int w = 0; w < 2; ++w) { + const nvfp4_scale_t S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax[w], S_enc_colwise); + + // Store scaling factors to SMEM buffer (R2S) + sSFcolwise[scale_tr_offset_Y + w][scale_tr_offset_X] = S_dec_b_fp8; + + const scaling_coeff_type SFcoefficient = + compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_colwise); + + // Scale elements + __align__(8) uint32_t rOut[SCALE_DIM / 8]; +#pragma unroll + for (int e = 0; e < SCALE_DIM / 8; ++e) { + const uint64_t elts03 = *reinterpret_cast(&rIn[w][8 * e]); + const uint64_t elts47 = *reinterpret_cast(&rIn[w][8 * e + 4]); + if constexpr (USE_STOCHASTIC_ROUNDING) { + const uint32_t rbits03 = core::get_rbits(rng, random_uint4, rnd_idx); + const uint32_t rbits47 = core::get_rbits(rng, random_uint4, rnd_idx); + rOut[e] = ptx::mul_cvt_bf16_to_fp4_8x_stochastic_rounding( + elts03, elts47, SFcoefficient, rbits03, rbits47); + } else { + rOut[e] = ptx::mul_cvt_bf16_to_fp4_8x_round_to_nearest(elts03, elts47, + SFcoefficient); + } + } + uint64_t &out_pack_16x = *reinterpret_cast(rOut); + ptx::st_shared_b64(&sOut_tr[buff_out_tr][out_tr_thread_offset_Y + w][out_tr_thread_offset_X], + out_pack_16x); + } +} + +template +__device__ __forceinline__ void rowwise_scaling(const IType *__restrict__ sIn_ptr, + fp4e2m1x2 *__restrict__ sOut_ptr, + nvfp4_scale_t *__restrict__ sSFrowwise_ptr, + const float S_enc_rowwise, const int stage_Y, + const int stage_X, const int buff_in, + const int buff_out, RNG_t &rng, uint4 &random_uint4, + int &rnd_idx) { + using scaling_coeff_type = typename SCALING_COEFFICIENT_TYPE::type; + + const auto &sIn = *reinterpret_cast(sIn_ptr); + auto &sOut = *reinterpret_cast(sOut_ptr); + auto &sSFrowwise = *reinterpret_cast(sSFrowwise_ptr); + + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + + const int thread_offset_Y_rowwise = tid_Y_rowwise; + const int thread_offset_X_rowwise = tid_X_rowwise * ELTS_PER_THREAD; + + const int SF_thread_offset_rowwise_Y = tid_Y_rowwise; + const int SF_thread_offset_rowwise_X = tid_X_rowwise / THREADS_PER_SCALE_ROWWISE; + + const bool SF_storing_thread = (tid_X_rowwise % THREADS_PER_SCALE_ROWWISE == 0); + + const int stage_rowwise_scales_offset_Y = SF_thread_offset_rowwise_Y + stage_Y * TILE_DIM_Y; + const int stage_rowwise_scales_offset_X = + SF_thread_offset_rowwise_X + stage_X * SCALES_PER_TILE_X; +#pragma unroll + for (int it = 0; it < ITERATIONS_NORMAL; ++it) { + const int it_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; + + __align__(16) IType2 rIn[WAVES][PACK_SIZE / 2]; + + // Read (cache) input elements (S2R). Find NVFP4-block AMAX + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % ELTS_PER_THREAD; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + + // Load elements + __uint128_t &elts_8x = *reinterpret_cast<__uint128_t *>(&rIn[w]); + elts_8x = ptx::ld_shared_b128(&sIn[buff_in][it_offset_Y_rowwise][swizzled_thread_idx]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, rIn[w][e]); + } + } + const float block_amax = get_amax_of_pair(thread_amax_2x); + + const nvfp4_scale_t S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + const scaling_coeff_type SFcoefficient = + compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_rowwise); + + // Store scaling factors to SMEM buffer (R2S) + if (SF_storing_thread) { + const int scales_offset_Y = stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE; + const int scales_offset_X = stage_rowwise_scales_offset_X; + sSFrowwise[scales_offset_Y][scales_offset_X] = S_dec_b_fp8; + } + +// Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const uint64_t elts03 = *reinterpret_cast(&rIn[w][0]); + const uint64_t elts47 = *reinterpret_cast(&rIn[w][2]); + + uint32_t out_x8; + if constexpr (USE_STOCHASTIC_ROUNDING) { + const uint32_t rbits03 = core::get_rbits(rng, random_uint4, rnd_idx); + const uint32_t rbits47 = core::get_rbits(rng, random_uint4, rnd_idx); + out_x8 = ptx::mul_cvt_bf16_to_fp4_8x_stochastic_rounding( + elts03, elts47, SFcoefficient, rbits03, rbits47); + } else { + out_x8 = ptx::mul_cvt_bf16_to_fp4_8x_round_to_nearest(elts03, elts47, + SFcoefficient); + } + + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % ELTS_PER_THREAD; + const int swizzled_idx = (swizzled_group_idx + thread_offset_X_rowwise) / 2; + ptx::st_shared_b32(&sOut[buff_out][it_offset_Y_rowwise][swizzled_idx], out_x8); + } + } +} + +template +__global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D_kernel( + const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + const __grid_constant__ CUtensorMap tensor_map_output_t, nvfp4_scale_t *const scales_ptr, + nvfp4_scale_t *const scales_t_ptr, const float *noop, const float *const amax_rowwise_ptr, + const float *const amax_colwise_ptr, const size_t rows, const size_t cols, + const size_t scale_stride, const size_t scale_stride_t, const size_t *rng_state) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + + const size_t rng_sequence = + threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM; + const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; + const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; + RNG_t rng; + rng.init(rng_seed, rng_sequence, rng_offset); + uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? rng.generate4() : uint4{0, 0, 0, 0}; + // Index of the random number. It increments each time when used and resets to 0 if reaches 4x + int rnd_idx = 0; + + const bool leading_thread = (threadIdx.x == 0); + + constexpr int buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr int buff_elems_total_in = BUFFS_NUM_IN * buff_elems; + + constexpr int buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out_t = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_TR_SIZE, TMA_SHMEM_ALIGNMENT); + + constexpr int in_mem = buff_size_aligned_in; + + constexpr int out_mem_rowwise_data = buff_size_aligned_out; + constexpr int out_mem_colwise_data = RETURN_TRANSPOSE ? buff_size_aligned_out_t : 0; + constexpr int out_mem_rowwise_scales = DIVUP_TO_MULTIPLE( + TunableConfig::CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + extern __shared__ unsigned char dynamic_shmem[]; + unsigned char *dshmem = common::align_smem_ptr_per_TMA_requirements(dynamic_shmem); + + IType *sIn_ptr = reinterpret_cast(dshmem); + fp4e2m1x2 *sOut_ptr = reinterpret_cast(dshmem + in_mem); + fp4e2m1x2 *sOut_tr_ptr = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); + + auto &sIn = *reinterpret_cast(sIn_ptr); + auto &sOut = *reinterpret_cast(sOut_ptr); + auto &sOut_tr = *reinterpret_cast(sOut_tr_ptr); + + nvfp4_scale_t *sSFrowwise_ptr = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); + nvfp4_scale_t *sSFcolwise_ptr = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); + + auto &sSFrowwise = *reinterpret_cast(sSFrowwise_ptr); + auto &sSFcolwise = *reinterpret_cast(sSFcolwise_ptr); + + constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + // Compute a global encoding/decoding scaling factors for all S_dec_b + const float S_enc_rowwise = + (amax_rowwise_ptr == nullptr) + ? 1.0f + : core::compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); + + const float S_enc_colwise = + (amax_colwise_ptr == nullptr) + ? S_enc_rowwise + : core::compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); + + __shared__ uint64_t workID_mbar; + __shared__ __uint128_t workID_response; + constexpr uint32_t workID_response_size = sizeof(workID_response); + static_assert(workID_response_size == 16); + + __shared__ uint64_t IN_buff_readable_mbar[BUFFS_NUM]; + + // Coordinates of the first chunk (CTA) to process + int32_t ctaid_X = blockIdx.x; + int32_t ctaid_Y = blockIdx.y; + + // Initialize shared memory barriers with the number of threads participating in them + if (leading_thread) { +#pragma unroll + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_init(&IN_buff_readable_mbar[buff], 1); + } + ptx::mbarrier_init(&workID_mbar, 1); + ptx::fence_proxy_async_shared_cta(); + } + __syncthreads(); + + bool job_finished = false; + int buff_in = 0; + int buff_out = 0; + int buff_out_tr = 0; + int IN_buff_readable_parity[BUFFS_NUM] = {0, 0}; + int ctaid_parity = 0; + +// Prefetch input data only when processing the first chunk, +// which enables the one-iteration overlap throughout the entire kernel life +#pragma unroll + for (int stage = 0; stage < TunableConfig::PREFETCH_STAGES; ++stage) { + const int buff_in = stage; + const int stage_Y = stage / STAGES_X; + const int stage_X = stage % STAGES_X; + + const int stage_offset_Y = stage_Y * TILE_DIM_Y; + const int stage_offset_X = stage_X * TILE_DIM_X; + + const int block_offset_Y = ctaid_Y * TunableConfig::CHUNK_DIM_Y; + const int block_offset_X = ctaid_X * TunableConfig::CHUNK_DIM_X; + + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X + stage_offset_X; + + uint64_t *barrier = &IN_buff_readable_mbar[buff_in]; + if (leading_thread) { + uint64_t *dst = reinterpret_cast(&sIn[buff_in]); + const uint64_t *src = reinterpret_cast(&tensor_map_input); + + // Arrive on the barrier and tell how many bytes are expected to come in + ptx::mbarrier_arrive_expect_tx(barrier, shmem_buff_size); + + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared(dst, src, global_offset_X, global_offset_Y, + barrier); + } + } + + while (!job_finished) { + const int block_offset_Y = ctaid_Y * TunableConfig::CHUNK_DIM_Y; + const int block_offset_X = ctaid_X * TunableConfig::CHUNK_DIM_X; + + const int block_offset_Y_tr = ctaid_X * TunableConfig::CHUNK_DIM_X; + const int block_offset_X_tr = ctaid_Y * TunableConfig::CHUNK_DIM_Y; + + const int chunk_rows = rows - block_offset_Y; + const int chunk_cols = cols - block_offset_X; + + const int scales_block_offset_Y_rowwise = ctaid_Y * TunableConfig::CHUNK_DIM_Y; + const int scales_block_offset_X_rowwise = ctaid_X * SCALES_PER_CHUNK_X; + const int scales_block_offset_Y_tr = ctaid_X * TunableConfig::CHUNK_DIM_X; + const int scales_block_offset_X_tr = ctaid_Y * SCALES_PER_CHUNK_Y; + + if constexpr (TunableConfig::PERSISTENT) { + if (leading_thread) { + ptx::mbarrier_arrive_expect_tx_cta_relaxed_shared_cta(&workID_mbar, workID_response_size); + ptx::try_cancel_cta(&workID_mbar, &workID_response); + } + } + +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const int stage_Y = stage / STAGES_X; + const int stage_X = stage % STAGES_X; + + const int stage_offset_Y = stage_Y * TILE_DIM_Y; + const int stage_offset_X = stage_X * TILE_DIM_X; + + if (stage == STAGES - TunableConfig::PREFETCH_STAGES) { + if constexpr (TunableConfig::PERSISTENT) { + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&workID_mbar, ctaid_parity); + ptx::get_cancelled_cta_id_2D(&workID_response, ctaid_X, ctaid_Y); + ctaid_parity ^= 1; + } else { + ctaid_X = -1; + ctaid_Y = -1; + } + if (ctaid_X == -1 && ctaid_Y == -1) { + job_finished = true; + } + } + + // Prefetch next stage Input data + if (!job_finished || (stage < STAGES - TunableConfig::PREFETCH_STAGES)) { + const int next_prefetch_buff = (buff_in + TunableConfig::PREFETCH_STAGES) % BUFFS_NUM; + const int next_prefetch_stage = (stage + TunableConfig::PREFETCH_STAGES) % STAGES; + const int next_prefetch_stage_Y = next_prefetch_stage / STAGES_X; + const int next_prefetch_stage_X = next_prefetch_stage % STAGES_X; + + const int next_prefetch_stage_offset_Y = next_prefetch_stage_Y * TILE_DIM_Y; + const int next_prefetch_stage_offset_X = next_prefetch_stage_X * TILE_DIM_X; + + // Offsets change, because coordinates of the next "to-be-prefetched" CTA do also chage + const int block_offset_Y = ctaid_Y * TunableConfig::CHUNK_DIM_Y; + const int block_offset_X = ctaid_X * TunableConfig::CHUNK_DIM_X; + + const int global_offset_Y = block_offset_Y + next_prefetch_stage_offset_Y; + const int global_offset_X = block_offset_X + next_prefetch_stage_offset_X; + + uint64_t *barrier = &IN_buff_readable_mbar[next_prefetch_buff]; + if (leading_thread) { + uint64_t *dst = reinterpret_cast(&sIn[next_prefetch_buff]); + const uint64_t *src = reinterpret_cast(&tensor_map_input); + + // Arrive on the barrier and tell how many bytes are expected to come in + ptx::mbarrier_arrive_expect_tx(barrier, shmem_buff_size); + + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared(dst, src, global_offset_X, global_offset_Y, + barrier); + } + ptx::fence_proxy_async_shared_cta(); + } + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], + IN_buff_readable_parity[buff_in]); + IN_buff_readable_parity[buff_in] ^= 1; + + // Wait for TMA transfer to have finished reading shared memory + // I.e. the OUT buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read(); + + // NVFP4 Quantization + rowwise_scaling( + sIn_ptr, sOut_ptr, sSFrowwise_ptr, S_enc_rowwise, stage_Y, stage_X, buff_in, buff_out, + rng, random_uint4, rnd_idx); + + if constexpr (RETURN_TRANSPOSE) { + colwise_scaling( + sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, S_enc_colwise, stage_Y, stage_X, buff_in, + buff_out_tr, rng, random_uint4, rnd_idx); + } + + // Wait for shared memory writes to be visible to TMA engine + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine + + // Initiate TMA transfer to copy shared memory to global memory + if (leading_thread) { + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X + stage_offset_X; + const int global_offset_Y_tr = block_offset_Y_tr + stage_offset_X; + const int global_offset_X_tr = block_offset_X_tr + stage_offset_Y; + + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), global_offset_X, + global_offset_Y, reinterpret_cast(&sOut[buff_out])); + + if constexpr (RETURN_TRANSPOSE) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_t), global_offset_X_tr, + global_offset_Y_tr, reinterpret_cast(&sOut_tr[buff_out_tr])); + } + + // Create a "bulk async-group" out of the previous bulk copy operation + ptx::cp_async_bulk_commit_group(); + } + + buff_in = (buff_in + 1) % BUFFS_NUM_IN; + buff_out = (buff_out + 1) % BUFFS_NUM_OUT; + buff_out_tr = (buff_out_tr + 1) % BUFFS_NUM_OUT_TR; + } // end of stages + + // Vectorized store of scaling factors (S2G) + { + // Rowwise + { + using ScalesVec = Vec; + // number of scales in X dimension of this chunk + const int count = min(SCALES_PER_CHUNK_X, chunk_cols / SCALE_DIM); + + for (size_t row = threadIdx.x; row < TunableConfig::CHUNK_DIM_Y; row += THREADS_NUM) { + const size_t row_global = scales_block_offset_Y_rowwise + row; + if (row_global < rows) { + ScalesVec &scales_vec = *reinterpret_cast(sSFrowwise[row]); + const size_t scale_idx_global = + row_global * scale_stride + scales_block_offset_X_rowwise; + scales_vec.store_to_elts(&scales_ptr[scale_idx_global], 0, count); + } + } + } + + // Colwise + if constexpr (RETURN_TRANSPOSE) { + using ScalesVec = Vec; + // number of scales in Y dimension of this chunk + const int count = min(SCALES_PER_CHUNK_Y, chunk_rows / SCALE_DIM); + + for (size_t row_tr = threadIdx.x; row_tr < TunableConfig::CHUNK_DIM_X; + row_tr += THREADS_NUM) { + const size_t row_tr_global = scales_block_offset_Y_tr + row_tr; + if (row_tr_global < cols) { + ScalesVec &scales_vec = *reinterpret_cast(sSFcolwise[row_tr]); + const size_t scale_idx_global = + row_tr_global * scale_stride_t + scales_block_offset_X_tr; + scales_vec.store_to_elts(&scales_t_ptr[scale_idx_global], 0, count); + } + } + } + + if (!job_finished) { + // Ensures all reads from SFs buffer have completed and it's ready to be reused + __syncthreads(); + } + } + } + + if (leading_thread) { +#pragma unroll + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_invalid(&IN_buff_readable_mbar[buff]); + } + ptx::mbarrier_invalid(&workID_mbar); + } +#else + NVTE_DEVICE_ERROR("sm_100 or higher is required."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +#endif // FP4_TYPE_SUPPORTED +} // namespace quantize_transpose_tuned_kernel + +inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, Tensor *output, + const QuantizationConfig *quant_config, + cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + using namespace quantize_transpose_tuned_kernel; + using namespace ptx; + + const bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; + const bool use_fast_math = quant_config ? quant_config->use_fast_math : false; + + // If transposed output is allocated, return the transposed data + // Otherwise, it's not necesary to return the transposed data. + const bool return_transpose = output->has_columnwise_data(); + + checkCuDriverContext(stream); + CheckNoopTensor(*noop, "cast_noop"); + CheckInputTensor(input, "input"); + CheckOutputTensor(*output, "output", false); + + NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); + NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated."); + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + + if (return_transpose) { + NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), + "Transposed output must have FP4 type."); + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Transposed scaling tensor must be allocated"); + } + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + + NVTE_CHECK(rows % 32 == 0, + "Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA + NVTE_CHECK(cols % 32 == 0, + "Number of tensor cols must be a multiple of 32"); // 16B alignment for TMA + + const int blocks_Y = DIVUP(rows, static_cast(TunableConfig::CHUNK_DIM_Y)); + const int blocks_X = DIVUP(cols, static_cast(TunableConfig::CHUNK_DIM_X)); + const dim3 grid(blocks_X, blocks_Y); + const int block_size = THREADS_NUM; + + const size_t scale_stride = output->scale_inv.shape[1]; + const size_t scale_stride_transpose = + return_transpose ? output->columnwise_scale_inv.shape[1] : 0; + + nvfp4_scale_t *const scales_ptr = reinterpret_cast(output->scale_inv.dptr); + nvfp4_scale_t *const scales_transpose_ptr = + reinterpret_cast(output->columnwise_scale_inv.dptr); + + const float *noop_ptr = reinterpret_cast(noop->data.dptr); + const float *const amax_rowwise_ptr = reinterpret_cast(output->amax.dptr); + const float *const amax_colwise_ptr = + reinterpret_cast(output->columnwise_amax.dptr); + + const NVTETensor rng_state_tensor = (quant_config != nullptr) ? quant_config->rng_state : nullptr; + const size_t *rng_state = nullptr; + if (rng_state_tensor != nullptr) { + Tensor &rng_state_te_tensor = *convertNVTETensor(rng_state_tensor); + NVTE_CHECK(rng_state_te_tensor.dtype() == DType::kInt64, + "RNG state should contain 2 64-bit values."); + NVTE_CHECK(rng_state_te_tensor.data.shape == std::vector{2}, + "Shape of the RNG state should be [2], but got ", rng_state_te_tensor.data.shape); + rng_state = reinterpret_cast(rng_state_te_tensor.data.dptr); + } + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_output{}; + alignas(64) CUtensorMap tensor_map_output_transpose{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, + sizeof(IType) * 8); + + create_2D_tensor_map(tensor_map_output, output->data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, + 4); + if (return_transpose) { + create_2D_tensor_map(tensor_map_output_transpose, output->columnwise_data, cols, rows, + BUFF_DIM_X, BUFF_DIM_Y, rows, 0, 4); + } + + constexpr int buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr int buff_elems_total_in = BUFFS_NUM_IN * buff_elems; + constexpr int buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out_t = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_TR_SIZE, TMA_SHMEM_ALIGNMENT); + + constexpr int buff_size_scales = DIVUP_TO_MULTIPLE( + TunableConfig::CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_scales_transpose = DIVUP_TO_MULTIPLE( + TunableConfig::CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); + + const int in_mem = buff_size_aligned_in; + + const int out_data_mem = buff_size_aligned_out; + const int out_data_transpose_mem = return_transpose ? buff_size_aligned_out_t : 0; + const int out_scales_mem = buff_size_scales; + const int out_scales_transpose_mem = return_transpose ? buff_size_scales_transpose : 0; + + const int out_mem = out_data_mem + out_data_transpose_mem; + + const int dshmem_size = + in_mem + out_mem + out_scales_transpose_mem + out_scales_mem + TMA_SHMEM_ALIGNMENT; + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_fast_math, USE_FAST_MATH, + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + auto kernel = quantize_transpose_nvfp4_tuned_1D_kernel; + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, + scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, + scale_stride, scale_stride_transpose, rng_state); + }););); +#else + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif // FP4_TYPE_SUPPORTED +} + +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_TUNED_1D_CUH_ diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 4cdd8297a..9bcf6e228 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -164,6 +164,18 @@ __device__ __forceinline__ void mbarrier_arrive_expect_tx(uint64_t *mbar, const #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +__device__ __forceinline__ void mbarrier_arrive_expect_tx_cta_relaxed_shared_cta( + uint64_t *mbar, const uint32_t tx_count) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.arrive.expect_tx.relaxed.cta.shared::cta.b64 _, [%0], %1;" ::"r"(mbar_ptr), + "r"(tx_count)); +#else + NVTE_DEVICE_ERROR( + "mbarrier_arrive_expect_tx_cta_relaxed_shared_cta is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + __device__ __forceinline__ void fence_mbarrier_init_release_cluster() { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) asm volatile("fence.mbarrier_init.release.cluster;"); @@ -243,6 +255,75 @@ __device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint3 #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +__device__ __forceinline__ void mbarrier_wait_parity_acquire_cta_shared_cta(uint64_t *mbar, + uint32_t phase_parity) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile( + "{\n\t" + ".reg .b64 r1; \n\t" + ".reg .pred waitComplete; \n\t" // predicate representing if barrier condition is met + "WAIT: \n\t" // loop around barrier wait + "mbarrier.try_wait.parity.acquire.cta.shared::cta.b64 waitComplete, [%0], %1; \n\t" + "@waitComplete bra DONE; \n\t" // mbarrier conditions are met + "bra WAIT; \n\t" // just a time-out, try again + "DONE: \n\t" + "}\n\t" + : + : "r"(mbar_ptr), "r"(phase_parity) + : "memory"); +#else + NVTE_DEVICE_ERROR("mbarrier_wait_parity_acquire_cta_shared_cta is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +__device__ __forceinline__ void try_cancel_cta(uint64_t *mbar, __uint128_t *response_data_ptr) { + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; + if constexpr (is_blackwell) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + uint32_t workID_response = __cvta_generic_to_shared(response_data_ptr); + asm volatile( + "clusterlaunchcontrol.try_cancel.async.mbarrier::complete_tx::bytes.multicast::cluster::" + "all.b128 " + "[%0], [%1];" ::"r"(workID_response), + "r"(mbar_ptr)); + } else { + NVTE_DEVICE_ERROR( + "Cluster Launch Control PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } +} + +__device__ __forceinline__ void get_cancelled_cta_id_2D(__uint128_t *response_data_ptr, + int32_t &ctaid_X, int32_t &ctaid_Y) { + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; + if constexpr (is_blackwell) { + uint32_t workID_response = __cvta_generic_to_shared(response_data_ptr); + asm volatile( + "{\n\t" + ".reg .s32 x_ctaid; \n\t" + ".reg .s32 y_ctaid; \n\t" + "mov .s32 x_ctaid, -1; \n\t" + "mov .s32 y_ctaid, -1; \n\t" + ".reg.b128 try_cancel_response; \n\t" + "ld.shared.b128 try_cancel_response, [%2]; \n\t" + ".reg .pred P1; \n\t" + "clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 P1, try_cancel_response; \n\t" + "@P1 clusterlaunchcontrol.query_cancel.get_first_ctaid.v4.b32.b128 {x_ctaid, y_ctaid, _, " + "_}, try_cancel_response; \n\t" + "mov .s32 %0, x_ctaid; \n\t" + "mov .s32 %1, y_ctaid; \n\t" + "}\n\t" + : "=r"(ctaid_X), "=r"(ctaid_Y) + : "r"(workID_response) + : "memory"); + } else { + NVTE_DEVICE_ERROR( + "Cluster Launch Control PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } +} + constexpr uint32_t FP32_MANTISSA_BITS = 23; constexpr uint32_t FP32_EXPONENT_BIAS = 127; @@ -657,6 +738,179 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, c return mul_cvt_fp32_to_fp4_4x_with_rn(in01, in23, scale, rbits); } } + +template +__device__ __forceinline__ uint32_t mul_cvt_bf16_to_fp4_8x_round_to_nearest( + const uint64_t in03, const uint64_t in47, const SCALING_COEFFICIENT_TYPE scaling_coefficient) { + uint32_t out_8x = 0; + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; + if constexpr (is_blackwell) { + if constexpr (std::is_same::value) { + asm volatile( + "{\n" + ".reg.f32 zero; \n\t" + "mov.b32 zero, 0; \n\t" + ".reg.b16 scaling_coeff; \n\t" + "mov.b16 scaling_coeff, %3; \n\t" + ".reg.b16 v0_h, v1_h, v2_h, v3_h, v4_h, v5_h, v6_h, v7_h; \n\t" + "mov.b64 {v0_h, v1_h, v2_h, v3_h}, %1; \n\t" + "mov.b64 {v4_h, v5_h, v6_h, v7_h}, %2; \n\t" + + ".reg.f32 v0, v1, v2, v3, v4, v5, v6, v7; \n\t" + "fma.rn.f32.bf16 v0, v0_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v1, v1_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v2, v2_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v3, v3_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v4, v4_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v5, v5_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v6, v6_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v7, v7_h, scaling_coeff, zero; \n\t" + + ".reg.b8 f0, f1, f2, f3; \n\t" + // Elements reordered to match e2m1x4 packing order (v1,v0) + "cvt.rn.satfinite.e2m1x2.f32 f0, v1, v0;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, v3, v2;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f2, v5, v4;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f3, v7, v6;\n\t" + "mov.b32 %0, {f0, f1, f2, f3};\n" + "}" + : "=r"(out_8x) + : "l"(in03), "l"(in47), "h"(reinterpret_cast(scaling_coefficient))); + } else if constexpr (std::is_same::value) { + asm volatile( + "{\n" + ".reg.b64 scaling_coeff_2x; \n\t" + "mov.b64 scaling_coeff_2x, {%3, %3}; \n\t" + ".reg.b16 v0_bf16, v1_bf16, v2_bf16, v3_bf16, v4_bf16, v5_bf16, v6_bf16, v7_bf16; \n\t" + "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16}, %1; \n\t" + "mov.b64 {v4_bf16, v5_bf16, v6_bf16, v7_bf16}, %2; \n\t" + + ".reg.b32 v0, v1, v2, v3, v4, v5, v6, v7; \n\t" + "cvt.f32.bf16 v0, v0_bf16; \n\t" + "cvt.f32.bf16 v1, v1_bf16; \n\t" + "cvt.f32.bf16 v2, v2_bf16; \n\t" + "cvt.f32.bf16 v3, v3_bf16; \n\t" + "cvt.f32.bf16 v4, v4_bf16; \n\t" + "cvt.f32.bf16 v5, v5_bf16; \n\t" + "cvt.f32.bf16 v6, v6_bf16; \n\t" + "cvt.f32.bf16 v7, v7_bf16; \n\t" + + ".reg.b64 v01, v23, v45, v67; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mov.b64 v45, {v4, v5}; \n\t" + "mov.b64 v67, {v6, v7}; \n\t" + "mul.f32x2 v01, v01, scaling_coeff_2x; \n\t" + "mul.f32x2 v23, v23, scaling_coeff_2x; \n\t" + "mul.f32x2 v45, v45, scaling_coeff_2x; \n\t" + "mul.f32x2 v67, v67, scaling_coeff_2x; \n\t" + // Elements reordered to match the packing order (v1,v0) + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "mov.b64 {v5, v4}, v45; \n\t" + "mov.b64 {v7, v6}, v67; \n\t" + + ".reg.b8 f0, f1, f2, f3; \n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f2, v4, v5;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f3, v6, v7;\n\t" + "mov.b32 %0, {f0, f1, f2, f3};\n\t" + "}" + : "=r"(out_8x) + : "l"(in03), "l"(in47), "f"(scaling_coefficient)); + } else { + NVTE_DEVICE_ERROR("Not supported scaling coefficient type."); + } + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } + return out_8x; +} + +template +__device__ __forceinline__ uint32_t mul_cvt_bf16_to_fp4_8x_stochastic_rounding( + const uint64_t in03, const uint64_t in47, const SCALING_COEFFICIENT_TYPE scaling_coefficient, + const uint32_t rbits03, const uint32_t rbits47) { + uint32_t out_8x = 0; + constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; + if constexpr (has_rs) { + if constexpr (std::is_same::value) { + asm volatile( + "{\n" + ".reg.f32 zero; \n\t" + "mov.b32 zero, 0; \n\t" + ".reg.b16 scaling_coeff; \n\t" + "mov.b16 scaling_coeff, %3; \n\t" + ".reg.b16 v0_h, v1_h, v2_h, v3_h, v4_h, v5_h, v6_h, v7_h; \n\t" + "mov.b64 {v0_h, v1_h, v2_h, v3_h}, %1; \n\t" + "mov.b64 {v4_h, v5_h, v6_h, v7_h}, %2; \n\t" + + ".reg.f32 v0, v1, v2, v3, v4, v5, v6, v7; \n\t" + "fma.rn.f32.bf16 v0, v0_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v1, v1_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v2, v2_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v3, v3_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v4, v4_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v5, v5_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v6, v6_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v7, v7_h, scaling_coeff, zero; \n\t" + + ".reg.b16 b03, b47; \n\t" + // Elements reordered to match e2m1x4 packing order (v3,v2,v1,v0) + "cvt.rs.satfinite.e2m1x4.f32 b03, {v3, v2, v1, v0}, %4; \n\t" + "cvt.rs.satfinite.e2m1x4.f32 b47, {v7, v6, v5, v4}, %5; \n\t" + "mov.b32 %0, {b03, b47};\n" + "}" + : "=r"(out_8x) + : "l"(in03), "l"(in47), "h"(reinterpret_cast(scaling_coefficient)), + "r"(rbits03), "r"(rbits47)); + } else if constexpr (std::is_same::value) { + asm volatile( + "{\n" + ".reg.b16 v0_bf16, v1_bf16, v2_bf16, v3_bf16, v4_bf16, v5_bf16, v6_bf16, v7_bf16; \n\t" + "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16}, %1; \n\t" + "mov.b64 {v4_bf16, v5_bf16, v6_bf16, v7_bf16}, %2; \n\t" + + ".reg.b32 v0, v1, v2, v3, v4, v5, v6, v7; \n\t" + "cvt.f32.bf16 v0, v0_bf16; \n\t" + "cvt.f32.bf16 v1, v1_bf16; \n\t" + "cvt.f32.bf16 v2, v2_bf16; \n\t" + "cvt.f32.bf16 v3, v3_bf16; \n\t" + "cvt.f32.bf16 v4, v4_bf16; \n\t" + "cvt.f32.bf16 v5, v5_bf16; \n\t" + "cvt.f32.bf16 v6, v6_bf16; \n\t" + "cvt.f32.bf16 v7, v7_bf16; \n\t" + + "mul.f32 v0, v0, %3; \n\t" + "mul.f32 v1, v1, %3; \n\t" + "mul.f32 v2, v2, %3; \n\t" + "mul.f32 v3, v3, %3; \n\t" + "mul.f32 v4, v4, %3; \n\t" + "mul.f32 v5, v5, %3; \n\t" + "mul.f32 v6, v6, %3; \n\t" + "mul.f32 v7, v7, %3; \n\t" + ".reg.b16 b03, b47; \n\t" + // Elements reordered to match e2m1x4 packing order (v3,v2,v1,v0) + "cvt.rs.satfinite.e2m1x4.f32 b03, {v3, v2, v1, v0}, %4; \n\t" + "cvt.rs.satfinite.e2m1x4.f32 b47, {v7, v6, v5, v4}, %5; \n\t" + "mov.b32 %0, {b03, b47};\n" + "}" + : "=r"(out_8x) + : "l"(in03), "l"(in47), "f"(scaling_coefficient), "r"(rbits03), "r"(rbits47)); + } else { + NVTE_DEVICE_ERROR("Not supported scaling coefficient type."); + } + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } + return out_8x; +} + #endif // FP4_TYPE_SUPPORTED // SIMD like "Fused" cast + multiplication (x2) @@ -1508,6 +1762,58 @@ __device__ __forceinline__ floatx4 up_cast(const bf16x4 &in) { return out; } +// Loads single BF16/FP16 element from shared memory state space +__device__ __forceinline__ bf16 ld_shared_b16(const bf16 *__restrict__ src_smem) { + const uint32_t src_smem_ptr = __cvta_generic_to_shared(src_smem); + bf16 dst; + asm volatile("ld.shared.b16 %0, [%1];" + : "=h"(reinterpret_cast(dst)) + : "r"(src_smem_ptr)); + return dst; +} + +// Loads pair of BF16/FP16 values from shared memory state space +__device__ __forceinline__ bf16x2 ld_shared_b32(const bf16x2 *__restrict__ src_smem) { + const uint32_t src_smem_ptr = __cvta_generic_to_shared(src_smem); + bf16x2 dst; + asm volatile("ld.shared.b32 %0, [%1];" + : "=r"(reinterpret_cast(dst)) + : "r"(src_smem_ptr)); + return dst; +} + +// Loads 8x BF16 values from shared memory state space +__device__ __forceinline__ __uint128_t ld_shared_b128(const bf16 *__restrict__ src_smem) { + uint64_t elts03, elts47; + const uint32_t src_smem_ptr = __cvta_generic_to_shared(src_smem); + asm volatile( + "{\n\t" + ".reg.b128 xy; \n\t" + "ld.shared.b128 xy, [%2]; \n\t" + "mov.b128 {%0, %1}, xy; \n" + "}\n" + : "=l"(elts03), "=l"(elts47) + : "r"(src_smem_ptr)); + return (static_cast<__uint128_t>(elts47) << 64) | static_cast<__uint128_t>(elts03); +} + +#if FP4_TYPE_SUPPORTED +// Vectorized store of x8 FP4 elements into shared memory state space +__device__ __forceinline__ void st_shared_b32(fp4e2m1x2 *__restrict__ dst_smem, + uint32_t fp4_pack_x8) { + const uint32_t dst_smem_ptr = __cvta_generic_to_shared(dst_smem); + asm volatile("st.shared.b32 [%0], %1;" : : "r"(dst_smem_ptr), "r"(fp4_pack_x8)); +} +#endif + +// Vectorized store of x16 FP4 elements into shared memory state space +#if FP4_TYPE_SUPPORTED +__device__ __forceinline__ void st_shared_b64(fp4e2m1x2 *__restrict__ dst_smem, + uint64_t fp4_pack_x16) { + const uint32_t dst_smem_ptr = __cvta_generic_to_shared(dst_smem); + asm volatile("st.shared.b64 [%0], %1;" : : "r"(dst_smem_ptr), "l"(fp4_pack_x16)); +} +#endif } // namespace ptx namespace { From 42e803d4b126d44dbc824fbd144ea76a1a189dc9 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com> Date: Wed, 21 Jan 2026 19:10:09 +0100 Subject: [PATCH 04/22] Fixed the year to 2026 (#2611) Signed-off-by: Oleg Goncharov --- .../nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh index af1b01d6b..411900168 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ From d759aa6412f4241e081e46fb7b5b4ec3d9ec54ee Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Wed, 21 Jan 2026 14:25:09 -0800 Subject: [PATCH 05/22] [pyTorch] CPU performance optimizations (#2439) * PoC of the changes Signed-off-by: Przemek Tredak * Early exit from the Free function for the empty tensor Signed-off-by: Przemek Tredak * Use the proper function for nvtx range Signed-off-by: Przemek Tredak * Only do mark_not_offload when the cpu_offloading is enabled Signed-off-by: Przemek Tredak * First pass on making the setattr issue not come back Signed-off-by: Przemek Tredak * Actually add pytest.ini Signed-off-by: Przemek Tredak * Changes to __init__ Signed-off-by: Przemek Tredak * A different way Signed-off-by: Przemek Tredak * WAR the fact that it is not possible to set __setattr__ dynamically Signed-off-by: Przemek Tredak * Simpler solution and fixes Signed-off-by: Przemek Tredak * Fix for the inference mode DPA Signed-off-by: Przemek Tredak * Start of debugging debug tools Signed-off-by: Przemek Tredak * More fixes in debug Signed-off-by: Przemek Tredak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Speculative moving the validate_name to the constructor Signed-off-by: Przemek Tredak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Signed-off-by: Przemek Tredak * Making the debug tools names saner Signed-off-by: Przemek Tredak * Change the setattr usage in the tensor parallel group setting Signed-off-by: Przemek Tredak * Adding try/finally - it does not seem to impact the time in observable way Signed-off-by: Przemek Tredak * Fixing lint issues and the thunder test Signed-off-by: Przemek Tredak * Fix 1 of the debug tests Signed-off-by: Przemek Tredak * Removed the warning and enforcement in the CI Signed-off-by: Przemek Tredak * try-finally in the context manager Signed-off-by: Przemek Tredak * Fixing the debug tests Signed-off-by: Przemek Tredak Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Przemek Tredak Signed-off-by: Pawel Gadzinski Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 2 +- tests/pytorch/debug/test_sanity.py | 25 ++- .../common/transformer_engine.cpp | 4 +- .../dot_product_attention.py | 21 +- .../pytorch/attention/multi_head_attention.py | 5 +- transformer_engine/pytorch/distributed.py | 8 +- transformer_engine/pytorch/module/base.py | 188 ++++++++++-------- .../pytorch/module/grouped_linear.py | 9 +- .../pytorch/module/layernorm_linear.py | 13 +- .../pytorch/module/layernorm_mlp.py | 13 +- transformer_engine/pytorch/module/linear.py | 14 +- transformer_engine/pytorch/transformer.py | 11 +- 12 files changed, 170 insertions(+), 143 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 9111d3511..6fe0ffdae 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -2790,7 +2790,7 @@ def forward( cu_seqlens, max_s, ) -> torch.Tensor: - with self.prepare_forward(inp, num_gemms=3) as inp: + with self.prepare_forward_ctx(inp, num_gemms=3) as inp: out = _custom_mha_fp8.apply( inp, self.qkv_weight, diff --git a/tests/pytorch/debug/test_sanity.py b/tests/pytorch/debug/test_sanity.py index aee5474e7..2bc4b3559 100644 --- a/tests/pytorch/debug/test_sanity.py +++ b/tests/pytorch/debug/test_sanity.py @@ -30,10 +30,17 @@ stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range] start_step : 0 end_step: 1 +""", + "log_fp8": """log_fp8: + layers: + layer_types: [linear] + enabled: + True + transformer_engine: LogFp8TensorStats: enabled: True tensors: [activation, gradient, weight] - stats: [underflows, overflows] + stats: [underflows%] start_step : 0 end_step: 1 """, @@ -46,22 +53,26 @@ FakeQuant: enabled: True gemms: [fprop, dgrad, wgrad] + tensors: [activation, weight, gradient] quant_format: FP8E5M2 """, } +# Configs that require FP8 to be enabled +fp8_required_configs = {"log_fp8"} + def _get_model(model_key): if model_key == "linear": - return te.Linear(D, D) + return te.Linear(D, D, name="layer") if model_key == "layernorm_linear": - return te.LayerNormLinear(D, D) + return te.LayerNormLinear(D, D, name="layer") if model_key == "layernorm_mlp": - return te.LayerNormMLP(D, D, D) + return te.LayerNormMLP(D, D, D, name="layer") if model_key == "mha_attention": - return te.MultiheadAttention(D, H) + return te.MultiheadAttention(D, H, name="layer") if model_key == "transformer_layer": - return te.TransformerLayer(D, D, H) + return te.TransformerLayer(D, D, H, name="layer") def _run_forward_backward(model, fp8): @@ -95,4 +106,6 @@ def _run_test(model_key, fp8, config, feature_dirs, config_file, log_dir): def test_sanity_debug(model_key, fp8, config_key, feature_dirs): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) + if not fp8 and config_key in fp8_required_configs: + pytest.skip(f"Config '{config_key}' requires FP8") _run_test(model_key, fp8, configs[config_key], feature_dirs) diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 6880dd560..06971443d 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -454,9 +454,9 @@ class TensorAllocator { } void Free(NVTETensor t) { - std::lock_guard lock(mutex); uintptr_t index = reinterpret_cast(t); if (index == 0) return; + std::lock_guard lock(mutex); NVTE_CHECK(index <= memory.size(), "Invalid tensor."); free_list.push_back(index); // Clean up @@ -564,9 +564,9 @@ class GroupedTensorAllocator { } void Free(NVTEGroupedTensor t) { - std::lock_guard lock(mutex); uintptr_t index = reinterpret_cast(t); if (index == 0) return; + std::lock_guard lock(mutex); NVTE_CHECK(index <= memory.size(), "Invalid grouped tensor."); free_list.push_back(index); // Clean up diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 6e5a12a10..51ffbc2e4 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -676,9 +676,9 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # assume attention uses the same fp8_group as GEMMs fp8_group = FP8GlobalStateManager.get_fp8_group() - self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() - self.fp8 = FP8GlobalStateManager.is_fp8_enabled() - self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() + self.fast_setattr("fp8_parameters", FP8GlobalStateManager.with_fp8_parameters()) + self.fast_setattr("fp8", FP8GlobalStateManager.is_fp8_enabled()) + self.fast_setattr("fp8_calibration", FP8GlobalStateManager.is_fp8_calibration()) fp8_enabled = self.fp8 or self.fp8_calibration self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration if self.fp8_parameters or fp8_enabled: @@ -703,7 +703,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: ) else: # If fp8 isn't enabled, turn off and return. - self.fp8_initialized = False + self.fast_setattr("fp8_initialized", False) return if self.fp8_parameters and not self.fp8_initialized: @@ -721,7 +721,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # Allocate scales and amaxes self.init_fp8_meta_tensors(fp8_recipes) - self.fp8_initialized = True + self.fast_setattr("fp8_initialized", True) self.fp8_meta["recipe"] = fp8_recipe_dpa if fp8_recipe != fp8_recipe_dpa: @@ -1000,7 +1000,7 @@ def forward( cases. It is ignored for other backends and when context parallelism is enabled. """ - with self.prepare_forward( + with self.prepare_forward_ctx( query_layer, num_gemms=3, allow_non_contiguous=True, @@ -1145,10 +1145,11 @@ def forward( if attn_mask_type == "padding_causal": attn_mask_type = attn_mask_type + "_bottom_right" - self.attention_type = "cross" - self.flash_attention.attention_type = self.attention_type - self.fused_attention.attention_type = self.attention_type - self.unfused_attention.attention_type = self.attention_type + if self.attention_type != "cross": + self.fast_setattr("attention_type", "cross") + self.flash_attention.attention_type = self.attention_type + self.fused_attention.attention_type = self.attention_type + self.unfused_attention.attention_type = self.attention_type query_layer, key_layer, value_layer = [ x.contiguous() if not x.is_contiguous() else x diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index f875fd1e0..d813e7c8f 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -8,7 +8,6 @@ from typing import Callable, List, Optional, Tuple, Union import torch -from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.pytorch.quantization import FP8GlobalStateManager from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.module.base import TransformerEngineBaseModule @@ -335,6 +334,7 @@ def __init__( self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups self.name = name + TransformerEngineBaseModule._validate_name(self) common_gemm_kwargs = { "fuse_wgrad_accumulation": fuse_wgrad_accumulation, @@ -739,9 +739,6 @@ def forward( core_attention_bias_type in AttnBiasTypes ), f"core_attention_bias_type {core_attention_bias_type} is not supported!" - if TEDebugState.debug_enabled: - TransformerEngineBaseModule._validate_name(self) - # ================================================= # Pre-allocate memory for key-value cache for inference # ================================================= diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 004a04ab4..f269e21b8 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -729,8 +729,8 @@ def checkpoint( if isinstance(function, TransformerEngineBaseModule): # If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need # to scatter/gather activations that we will recompute anyway. - setattr(function, "fsdp_wrapped", False) - setattr(function, "fsdp_group", None) + function.fast_setattr("fsdp_wrapped", False) + function.fast_setattr("fsdp_group", None) # Otherwise discard unused te.utils.checkpoint.checkpoint() arguments # and execute TE's own checkpointing @@ -2022,7 +2022,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: ) root_state = _get_module_fsdp_state(fsdp_root) assert root_state is not None, "Root module does not have a valid _FSDPState." - setattr(fsdp_root.module, "fsdp_group", root_state.process_group) + fsdp_root.module.fast_setattr("fsdp_group", root_state.process_group) # Iterate through all FSDP-wrapped submodules and inject FSDP information into TE modules fsdp_states, fsdp_modules = _get_fsdp_states_with_modules(fsdp_root) @@ -2033,7 +2033,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: "TE modules with primary weights in FP8 cannot be FSDP-wrapped. " "Please initialize your model without the te.quantized_model_init(...) context." ) - setattr(fsdp_module.module, "fsdp_group", state.process_group) + fsdp_module.module.fast_setattr("fsdp_group", state.process_group) class FullyShardedDataParallel(FSDP): diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 875d245a8..841cdf04c 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -10,9 +10,8 @@ import warnings from enum import Enum from abc import ABC, abstractmethod -from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Generator, List, Optional, Tuple, Union from contextlib import contextmanager -import logging from types import MethodType import torch @@ -50,6 +49,8 @@ is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype, get_nvtx_range_context, + nvtx_range_push, + nvtx_range_pop, ) from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ...common.recipe import DelayedScaling, Recipe @@ -605,10 +606,10 @@ def fill_userbuffers_buffer_for_all_gather( class TransformerEngineBaseModule(torch.nn.Module, ABC): """Base TE module.""" - def __init__(self) -> None: + def __init__(self, name: Optional[str] = None) -> None: super().__init__() assert torch.cuda.is_available(), "TransformerEngine needs CUDA." - self.name = None + self.name = name self.next_iter_when_debug_should_be_run = 0 self.fp8_initialized = False self.fp8 = False @@ -633,26 +634,22 @@ def __init__(self) -> None: if not TEDebugState.debug_enabled: TEDebugState.initialize() + self._validate_name() - # Names of attributes that can be set quickly (see __setattr__ - # method) - _fast_setattr_names: Set[str] = { - "activation_dtype", - "fp8", - "fp8_initialized", - "fp8_calibration", - "fp8_parameters", - } + def fast_setattr(self, name: str, value: Any) -> None: + """ + Fast version of the Module's set attribute function. + Should be used for regular attributes, but not properties nor parameters/buffers. + """ + self.__dict__[name] = value - def __setattr__(self, name: str, value: Any) -> None: - if name in TransformerEngineBaseModule._fast_setattr_names: - # torch.nn.Module has a custom __setattr__ that handles - # modules, parameters, and buffers. This is unnecessary - # overhead when setting plain attrs. - self.__dict__[name] = value - else: - # Default case - super().__setattr__(name, value) + def module_setattr(self, name: str, value: Any) -> None: + """ + Regular version of the Module's set attribute function. + Should be used only when the fast version cannot be used - for the properties, + parameters and buffers. + """ + super().__setattr__(name, value) def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None: """ @@ -773,7 +770,7 @@ def init_fp8_meta_tensors(self, recipe: Recipe) -> None: self.set_meta_tensor(True, recipe) self.set_meta_tensor(False, recipe) - self.fp8_meta_tensors_initialized = True + self.fast_setattr("fp8_meta_tensors_initialized", True) def get_fp8_meta_tensors(self) -> None: """Get scales and amaxes.""" @@ -930,7 +927,7 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: """Get activation data type for AMP.""" # Native AMP (`torch.autocast`) gets highest priority if torch.is_autocast_enabled(): - self.activation_dtype = torch_get_autocast_gpu_dtype() + self.fast_setattr("activation_dtype", torch_get_autocast_gpu_dtype()) return # All checks after this have already been performed once, thus skip @@ -945,7 +942,7 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: "Data types for parameters must match when outside of autocasted region. " f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" ) - self.activation_dtype = dtype + self.fast_setattr("activation_dtype", dtype) def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: """ @@ -957,8 +954,8 @@ def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> N tp_group : ProcessGroup, default = None tensor parallel process group. """ - self.tp_group = tp_group - self.tp_group_initialized = True + self.fast_setattr("tp_group", tp_group) + self.fast_setattr("tp_group_initialized", True) def _get_fp8_params(self) -> Union[List[torch.Tensor], None]: """returns the FP8 weights.""" @@ -974,48 +971,51 @@ def _get_fp8_params(self) -> Union[List[torch.Tensor], None]: # assume FP8 execution. def init_fp8_metadata(self, num_gemms: int = 1) -> None: """Initialize fp8 related metadata and tensors during fprop.""" - _original_recipe = self.fp8_meta.get("recipe", None) - - self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() - self.fp8 = FP8GlobalStateManager.is_fp8_enabled() - self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() - fp8_enabled = self.fp8 or self.fp8_calibration - self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration - - if self.fp8_parameters or fp8_enabled: - if ( - self.fp8_initialized - and FP8GlobalStateManager.get_fp8_recipe() == self.fp8_meta["recipe"] - ): + meta = self.fp8_meta + + fp8 = FP8GlobalStateManager.is_fp8_enabled() + fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() + fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() + self.fast_setattr("fp8_parameters", fp8_parameters) + self.fast_setattr("fp8", fp8) + self.fast_setattr("fp8_calibration", fp8_calibration) + fp8_enabled = fp8 or fp8_calibration + meta["fp8_checkpoint"] = fp8_enabled + + _original_recipe = None + + if fp8_parameters or fp8_enabled: + _original_recipe = meta.get("recipe", None) + if self.fp8_initialized and FP8GlobalStateManager.get_fp8_recipe() == _original_recipe: # FP8 init has already been run and recipe is the same, don't do anything. return - self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() + meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() else: # If fp8 isn't enabled, turn off and return. - self.fp8_initialized = False + self.fast_setattr("fp8_initialized", False) return - if self.fp8_parameters and not self.fp8_initialized: - self.fp8_meta["num_gemms"] = num_gemms - self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) + if fp8_parameters and not self.fp8_initialized: + meta["num_gemms"] = num_gemms + self.init_fp8_meta_tensors(meta["recipe"]) if fp8_enabled: # Set FP8 and other FP8 metadata - self.fp8_meta["num_gemms"] = num_gemms - self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() + meta["num_gemms"] = num_gemms + meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() # Set FP8_MAX per tensor according to recipe - if hasattr(self.fp8_meta["recipe"], "fp8_format"): - self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd - self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd + if hasattr(meta["recipe"], "fp8_format"): + meta["fp8_max_fwd"] = meta["recipe"].fp8_format.value.max_fwd + meta["fp8_max_bwd"] = meta["recipe"].fp8_format.value.max_bwd # Allocate scales and amaxes - self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) - self.fp8_initialized = True + self.init_fp8_meta_tensors(meta["recipe"]) + self.fast_setattr("fp8_initialized", True) - self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() + meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() - _current_recipe = self.fp8_meta["recipe"] + _current_recipe = meta["recipe"] if _original_recipe is not None and not ( issubclass(_current_recipe.__class__, _original_recipe.__class__) or issubclass(_original_recipe.__class__, _current_recipe.__class__) @@ -1028,22 +1028,18 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # Clear cached workspaces as they were created with the old recipe/quantizer type self._fp8_workspaces.clear() - @contextmanager def prepare_forward( self, inp: torch.Tensor, num_gemms: int = 1, allow_non_contiguous: bool = False, allow_different_data_and_param_types: bool = False, - ) -> Generator[torch.Tensor, None, None]: - """Checks and prep for FWD. - The context manager is needed because there isn't a way for a module to know - if it's the last FP8 module in the forward autocast. It is useful - to setup the forward aggregated amax reduction for every module - just in case. The autocast exit will pick up the most recent one. - """ - self.allow_different_data_and_param_types = allow_different_data_and_param_types - self.forwarded_at_least_once = True + ) -> torch.Tensor: + """Checks and prepares for FWD execution.""" + self.fast_setattr( + "allow_different_data_and_param_types", allow_different_data_and_param_types + ) + self.fast_setattr("forwarded_at_least_once", True) # Activation recomputation is used and this is the second forward phase. if self.fp8 and in_fp8_activation_recompute_phase(): @@ -1074,13 +1070,37 @@ def prepare_forward( if self.training and is_fp8_activation_recompute_enabled(): FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) - with get_nvtx_range_context(self.__class__.__name__ + " forward"): - if not allow_non_contiguous and not inp.is_contiguous(): - inp = inp.contiguous() - yield inp + nvtx_range_push(self.__class__.__name__ + " forward") + if not allow_non_contiguous and not inp.is_contiguous(): + inp = inp.contiguous() + return inp + def end_forward(self): + """ + Required to be called at the end of the forward function to properly handle + DelayedScaling metadata handling and the NVTX ranges. + """ + delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed() if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase(): FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) + nvtx_range_pop() + + @contextmanager + def prepare_forward_ctx( + self, + inp: torch.Tensor, + num_gemms: int = 1, + allow_non_contiguous: bool = False, + allow_different_data_and_param_types: bool = False, + ) -> Generator[torch.Tensor, None, None]: + """Checks and prepares for FWD execution.""" + inp = self.prepare_forward( + inp, num_gemms, allow_non_contiguous, allow_different_data_and_param_types + ) + try: + yield inp + finally: + self.end_forward() def set_nccl_overlap_warning_if_tp(self) -> None: """When using TP, the NCCL communication needs to be scheduled @@ -1315,9 +1335,9 @@ def clear(self): # Update the parameter based on its type if not is_dtensor: - setattr(self, name, param) + self.module_setattr(name, param) else: - setattr(self, name, dtensor_param) + self.module_setattr(name, dtensor_param) @abstractmethod def forward(self): @@ -1516,7 +1536,6 @@ def is_debug_iter(self) -> bool: debug = TEDebugState.debug_enabled if not debug: return False - self._validate_name() # If layer is run first time in new iteration, # we need to check if the debug should be enabled for this layer - @@ -1530,14 +1549,14 @@ def is_debug_iter(self) -> bool: debug = False else: debug = TEDebugState.get_iteration() >= self.next_iter_when_debug_should_be_run - self.debug_last_iteration = TEDebugState.get_iteration() - self.debug_enabled_in_this_iteration = debug + self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration()) + self.fast_setattr("debug_enabled_in_this_iteration", debug) else: # If this is the same iteration as previous invocation of the module, # we use the debug value from the first invocation in the iteration. debug = self.debug_enabled_in_this_iteration - self.debug_last_iteration = TEDebugState.get_iteration() + self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration()) if self.wgrad_store is not None: if debug and self.wgrad_store.delay_wgrad_compute(): @@ -1553,7 +1572,9 @@ def no_debug_features_active(self, quantizers): # Sometimes features inform that they will not be enabled for particular layer # for multiple next iterations. - self.next_iter_when_debug_should_be_run = next_iter_when_debug_should_be_run(quantizers) + self.fast_setattr( + "next_iter_when_debug_should_be_run", next_iter_when_debug_should_be_run(quantizers) + ) if not run_current: return True @@ -1565,22 +1586,13 @@ def no_debug_features_active(self, quantizers): def _validate_name(self): """ Validate name passed to the module. - This is invoked in the forward() method as module names are assigned after Model is initialized in Megatron-LM. - If no name is assigned, it creates a default name with layer count as the variable. + It creates a default name with layer count as the variable + which may be changed by the user of the module. """ if self.name is not None: return - assert TEDebugState.debug_enabled - import nvdlfw_inspect.api as debug_api - - if self.name is None: - debug_api.log_message( - "Names are not provided to debug modules. ", - "Creating and using generic names. Pass names to debug modules for better" - " insight. ", - level=logging.WARNING, - ) - self.name = f"Layer_{TEDebugState.get_layer_count()}" + + self.name = f"Layer_{TEDebugState.get_layer_count()}" def _check_weight_tensor_recipe_correspondence(self) -> None: """ diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index e6e69b3e4..c9ceb714e 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -614,7 +614,7 @@ def __init__( save_original_input: bool = False, name: Optional[str] = None, ) -> None: - super().__init__() + super().__init__(name) params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.num_gemms = num_gemms @@ -633,7 +633,6 @@ def __init__( ), "GroupedLinear doesn't support Userbuffer overlap." self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name - self.name = name self.wgrad_store = WeightGradStore(delay_wgrad_compute) @@ -789,7 +788,8 @@ def forward( is_grad_enabled = torch.is_grad_enabled() - with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp: + inp = self.prepare_forward(inp, num_gemms=self.num_gemms) + try: weight_tensors = self._get_weight_tensors() bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] @@ -844,6 +844,9 @@ def forward( ) out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) + finally: + self.end_forward() + if self.return_bias: return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors] return out diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index ca30ef956..702916696 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1158,9 +1158,9 @@ def __init__( ub_name: Optional[str] = None, delay_wgrad_compute: bool = False, symmetric_ar_type: Optional[str] = None, - name: str = None, + name: Optional[str] = None, ) -> None: - super().__init__() + super().__init__(name) params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.in_features = in_features @@ -1179,7 +1179,6 @@ def __init__( self.symmetric_ar_type = symmetric_ar_type self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) - self.name = name if tp_group is None: self.tp_size = tp_size @@ -1508,10 +1507,11 @@ def forward( ).is_fp8_ubuf(): fp8_grad = True - with self.prepare_forward( + inp = self.prepare_forward( inp, allow_non_contiguous=False # removed .contiguous from inside the layer - ) as inp: + ) + try: # Get concatenated weight and bias tensors weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() @@ -1590,6 +1590,9 @@ def forward( non_tensor_args, ) + finally: + self.end_forward() + if self.return_layernorm_output: out, ln_out = out diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 35e452213..bec674451 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1787,7 +1787,7 @@ def __init__( zero_centered_gamma: bool = False, device: Union[torch.device, str] = "cuda", ub_overlap_ag: bool = False, - name: str = None, + name: Optional[str] = None, ub_overlap_rs: bool = False, ub_overlap_rs_dgrad: bool = False, ub_bulk_dgrad: bool = False, @@ -1796,7 +1796,7 @@ def __init__( symmetric_ar_type: Optional[str] = None, checkpoint: bool = False, ) -> None: - super().__init__() + super().__init__(name) params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.fuse_wgrad_accumulation = fuse_wgrad_accumulation @@ -1827,7 +1827,6 @@ def __init__( for use_fp8 in [False, True] ) ) - self.name = name self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) @@ -2047,8 +2046,9 @@ def forward( if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf(): fp8_output = True - with self.prepare_forward(inp, num_gemms=2) as inp: + inp = self.prepare_forward(inp, num_gemms=2) + try: quantizers = ( self._get_quantizers(fp8_output, is_grad_enabled) if not debug @@ -2087,7 +2087,7 @@ def forward( # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): - self.bias_gelu_nvfusion = False + self.fast_setattr("bias_gelu_nvfusion", False) if is_grad_enabled: fwd_fn = _LayerNormMLP.apply @@ -2157,6 +2157,9 @@ def forward( non_tensor_args, ) + finally: + self.end_forward() + if self.return_layernorm_output: out, ln_out = out diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 38104604d..23ad8cacb 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -428,8 +428,8 @@ def forward( # weights if weights are externally touched outside this module ctx.weight_object = weight - if cpu_offloading: mark_not_offload(weight, weightmat, bias) + # TODO(ksivamani): Check memory usage tensors_to_save, tensor_objects = prepare_for_saving( saved_inputmat, @@ -1098,7 +1098,7 @@ def __init__( save_original_input: bool = False, name: Optional[str] = None, ) -> None: - super().__init__() + super().__init__(name) params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.in_features = in_features @@ -1111,7 +1111,6 @@ def __init__( self.rng_tracker_name = rng_tracker_name self.symmetric_ar_type = symmetric_ar_type self.save_original_input = save_original_input - self.name = name self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) @@ -1395,11 +1394,8 @@ def forward( ).is_fp8_ubuf(): fp8_grad = True - with self.prepare_forward( - inp, - allow_non_contiguous=isinstance(inp, QuantizedTensor), - ) as inp: - + inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor)) + try: weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() quantizers = ( @@ -1470,6 +1466,8 @@ def forward( bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, non_tensor_args, ) + finally: + self.end_forward() if self.gemm_bias_unfused_add: out = out + cast_if_needed(bias_tensor, self.activation_dtype) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 9b9ccc518..7c3125a16 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -12,7 +12,6 @@ from transformer_engine.pytorch.torch_version import torch_version from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm -from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.jit import ( @@ -398,6 +397,7 @@ def __init__( self.softmax_type = softmax_type self.name = name + TransformerEngineBaseModule._validate_name(self) attention_args = ( hidden_size, @@ -446,7 +446,7 @@ def __init__( qk_norm_type=qk_norm_type, qk_norm_eps=qk_norm_eps, qk_norm_before_rope=qk_norm_before_rope, - name=name + ".self_attention" if name is not None else None, + name=self.name + ".self_attention" if self.name is not None else None, ) if layer_type == "decoder": @@ -463,7 +463,7 @@ def __init__( qk_norm_type=qk_norm_type, qk_norm_eps=qk_norm_eps, qk_norm_before_rope=qk_norm_before_rope, - name=name + ".inter_attention" if name is not None else None, + name=self.name + ".inter_attention" if self.name is not None else None, ) # LayerNorm -> activation(Linear + Bias) -> Linear @@ -499,7 +499,7 @@ def __init__( activation_params=activation_params, normalization=normalization, device=device, - name=name + ".layernorm_mlp" if name is not None else None, + name=self.name + ".layernorm_mlp" if self.name is not None else None, ) self.hidden_dropout = hidden_dropout @@ -768,9 +768,6 @@ def forward( enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask)) ), "Encoder-decoder attention mask must be boolean tensor(s)" - if TEDebugState.debug_enabled: - TransformerEngineBaseModule._validate_name(self) - # For AMP if torch.is_autocast_enabled(): hidden_states = cast_if_needed(hidden_states, torch_get_autocast_gpu_dtype()) From bf4af7e8bdc9c5ad6740f793d55e18ab24e09e55 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Wed, 21 Jan 2026 17:06:10 -0800 Subject: [PATCH 06/22] [JAX] Fix cb.CUDAOptions usage for Triton 3.6.0 (#2610) * Fix cb.CUDAOptions usage for Triton 3.6.0 Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update utils.py Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> * Update utils.py Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> * Update utils.py Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> --------- Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/jax/triton_extensions/utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 6ea4092cb..2627a0892 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -36,6 +36,8 @@ from typing import Any, Callable, Mapping import zlib +from packaging import version + from jax import core import jax import jax.numpy as jnp @@ -274,13 +276,16 @@ def compile_triton( return _TRITON_KERNEL_CACHE[cache_key] # Compile kernel + cuda_option_kwargs = {} + if version.parse(_TRITON_VERSION) < version.parse("3.6.0"): + cuda_option_kwargs["cluster_dims"] = (1, 1, 1) options = cb.CUDAOptions( num_warps=num_warps, num_stages=num_stages, num_ctas=num_ctas, - cluster_dims=(1, 1, 1), debug=False, enable_fp_fusion=enable_fp_fusion, + **cuda_option_kwargs, ) # Mark constants as constexpr in signature @@ -303,8 +308,6 @@ def compile_triton( # Create kernel object for JAX # From jax/jaxlib/gpu/triton_kernels.cc: - from packaging import version - if version.parse(jax.__version__) >= version.parse("0.8.2"): kernel = gpu_triton.TritonKernel( compiled.name, # arg0: kernel_name (str) From f49f515471f80bd442e42693548403e999a4cd81 Mon Sep 17 00:00:00 2001 From: Teddy Do Date: Thu, 22 Jan 2026 17:14:33 -0800 Subject: [PATCH 07/22] Fix bugs in permutation custom partitioning (#2617) * Use correct block size for workspace in row id map creation, also shard workspace correctly based on 2nd dim of routing_map/row_id map Signed-off-by: DoubleCheeseCheetos * reduce size of largest test case on single_GPU scenario to fit on L40 and A100 in CI line up Signed-off-by: tdophung --------- Signed-off-by: DoubleCheeseCheetos Signed-off-by: tdophung Co-authored-by: DoubleCheeseCheetos --- tests/jax/test_permutation.py | 4 +-- .../jax/triton_extensions/permutation.py | 29 ++++++++++++------- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/tests/jax/test_permutation.py b/tests/jax/test_permutation.py index 5bb59c6ed..138a81724 100644 --- a/tests/jax/test_permutation.py +++ b/tests/jax/test_permutation.py @@ -23,7 +23,7 @@ (128, 5, 128, 3), (1024, 8, 128, 8), (4096, 32, 1280, 2), - (4096, 256, 4096, 6), + (4096, 64, 4096, 6), ] DISPATCH_COMBINE_CASES = { "L0": ALL_DISPATCH_COMBINE_CASES[0:2], @@ -44,7 +44,7 @@ (128, 5, 128, 3, 8), (1024, 8, 128, 8, 16), (4096, 32, 1280, 2, 128), - (4096, 256, 4096, 6, 16), + (4096, 64, 4096, 6, 16), ] DISPATCH_COMBINE_PADDING_CASES = { "L0": ALL_DISPATCH_COMBINE_PADDING_CASES[0:2], diff --git a/transformer_engine/jax/triton_extensions/permutation.py b/transformer_engine/jax/triton_extensions/permutation.py index bd8bd8ff1..0c80f9f18 100644 --- a/transformer_engine/jax/triton_extensions/permutation.py +++ b/transformer_engine/jax/triton_extensions/permutation.py @@ -65,8 +65,6 @@ class RowIdMapPass1Primitive(BasePrimitive): @staticmethod def abstract(routing_map_aval, *, num_tokens, num_experts, block_size): """Shape/dtype inference for pass 1.""" - del block_size # Only affects grid, not output shape - assert routing_map_aval.shape == ( num_tokens, num_experts, @@ -75,7 +73,7 @@ def abstract(routing_map_aval, *, num_tokens, num_experts, block_size): row_id_map_shape = (num_tokens, num_experts * 2 + 1) workspace_shape = ( num_experts, - triton.cdiv(num_tokens, DEFAULT_BLOCK_SIZE), + triton.cdiv(num_tokens, block_size), ) return ( @@ -134,9 +132,10 @@ def infer_sharding_from_operands( desc="RowIdMapPass1.row_id_map_sharding", ) # Workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE)) + # Second dim depends on num_tokens, so it must be sharded on the same axis as tokens workspace_sharding = NamedSharding( mesh, - PartitionSpec(None, None), + PartitionSpec(None, routing_map_spec[0]), desc="RowIdMapPass1.workspace_sharding", ) return [row_id_map_sharding, workspace_sharding] @@ -156,9 +155,11 @@ def partition(num_tokens, num_experts, block_size, mesh, arg_infos, result_infos PartitionSpec(routing_map_spec[0], None), desc="RowIdMapPass1.row_id_map_sharding", ) + # Workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE)) + # Second dim depends on num_tokens, so it must be sharded on the same axis as tokens workspace_sharding = NamedSharding( mesh, - PartitionSpec(None, None), + PartitionSpec(None, routing_map_spec[0]), desc="RowIdMapPass1.workspace_sharding", ) out_shardings = [row_id_map_sharding, workspace_sharding] @@ -186,7 +187,8 @@ def shardy_sharding_rule(num_tokens, num_experts, block_size, mesh, value_types, # Note: row_id_cols != experts since it's num_experts * 2 + 1 row_id_map_spec = (f"{prefix}_tokens", f"{prefix}_row_id_cols") # workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE)) - workspace_spec = (f"{prefix}_experts", f"{prefix}_ws_blocks") + # Second dim depends on num_tokens, so use same factor to ensure same sharding + workspace_spec = (f"{prefix}_experts", f"{prefix}_tokens") return SdyShardingRule((input_spec,), (row_id_map_spec, workspace_spec)) @@ -208,10 +210,9 @@ class RowIdMapPass2Primitive(BasePrimitive): def abstract(row_id_map_aval, workspace_aval, *, num_tokens, num_experts, block_size): """Shape/dtype inference for pass 2 (in-place operation).""" del row_id_map_aval, workspace_aval - del block_size row_id_map_shape = (num_tokens, num_experts * 2 + 1) - workspace_shape = (num_experts, triton.cdiv(num_tokens, DEFAULT_BLOCK_SIZE)) + workspace_shape = (num_experts, triton.cdiv(num_tokens, block_size)) return ( jax.core.ShapedArray(row_id_map_shape, jnp.int32), @@ -270,9 +271,11 @@ def infer_sharding_from_operands( PartitionSpec(*row_id_map_spec), desc="RowIdMapPass2.row_id_map_sharding", ) + # Workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE)) + # Second dim depends on num_tokens, so it must be sharded on the same axis as tokens workspace_sharding = NamedSharding( mesh, - PartitionSpec(None, None), + PartitionSpec(None, row_id_map_spec[0]), desc="RowIdMapPass2.workspace_sharding", ) return [row_id_map_sharding, workspace_sharding] @@ -292,9 +295,11 @@ def partition(num_tokens, num_experts, block_size, mesh, arg_infos, result_infos PartitionSpec(*row_id_map_spec), desc="RowIdMapPass2.row_id_map_sharding", ) + # Workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE)) + # Second dim depends on num_tokens, so it must be sharded on the same axis as tokens workspace_sharding = NamedSharding( mesh, - PartitionSpec(None, None), + PartitionSpec(None, row_id_map_spec[0]), desc="RowIdMapPass2.workspace_sharding", ) out_shardings = [row_id_map_sharding, workspace_sharding] @@ -317,7 +322,9 @@ def shardy_sharding_rule(num_tokens, num_experts, block_size, mesh, value_types, del num_tokens, num_experts, block_size, mesh, value_types, result_types prefix = "RowIdMapPass2" row_id_map_spec = (f"{prefix}_tokens", f"{prefix}_cols") - workspace_spec = (f"{prefix}_ws_experts", f"{prefix}_ws_blocks") + # workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE)) + # Second dim depends on num_tokens, so use same factor to ensure same sharding + workspace_spec = (f"{prefix}_ws_experts", f"{prefix}_tokens") return SdyShardingRule((row_id_map_spec, workspace_spec), (row_id_map_spec, workspace_spec)) From d9b7fc5770a88af06e2e9c2bd97b550614c3a69f Mon Sep 17 00:00:00 2001 From: Oleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com> Date: Fri, 23 Jan 2026 06:37:48 +0100 Subject: [PATCH 08/22] [Common] Disabled the tuned NVFP4 kernels (#2615) * Disabled the tuned NVFP4 kernels Signed-off-by: Oleg Goncharov * Disabled fast math in cpp tests Signed-off-by: Oleg Goncharov --------- Signed-off-by: Oleg Goncharov --- tests/cpp/operator/test_cast_nvfp4_transpose.cu | 7 +------ .../common/cast/nvfp4/quantize_transpose_nvfp4.cuh | 8 ++++---- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index c4df8759f..d8d495d61 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -677,11 +677,6 @@ std::vector Activation_types = { ActivationType::Identity }; -std::vector use_fast_nvfp4_scaling_vec = { - false, - true -}; - } // namespace class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam @@ -743,7 +738,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(Activation_types), ::testing::ValuesIn(tensor_dims), ::testing::Values(DType::kBFloat16), - ::testing::ValuesIn(use_fast_nvfp4_scaling_vec)), + ::testing::Values(false)), [](const testing::TestParamInfo& info) { std::string name = to_string(std::get<0>(info.param)); const auto& shape = std::get<1>(info.param); diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index 99776db28..61c6ba9ce 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -1168,10 +1168,10 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, // TODO(Frank): Is there a better way to do this? bool return_transpose = output->has_columnwise_data(); - if (!use_2d_quantization && (input.dtype() == DType::kBFloat16)) { - quantize_transpose_tuned_1D(input, noop, output, quant_config, stream); - return; - } + // if (!use_2d_quantization && (input.dtype() == DType::kBFloat16)) { + // quantize_transpose_tuned_1D(input, noop, output, quant_config, stream); + // return; + // } constexpr bool COMPUTE_ACTIVATIONS = false; using ParamOP = Empty; From 07f7750384fbdea7d137d8b317ccb88c255c9224 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Thu, 22 Jan 2026 12:00:04 -0800 Subject: [PATCH 09/22] [PyT] Update THD sink attention logic for cudnn >=9.18.0 (#2568) * Update THD sink attention logic for newer cudnn versions THD Sink attention is supported in 9.18.0 Signed-off-by: Chen Cui * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update thd sink attention logic for cp>1 Signed-off-by: Chen Cui * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add unit test for thd + sink attention Signed-off-by: Chen Cui * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address comments Signed-off-by: Chen Cui * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * do not skip thd cp sink attention test Signed-off-by: Chen Cui * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * disable deterministic mode for sink attention Signed-off-by: Chen Cui --------- Signed-off-by: Chen Cui Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 9 ++++++ .../attention/test_attention_with_cp.py | 9 ++++-- .../dot_product_attention/context_parallel.py | 18 ++++++----- .../attention/dot_product_attention/utils.py | 31 ++++++++++--------- 4 files changed, 42 insertions(+), 25 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 6fe0ffdae..65ca74c48 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -429,6 +429,15 @@ def test_dpa_softmax(dtype, model_configs, model): ) +@pytest.mark.skipif(get_cudnn_version() < (9, 18, 0), reason="cuDNN 9.18.0+ is required.") +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("model_configs", [model_configs_softmax]) +@pytest.mark.parametrize("model", model_configs_softmax.keys()) +def test_dpa_softmax_thd(dtype, model_configs, model): + """Test DotProductAttention module with different softmax types""" + test_dot_product_attention(dtype, model_configs, model, True, True, "thd_thd_thd", False, False) + + model_configs_mla = { # test: ModelConfig(b, sq, hq, dqk) "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 9480b8de7..06ed6e572 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -283,9 +283,14 @@ def test_cp_with_fused_attention( pytest.skip( "CP implementation only supports cp_comm_type=a2a for non-vanilla softmax types!" ) - if config.softmax_type != "vanilla" and qkv_format == "thd": + if ( + get_cudnn_version() < (9, 18, 0) + and config.softmax_type != "vanilla" + and qkv_format == "thd" + ): pytest.skip( - "CP implementation does not support qkv_format=thd for non-vanilla softmax types!" + "Unless cudnn version >= 9.18.0, CP implementation does not support qkv_format=thd for" + " non-vanilla softmax types!" ) dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 75b360e48..a5931188d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -4026,28 +4026,30 @@ def attn_forward_func_with_cp( assert not sliding_window_attn or cp_comm_type in [ "a2a", "all_gather", - ], "Context parallelism does not support sliding window attention with {cp_comm_type=}!" + ], f"Context parallelism does not support sliding window attention with {cp_comm_type=}!" enable_mla = k.shape[-1] != v.shape[-1] assert not enable_mla or cp_comm_type in [ "p2p", "a2a+p2p", - ], "Context parallelism does not support MLA with {cp_comm_type=}!" + ], f"Context parallelism does not support MLA with {cp_comm_type=}!" if fp8 and fp8_meta is not None: if fp8_meta["recipe"].fp8_dpa: assert ( softmax_type == "vanilla" - ), "Context parallelism does not support {softmax_type=} with FP8 attention!" + ), f"Context parallelism does not support {softmax_type=} with FP8 attention!" assert ( softmax_type == "vanilla" or use_fused_attention - ), "Context parallelism only supports {softmax_type=} with FusedAttention backend!" + ), f"Context parallelism only supports {softmax_type=} with FusedAttention backend!" assert ( softmax_type == "vanilla" or cp_comm_type == "a2a" - ), "Context parallelism only supports {softmax_type=} with cp_comm_type = 'a2a'!" - assert ( - softmax_type == "vanilla" or qkv_format != "thd" - ), "Context parallelism does not support {softmax_type=} with qkv_format = 'thd'!" + ), f"Context parallelism only supports {softmax_type=} with cp_comm_type = 'a2a'!" + if get_cudnn_version() < (9, 18, 0): + assert softmax_type == "vanilla" or qkv_format != "thd", ( + f"Before cuDNN 9.18.0, context parallelism does not support {softmax_type=} with" + " qkv_format = 'thd'!" + ) args = [ is_training, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index cb74a15e7..fcac740cc 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -716,22 +716,14 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt ) use_unfused_attention = False if qkv_format == "thd": - logger.debug( - "Disabling FusedAttention for softmax_type = %s and qkv_format = thd", softmax_type - ) - use_fused_attention = False - logger.debug( - "Disabling UnfusedDotProductAttention for softmax_type = %s and qkv_format = thd", - softmax_type, - ) - use_unfused_attention = False + if cudnn_version < (9, 18, 0): + logger.debug( + "Disabling FusedAttention for softmax_type = %s, qkv_format = thd and cuDNN" + " version < 9.18", + softmax_type, + ) + use_fused_attention = False if context_parallel: - logger.debug( - "Disabling UnfusedDotProductAttention for context parallelism with softmax_type" - " = %s", - softmax_type, - ) - use_unfused_attention = False if cp_comm_type != "a2a": logger.debug( "Disabling FusedAttention for context parallelism with softmax_type = %s and" @@ -1049,6 +1041,15 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt ) use_flash_attention_2 = False if use_fused_attention and deterministic: + if softmax_type != "vanilla": + logger.debug( + "Disabling FusedAttention for determinism reasons with softmax_type = %s. " + "Sink attention (off-by-one and learnable softmax) requires " + "NVTE_ALLOW_NONDETERMINISTIC_ALGO=1", + softmax_type, + ) + use_fused_attention = False + fused_attention_backend = None if fused_attention_backend == FusedAttnBackend["FP8"] and is_training: logger.debug("Disabling FusedAttention for determinism reasons with FP8") use_fused_attention = False From fdc0168a6a65bc55ba3add36d49f793247620702 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Thu, 22 Jan 2026 12:07:20 -0800 Subject: [PATCH 10/22] Add support for SWA (left, right) with FusedAttention (#2477) * SWA (left, right) with FusedAttention changes cherry-picked from https://github.com/NVIDIA/TransformerEngine/pull/1369 Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix test_kv_cache failures Signed-off-by: Sudhakar Singh * remove unnecessary comments Signed-off-by: Sudhakar Singh * fix some more filter issues, address feedback Signed-off-by: Sudhakar Singh * fix for local test case failures - `bottom_right_diagonal` should be calculated in `fused_attn_fwd` call as well Signed-off-by: Sudhakar Singh * make conditions more accurate Signed-off-by: Sudhakar Singh * add cp tests to test swa (left, right) Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove dead code and make conditions better Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feedback form Charlene Signed-off-by: Sudhakar Singh * small er Signed-off-by: Sudhakar Singh * plumb `bottom_right_diagonal` through jax Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * plumb `bottom_right_diagonal` through jax Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add missing fields Signed-off-by: Sudhakar Singh * use proper mask type in CP Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Sudhakar Singh Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 7 +- .../attention/test_attention_with_cp.py | 15 ++- tests/pytorch/utils.py | 14 +-- .../common/fused_attn/fused_attn.cpp | 97 ++++++++++--------- .../fused_attn_f16_arbitrary_seqlen.cu | 75 ++++++++------ .../fused_attn_f16_arbitrary_seqlen.h | 19 ++-- .../common/fused_attn/fused_attn_fp8.cu | 2 + transformer_engine/common/fused_attn/utils.h | 12 ++- .../include/transformer_engine/fused_attn.h | 63 ++++++------ .../jax/cpp_extensions/attention.py | 22 ++++- transformer_engine/jax/csrc/extensions.h | 4 +- .../jax/csrc/extensions/attention.cpp | 62 ++++++------ .../dot_product_attention/backends.py | 21 +++- .../dot_product_attention.py | 43 ++++++-- .../attention/dot_product_attention/utils.py | 82 +++++++++------- .../pytorch/attention/multi_head_attention.py | 26 +++++ .../pytorch/cpp_extensions/fused_attn.py | 22 +++++ transformer_engine/pytorch/csrc/extensions.h | 15 +-- .../pytorch/csrc/extensions/attention.cpp | 51 +++++----- transformer_engine/pytorch/transformer.py | 55 ++++++++++- 20 files changed, 474 insertions(+), 233 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 65ca74c48..bd0ac4197 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -153,6 +153,7 @@ def test_dot_product_attention( if config.window_size == (-1, -1) and swa: config.window_size = [2, 2] + config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0] if qkv_format == "thd" and "padding" not in config.attn_mask_type: @@ -171,6 +172,7 @@ def test_dot_product_attention( deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + if not fused_attn_supported: is_training = False available_backends, _, fused_attn_backends = get_available_attention_backends( @@ -701,9 +703,10 @@ def test_dpa_bias_shapes(dtype, model_configs, model): @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_swa]) @pytest.mark.parametrize("model", model_configs_swa.keys()) -def test_dpa_sliding_window(dtype, model_configs, model): +@pytest.mark.parametrize("qkv_layout", ["thd_thd_thd", "sbhd_sbhd_sbhd"]) +def test_dpa_sliding_window(dtype, model_configs, model, qkv_layout): """Test DotProductAttention module with sliding window attention""" - test_dot_product_attention(dtype, model_configs, model, False, True, None, True, False) + test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, True, False) model_configs_alibi_slopes = { diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 06ed6e572..836598087 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -147,7 +147,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias" ), # MHA "cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA - "cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA + "cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA "cp_2_2": ModelConfig( @@ -163,7 +163,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): 2, 4096, 12, 128, num_gqa_groups=2, attn_bias_type="post_scale_bias" ), # GQA "cp_2_4": ModelConfig( - 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0) + 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 512) ), # GQA "cp_3_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64), # MLA "cp_3_1": ModelConfig(2, 4096, 12, 128, head_dim_v=64), # MLA @@ -187,7 +187,16 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): qkv_formats = ["bshd", "sbhd", "thd"] cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] if test_essential: - configs = ["cp_1_0", "cp_1_1", "cp_1_4", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"] + configs = [ + "cp_1_0", + "cp_1_1", + "cp_1_4", + "cp_2_0", + "cp_2_2", + "cp_2_4", + "cp_3_2", + "cp_4_2", + ] model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} dtypes = ["bf16", "fp8"] qkv_formats = ["sbhd", "thd"] diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index ca5fbc997..b6a84a8e2 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -353,11 +353,11 @@ def test(): backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} if AttentionLogging._is_logging_setup is False: AttentionLogging.setup_logging() - with logging_context(highest_level=AttentionLogging._log_level): - for i in range(3): - os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) - _attention_backends["backend_selection_requires_update"] = True - available_backends, flash_attention_backend, fused_attention_backend = test() - if fused_attention_backend == FusedAttnBackend[backends[i]]: - fused_attn_backends.append(fused_attention_backend) + + for i in range(3): + os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) + _attention_backends["backend_selection_requires_update"] = True + available_backends, flash_attention_backend, fused_attention_backend = test() + if fused_attention_backend == FusedAttnBackend[backends[i]]: + fused_attn_backends.append(fused_attention_backend) return available_backends, flash_attention_backend, fused_attn_backends diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 415bfae06..4f8367aac 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -406,9 +406,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (window_size_right == -1 || window_size_right == 0)) || // 9.2: SWA (left, 0) + top-left diagonal + {bshd, sbhd} (cudnn_runtime_version >= 90200 && - ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || - ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && - (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + ((window_size_left == -1 && window_size_right == -1 && + attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK) || + ((window_size_left == -1 || window_size_left >= 0) && window_size_right == 0 && + (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && max_seqlen_q == max_seqlen_kv)) && max_seqlen_q <= max_seqlen_kv && dropout == 0.0 && @@ -418,12 +420,14 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // 9.6: SWA (left, 0) + top-left/bottom-right diagonal + {bshd, sbhd, thd} (cudnn_runtime_version >= 90600 && ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || - ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && + ((window_size_left >= 0 || window_size_left == -1) && + (window_size_right >= 0 || window_size_right == -1) && ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && // TODO(cyang): fix bug for BRCM + cross-attention on sm100 (sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv && cudnn_runtime_version <= 90700) || cudnn_runtime_version > 90700)))) || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && (sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv && @@ -515,16 +519,14 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // NVTE fused attention FWD with packed QKV // DEPRECATED: This API is deprecated. // Please use nvte_fused_attn_fwd with separate Q, K, V tensors instead. -void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, - const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, - size_t max_seqlen, bool is_training, bool return_max_logit, - bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream) { +void nvte_fused_attn_fwd_qkvpacked( + const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, + const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen, + bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); using namespace transformer_engine; @@ -598,10 +600,10 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, fused_attn_arbitrary_seqlen_fwd( b, h, h, max_seqlen, max_seqlen, d, d, t, t, 0, 0, 0, 0, 0, 0, is_training, return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, &Q_view, &K_view, &V_view, input_Bias, - input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens, - input_cu_seqlens_padded, input_cu_seqlens_padded, nullptr, nullptr, input_rng_state, - wkspace, stream, handle); + window_size_left, window_size_right, bottom_right_diagonal, &Q_view, &K_view, &V_view, + input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, + input_cu_seqlens, input_cu_seqlens_padded, input_cu_seqlens_padded, nullptr, nullptr, + input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. " @@ -639,8 +641,8 @@ void nvte_fused_attn_bwd_qkvpacked( NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph, - NVTETensor workspace, cudaStream_t stream) { + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); using namespace transformer_engine; @@ -736,10 +738,11 @@ void nvte_fused_attn_bwd_qkvpacked( fused_attn_arbitrary_seqlen_bwd( b, h, h, max_seqlen, max_seqlen, d, d, t, t, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, &Q_view, - &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, &dQ_view, - &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens, - input_cu_seqlens_padded, input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); + attn_mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, + deterministic, &Q_view, &K_view, &V_view, input_O, input_dO, input_Bias, + input_SoftmaxOffset, output_S, &dQ_view, &dK_view, &dV_view, output_dBias, + output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens, input_cu_seqlens_padded, + input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); #else const char *err_msg = "cuDNN 8.9.0 is required for BF16/FP16 fused attention " @@ -790,7 +793,8 @@ void nvte_fused_attn_fwd_kvpacked( size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { + int64_t window_size_right, bool bottom_right_diagonal, NVTETensor workspace, + cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -902,10 +906,10 @@ void nvte_fused_attn_fwd_kvpacked( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, input_Q, &K_view, &V_view, input_Bias, - input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, - input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, - input_page_table_v, input_rng_state, wkspace, stream, handle); + window_size_left, window_size_right, bottom_right_diagonal, input_Q, &K_view, &V_view, + input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, + input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. " @@ -945,8 +949,8 @@ void nvte_fused_attn_bwd_kvpacked( const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace, - cudaStream_t stream) { + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, bool cuda_graph, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -1052,11 +1056,11 @@ void nvte_fused_attn_bwd_kvpacked( fused_attn_arbitrary_seqlen_bwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, - input_Q, &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, - output_dQ, &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, - input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, - wkspace, stream, handle); + bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, + bottom_right_diagonal, deterministic, input_Q, &K_view, &V_view, input_O, input_dO, + input_Bias, input_SoftmaxOffset, output_S, output_dQ, &dK_view, &dV_view, output_dBias, + output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, + input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else const char *err_msg = "cuDNN 8.9.3 is required for BF16/FP16 fused attention " @@ -1106,8 +1110,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, - cudaStream_t stream) { + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -1195,10 +1199,10 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, input_Q, input_K, input_V, input_Bias, - input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, - input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, - input_page_table_v, input_rng_state, wkspace, stream, handle); + window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V, + input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, + input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. " @@ -1228,8 +1232,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, - bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic, bool cuda_graph, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -1302,8 +1307,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso fused_attn_arbitrary_seqlen_bwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, - deterministic, input_Q, input_K, input_V, input_O, input_dO, input_Bias, - input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias, + bottom_right_diagonal, deterministic, input_Q, input_K, input_V, input_O, input_dO, + input_Bias, input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index d3746fc04..53023361e 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -55,10 +55,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training, bool return_max_logit, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK, - void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2, - void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, - void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, void *devPtrQ, + void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1, + void *devPtrS2, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, + void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -75,6 +75,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( if (is_bottom_right && s_q == s_kv && !is_padding) { is_causal = true; is_bottom_right = false; + bottom_right_diagonal = false; } bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); bool is_dropout = (is_training && dropout_probability != 0.0f); @@ -129,6 +130,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( softmax_type, window_size_left, window_size_right, + bottom_right_diagonal, true, tensorType, cudnn_frontend::DataType_t::NOT_SET, @@ -254,9 +256,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl( .set_causal_mask_bottom_right(is_bottom_right) .set_attn_scale(attn_scale); + fe::DiagonalAlignment_t const &diagonal_alignment = + bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT + : fe::DiagonalAlignment_t::TOP_LEFT; + sdpa_options.set_diagonal_alignment(diagonal_alignment); if (cudnn_runtime_version >= 90200 && window_size_left != -1) { sdpa_options.set_diagonal_band_left_bound(window_size_left + 1); } + if (cudnn_runtime_version >= 90600 && window_size_right != -1) { + sdpa_options.set_diagonal_band_right_bound(window_size_right); + } sdpa_options.set_alibi_mask(is_alibi); @@ -542,13 +551,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl( int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, void *devPtrQ, - void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, - void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrdQ, void *devPtrdK, void *devPtrdV, - void *devPtrdO, void *devPtrdBias, void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, - void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, - void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, - void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, void *devPtrQ, void *devPtrKTranspose, void *devPtrVTranspose, + void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, void *devPtrSoftmaxOffset, + void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias, + void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, void *devPtrDropoutOffset, + void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, + void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, + size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -563,6 +573,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( if (is_bottom_right && s_q == s_kv && !is_padding) { is_causal = true; is_bottom_right = false; + bottom_right_diagonal = false; } bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); bool is_dropout = (dropout_probability != 0.0f); @@ -621,6 +632,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( softmax_type, window_size_left, window_size_right, + bottom_right_diagonal, deterministic, tensorType, cudnn_frontend::DataType_t::NOT_SET, @@ -781,9 +793,17 @@ void fused_attn_arbitrary_seqlen_bwd_impl( sdpa_backward_options.set_max_total_seq_len_kv(s_kv); } + fe::DiagonalAlignment_t const &diagonal_alignment = + bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT + : fe::DiagonalAlignment_t::TOP_LEFT; + sdpa_backward_options.set_diagonal_alignment(diagonal_alignment); + if (cudnn_runtime_version >= 90200 && window_size_left != -1) { sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1); } + if (cudnn_runtime_version >= 90600 && window_size_right != -1) { + sdpa_backward_options.set_diagonal_band_right_bound(window_size_right); + } if (cudnn_runtime_version >= 90000) { sdpa_backward_options.set_deterministic_algorithm(deterministic); @@ -1044,8 +1064,8 @@ void fused_attn_arbitrary_seqlen_fwd( size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, @@ -1180,11 +1200,11 @@ void fused_attn_arbitrary_seqlen_fwd( max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, - devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, - devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, - stream, handle); + window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, + devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, + devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, + devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, + &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1206,13 +1226,14 @@ void fused_attn_arbitrary_seqlen_bwd( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S, - Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, - Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, + const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, + Tensor *output_dV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; void *devPtrQ = input_Q->data.dptr; @@ -1273,8 +1294,8 @@ void fused_attn_arbitrary_seqlen_bwd( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, - deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, - devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, + bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, + devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index c34eae4e6..4dd7f3d1d 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -25,8 +25,8 @@ void fused_attn_arbitrary_seqlen_fwd( size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, @@ -37,13 +37,14 @@ void fused_attn_arbitrary_seqlen_bwd( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S, - Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, - Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, + const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, + Tensor *output_dV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle); #endif // CUDNN_VERSION >= 8900 } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 3630041cc..f886ec77f 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1707,6 +1707,7 @@ void fused_attn_fp8_fwd_impl_v1( 0, 0, true, + true, qkv_tensor_type, o_tensor_type, cudnn_frontend::DataType_t::NOT_SET, @@ -2035,6 +2036,7 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, 0, 0, + true, false, qkv_tensor_type, o_tensor_type, diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 7d23bb5c5..fdfc4abe8 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -110,6 +110,7 @@ struct FADescriptor_v1 { NVTE_Softmax_Type softmax_type; std::int64_t window_size_left; std::int64_t window_size_right; + bool bottom_right_diagonal; bool deterministic; cudnn_frontend::DataType_t qkv_tensor_type; cudnn_frontend::DataType_t o_tensor_type; @@ -121,15 +122,16 @@ struct FADescriptor_v1 { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type, - window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type, - o_tensor_type, do_tensor_type, dqkv_tensor_type, generate_max_sum_exp) < + window_size_left, window_size_right, bottom_right_diagonal, deterministic, + bias_type, qkv_tensor_type, o_tensor_type, do_tensor_type, dqkv_tensor_type, + generate_max_sum_exp) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type, - rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type, - rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, - rhs.dqkv_tensor_type, rhs.generate_max_sum_exp); + rhs.window_size_left, rhs.window_size_right, rhs.bottom_right_diagonal, + rhs.deterministic, rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type, + rhs.do_tensor_type, rhs.dqkv_tensor_type, rhs.generate_max_sum_exp); } }; diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 0fabb81ae..cddd3d750 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -270,22 +270,21 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ [[deprecated( "nvte_fused_attn_fwd_qkvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate " "Q, K, V tensors instead.")]] -void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, - const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, - size_t max_seqlen, bool is_training, bool return_max_logit, - bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream); +void nvte_fused_attn_fwd_qkvpacked( + const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, + const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen, + bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed QKV input. * @@ -333,6 +332,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] workspace Workspace tensor. @@ -347,8 +347,8 @@ void nvte_fused_attn_bwd_qkvpacked( NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph, - NVTETensor workspace, cudaStream_t stream); + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute dot product attention with packed KV input. * @@ -410,6 +410,7 @@ void nvte_fused_attn_bwd_qkvpacked( * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ @@ -425,7 +426,8 @@ void nvte_fused_attn_fwd_kvpacked( size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); + int64_t window_size_right, bool bottom_right_diagonal, NVTETensor workspace, + cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed KV input. * @@ -479,6 +481,7 @@ void nvte_fused_attn_fwd_kvpacked( * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] workspace Workspace tensor. @@ -495,8 +498,8 @@ void nvte_fused_attn_bwd_kvpacked( const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace, - cudaStream_t stream); + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, bool cuda_graph, + NVTETensor workspace, cudaStream_t stream); /*! \brief Compute dot product attention with separate Q, K and V. * @@ -560,19 +563,23 @@ void nvte_fused_attn_bwd_kvpacked( * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_fwd( - const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, - const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, - bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); +void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, + const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, + const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, + bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with separate Q, K and V. * @@ -629,6 +636,7 @@ void nvte_fused_attn_fwd( * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] workspace Workspace tensor. @@ -644,8 +652,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, - bool cuda_graph, NVTETensor workspace, cudaStream_t stream); + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic, bool cuda_graph, + NVTETensor workspace, cudaStream_t stream); /*! \brief Update the RNG state with the seed and calculated offset. * diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index ee10115aa..e5d75e150 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -70,6 +70,7 @@ "is_training", "max_segments_per_seq", "window_size", + "bottom_right_diagonal", "context_parallel_load_balanced", "cp_axis", "cp_striped_window_size", @@ -91,6 +92,7 @@ class _FusedAttnConfig: is_training: bool max_segments_per_seq: int window_size: Tuple[int, int] + bottom_right_diagonal: bool context_parallel_load_balanced: bool cp_axis: str cp_striped_window_size: Tuple[int, int] # Only for CP + Ring P2P + THD + SWA @@ -371,6 +373,11 @@ def abstract( *bias_batch_shape, bias_heads, _, _ = bias_aval.shape bias_batch = reduce(operator.mul, bias_batch_shape) + bottom_right_diagonal = config.attn_mask_type in [ + AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, + AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, + ] + # do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to # prepare for the active fused-attn backend input_batch = reduce(operator.mul, batch_shape) @@ -395,6 +402,7 @@ def abstract( config.max_segments_per_seq, config.window_size[0], config.window_size[1], + bottom_right_diagonal, ) wkspace_aval = q_aval.update( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) @@ -503,6 +511,7 @@ def lowering( deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), window_size_left=window_size_left, window_size_right=window_size_right, + bottom_right_diagonal=config.bottom_right_diagonal, softmax_type=int(config.softmax_type.value), ) @@ -813,6 +822,7 @@ def abstract( config.max_segments_per_seq, config.window_size[0], config.window_size[1], + config.bottom_right_diagonal, ) dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype) @@ -948,6 +958,7 @@ def lowering( deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), window_size_left=window_size_left, window_size_right=window_size_right, + bottom_right_diagonal=config.bottom_right_diagonal, softmax_type=int(config.softmax_type.value), ) @@ -1357,9 +1368,10 @@ def get_adjusted_max_segments_per_seq(self, max_seqlen, cp_size): def get_step_config(self) -> _FusedAttnConfig: """Returns a _FusedAttnConfig for single CP step call to fused attention.""" + adjusted_mask = self.get_adjusted_mask() return _FusedAttnConfig( attn_bias_type=self.config.attn_bias_type, - attn_mask_type=self.get_adjusted_mask(), + attn_mask_type=adjusted_mask, softmax_type=self.config.softmax_type, qkv_layout=self.config.qkv_layout, scaling_factor=self.config.scaling_factor, @@ -1367,6 +1379,7 @@ def get_step_config(self) -> _FusedAttnConfig: is_training=self.config.is_training, max_segments_per_seq=self.config.max_segments_per_seq, window_size=self.config.window_size, + bottom_right_diagonal=adjusted_mask.is_bottom_right(), context_parallel_load_balanced=self.config.context_parallel_load_balanced, cp_axis=self.config.cp_axis, cp_striped_window_size=None, @@ -1375,9 +1388,10 @@ def get_step_config(self) -> _FusedAttnConfig: def get_step_config_for_striped(self, max_seqlen, cp_size) -> _FusedAttnConfig: """Returns a _FusedAttnConfig for single CP step call (made via a striped AG primitive) to fused attention.""" + adjusted_mask = self.get_adjusted_mask() return _FusedAttnConfig( attn_bias_type=self.config.attn_bias_type, - attn_mask_type=self.get_adjusted_mask(), + attn_mask_type=adjusted_mask, softmax_type=self.config.softmax_type, qkv_layout=self.config.qkv_layout, scaling_factor=self.config.scaling_factor, @@ -1385,6 +1399,7 @@ def get_step_config_for_striped(self, max_seqlen, cp_size) -> _FusedAttnConfig: is_training=self.config.is_training, max_segments_per_seq=self.get_adjusted_max_segments_per_seq(max_seqlen, cp_size), window_size=self.config.window_size, + bottom_right_diagonal=adjusted_mask.is_bottom_right(), context_parallel_load_balanced=self.config.context_parallel_load_balanced, cp_axis=self.config.cp_axis, cp_striped_window_size=None, @@ -2430,6 +2445,7 @@ def get_step_config(self, attn_mask_type) -> _FusedAttnConfig: is_training=self.config.is_training, max_segments_per_seq=self.config.max_segments_per_seq, window_size=self.config.window_size, + bottom_right_diagonal=attn_mask_type.is_bottom_right(), context_parallel_load_balanced=self.config.context_parallel_load_balanced, cp_axis=self.config.cp_axis, cp_striped_window_size=None, @@ -3418,6 +3434,7 @@ def fused_attn_fwd( is_training=is_training, max_segments_per_seq=max_segments_per_seq, window_size=(-1, -1) if window_size is None else window_size, + bottom_right_diagonal=attn_mask_type.is_bottom_right(), context_parallel_load_balanced=context_parallel_causal_load_balanced, cp_axis=_maybe_context_parallel_axis(context_parallel_axis), cp_striped_window_size=None, @@ -3590,6 +3607,7 @@ def fused_attn_bwd( is_training=is_training, max_segments_per_seq=max_segments_per_seq, window_size=(-1, -1) if window_size is None else window_size, + bottom_right_diagonal=attn_mask_type.is_bottom_right(), context_parallel_load_balanced=context_parallel_causal_load_balanced, cp_axis=_maybe_context_parallel_axis(context_parallel_axis), cp_striped_window_size=None, diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 5f9339263..3fd086e25 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -121,7 +121,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, size_t max_segments_per_seq, int64_t window_size_left, - int64_t window_size_right); + int64_t window_size_right, bool bottom_right_diagonal); pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, @@ -129,7 +129,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq, - int64_t window_size_left, int64_t window_size_right); + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal); // GEMM XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 4fe8e728a..92e67ac19 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -144,7 +144,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, size_t max_segments_per_seq, int64_t window_size_left, - int64_t window_size_right) { + int64_t window_size_right, bool bottom_right_diagonal) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; @@ -192,7 +192,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); + window_size_left, window_size_right, bottom_right_diagonal, query_workspace_tensor.data(), + nullptr); } nvte_tensor_pack_destroy(&aux_output_tensors); @@ -237,7 +238,7 @@ static void FusedAttnForwardImpl( size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, bool deterministic, - int64_t window_size_left, int64_t window_size_right) { + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal) { FUSED_ATTN_IMPL_COMMON_BLOCK; /* Input tensors */ @@ -328,7 +329,7 @@ static void FusedAttnForwardImpl( k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, workspace_tensor.data(), stream); + window_size_left, window_size_right, bottom_right_diagonal, workspace_tensor.data(), stream); nvte_tensor_pack_destroy(&aux_output_tensors); } @@ -346,6 +347,7 @@ static void FusedAttnForwardImpl( size_t max_segments_per_seq = get_attr_value(attrs, "max_segments_per_seq"); \ auto window_size_left = get_attr_value(attrs, "window_size_left"); \ auto window_size_right = get_attr_value(attrs, "window_size_right"); \ + bool bottom_right_diagonal = get_attr_value(attrs, "bottom_right_diagonal"); \ float scaling_factor = get_attr_value(attrs, "scaling_factor"); \ float dropout_probability = get_attr_value(attrs, "dropout_probability"); \ NVTE_Bias_Type bias_type = \ @@ -384,7 +386,7 @@ Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Ty input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads, qk_head_dim, v_head_dim, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type, mask_type, softmax_type, qkv_layout, dtype, wkspace_dtype, - is_training, deterministic, window_size_left, window_size_right); + is_training, deterministic, window_size_left, window_size_right, bottom_right_diagonal); return ffi_with_cuda_error_check(); } @@ -415,7 +417,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq, - int64_t window_size_left, int64_t window_size_right) { + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); @@ -467,17 +469,18 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( auto dummy_ragged_offset_tensor = TensorWrapper(nullptr, std::vector{num_segments + 1}, DType::kInt32); - nvte_fused_attn_bwd( - q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), - doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), - dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), - dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, deterministic, false, query_workspace_tensor.data(), nullptr); + nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), + doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), + dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), + q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, + qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, bottom_right_diagonal, deterministic, false, + query_workspace_tensor.data(), nullptr); } nvte_tensor_pack_destroy(&aux_input_tensors); @@ -496,7 +499,7 @@ static void FusedAttnBackwardImpl( size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, bool deterministic, - int64_t window_size_left, int64_t window_size_right) { + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal) { FUSED_ATTN_IMPL_COMMON_BLOCK; /* Input tensors */ @@ -593,16 +596,17 @@ static void FusedAttnBackwardImpl( } } - nvte_fused_attn_bwd( - q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), - doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), dbias_tensor.data(), - dsoftmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, deterministic, false, workspace_tensor.data(), stream); + nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), + doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), + dbias_tensor.data(), dsoftmax_offset_tensor.data(), + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, + kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, + mask_type, softmax_type, window_size_left, window_size_right, + bottom_right_diagonal, deterministic, false, workspace_tensor.data(), stream); nvte_tensor_pack_destroy(&aux_input_tensors); } @@ -631,7 +635,7 @@ Error_Type FusedAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_T q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads, qk_head_dim, v_head_dim, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type, mask_type, softmax_type, qkv_layout, dtype, wkspace_dtype, is_training, deterministic, window_size_left, - window_size_right); + window_size_right, bottom_right_diagonal); return ffi_with_cuda_error_check(); } diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index c726ed884..ef7fa0dcc 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -261,6 +261,7 @@ def forward( attn_mask_type: str = "causal", attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, @@ -346,6 +347,11 @@ def forward( attention_mask=attention_mask, window_size=window_size, attention_type=self.attention_type, + bottom_right_alignment=( + attn_mask_type not in ["causal", "padding_causal"] + if bottom_right_diagonal is None + else bottom_right_diagonal + ), ) ) @@ -449,7 +455,11 @@ def forward( actual_seqlens_q=actual_seqlens_q if "padding" in attn_mask_type else None, actual_seqlens_kv=actual_seqlens_kv if "padding" in attn_mask_type else None, alibi_slopes=alibi_slopes, - bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], + bottom_right_alignment=( + attn_mask_type not in ["causal", "padding_causal"] + if bottom_right_diagonal is None + else bottom_right_diagonal + ), ) matmul_result = torch.baddbmm( matmul_result, @@ -1110,6 +1120,7 @@ def forward( attn_mask_type, softmax_type, window_size, + bottom_right_diagonal, rng_gen, fused_attention_backend, use_FAv2_bwd, @@ -1213,6 +1224,7 @@ def forward( attn_mask_type, softmax_type, window_size, + bottom_right_diagonal, rng_gen, softmax_offset, cuda_graph=is_graph_capturing(), @@ -1290,6 +1302,7 @@ def forward( attn_mask_type, softmax_type, window_size, + bottom_right_diagonal, rng_gen, softmax_offset, return_max_logit, @@ -1377,6 +1390,7 @@ def forward( ctx.attn_mask_type = attn_mask_type ctx.softmax_type = softmax_type ctx.window_size = window_size + ctx.bottom_right_diagonal = bottom_right_diagonal ctx.fused_attention_backend = ( fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"] ) @@ -1527,6 +1541,7 @@ def backward(ctx, d_out, *_args): ctx.attn_mask_type, ctx.softmax_type, ctx.window_size, + ctx.bottom_right_diagonal, ctx.deterministic, is_graph_capturing(), ) @@ -1592,6 +1607,7 @@ def backward(ctx, d_out, *_args): ctx.attn_mask_type, ctx.softmax_type, ctx.window_size, + ctx.bottom_right_diagonal, ctx.deterministic, is_graph_capturing(), ) @@ -1631,6 +1647,7 @@ def backward(ctx, d_out, *_args): None, None, None, + None, d_softmax_offset, None, None, @@ -1728,6 +1745,7 @@ def forward( attn_mask_type: str = "causal", attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, fused_attention_backend: tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, @@ -1935,6 +1953,7 @@ def forward( attn_mask_type, self.softmax_type, window_size, + bottom_right_diagonal, None, # rng_gen fused_attention_backend, use_FAv2_bwd, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 51ffbc2e4..5a554d86e 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -228,6 +228,11 @@ class DotProductAttention(TransformerEngineBaseModule): map to ``window_size = (-1, 0)`` and Transformer Engine distinguishes them based on ``attn_mask_type``. Similar to :attr:`attn_mask_type`, ``window_size`` can be overridden by :attr:`window_size` in ``forward`` as well. + bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `attn_mask_type` = + {'causal', 'padding_causal'} and `True` for other mask types. attention_type : str, default = "self" type of attention, either ``"self"`` and ``"cross"``. layer_number : int, default = None @@ -324,6 +329,7 @@ def __init__( qkv_format: str = "sbhd", attn_mask_type: str = "causal", window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, sequence_parallel: bool = False, tp_size: int = 1, get_rng_state_tracker: Optional[Callable] = None, @@ -350,6 +356,7 @@ def __init__( attn_mask_type = "padding_causal" self.attn_mask_type = attn_mask_type self.window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size) + self.bottom_right_diagonal = bottom_right_diagonal if tp_group is None: self.tp_size = tp_size if tp_size == 1: @@ -811,6 +818,7 @@ def forward( max_seqlen_kv: int = None, attn_mask_type: Optional[str] = None, window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, checkpoint_core_attention: bool = False, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, @@ -963,6 +971,16 @@ def forward( causal masks are aligned to the bottom right corner. window_size: Optional[Tuple[int, int]], default = None Sliding window size for local attention. + bottom_right_diagonal: Optional[bool], default = None + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `attn_mask_type` = + {'causal', 'padding_causal'} and `True` for other mask types. + Note: This parameter will be automatically overridden based on the + `attn_mask_type` - it will be forced to `False` for 'causal' and + 'padding_causal' mask types, and forced to `True` for mask types + containing 'bottom_right' (e.g., 'causal_bottom_right', + 'padding_causal_bottom_right'), regardless of the explicitly passed value. checkpoint_core_attention : bool, default = False If true, forward activations for attention are recomputed during the backward pass in order to save memory that would @@ -1081,6 +1099,15 @@ def forward( if window_size is None: window_size = self.window_size window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size) + if bottom_right_diagonal is None: + bottom_right_diagonal = self.bottom_right_diagonal + if attn_mask_type in {"causal", "padding_causal"}: + bottom_right_diagonal = False + if bottom_right_diagonal is None or attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + }: + bottom_right_diagonal = True # checks for qkv_format if qkv_format is None: @@ -1144,6 +1171,8 @@ def forward( assert "padding" in attn_mask_type, "KV caching requires padding mask!" if attn_mask_type == "padding_causal": attn_mask_type = attn_mask_type + "_bottom_right" + # since attention mask is changed, set `bottom_right_diagonal` to True + bottom_right_diagonal = True if self.attention_type != "cross": self.fast_setattr("attention_type", "cross") @@ -1257,7 +1286,6 @@ def forward( if self.layer_number == 1: _alibi_cache["_alibi_slopes_require_update"] = True _alibi_cache["_alibi_bias_require_update"] = True - bottom_right_alignment = (attn_mask_type not in ["causal", "padding_causal"],) if core_attention_bias_type == "alibi": assert ( core_attention_bias is None @@ -1266,7 +1294,7 @@ def forward( _alibi_cache["_num_heads"] != query_layer.shape[-2] or _alibi_cache["_max_seqlen_q"] != max_seqlen_q or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv - or _alibi_cache["_bottom_right_alignment"] != bottom_right_alignment + or _alibi_cache["_bottom_right_alignment"] != bottom_right_diagonal or _alibi_cache["_alibi_slopes"] is None ): _alibi_cache["_alibi_slopes_require_update"] = True @@ -1323,6 +1351,7 @@ def forward( head_dim_v=head_dim_v, attn_mask_type=attn_mask_type, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None, core_attention_bias_type=core_attention_bias_type, core_attention_bias_shape=core_attention_bias_shape, @@ -1446,9 +1475,7 @@ def forward( if use_fused_attention: fu_core_attention_bias_type = core_attention_bias_type fu_core_attention_bias = core_attention_bias - if core_attention_bias_type == "alibi" and ( - alibi_slopes is not None or max_seqlen_q != max_seqlen_kv - ): + if core_attention_bias_type == "alibi" and (alibi_slopes is not None): fu_core_attention_bias_type = "post_scale_bias" _, fu_core_attention_bias = dpa_utils.get_alibi( _alibi_cache, @@ -1457,7 +1484,7 @@ def forward( max_seqlen_kv, alibi_slopes=alibi_slopes, bias_dtype=query_layer.dtype, - bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], + bottom_right_alignment=bottom_right_diagonal, ) if checkpoint_core_attention: return self._checkpointed_attention_forward( @@ -1475,6 +1502,7 @@ def forward( attn_mask_type=attn_mask_type, attention_mask=attention_mask, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, fused_attention_backend=fused_attention_backend, core_attention_bias_type=fu_core_attention_bias_type, core_attention_bias=fu_core_attention_bias, @@ -1505,6 +1533,7 @@ def forward( attn_mask_type=attn_mask_type, attention_mask=attention_mask, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, fused_attention_backend=fused_attention_backend, core_attention_bias_type=fu_core_attention_bias_type, core_attention_bias=fu_core_attention_bias, @@ -1539,6 +1568,7 @@ def forward( attn_mask_type=attn_mask_type, attention_mask=attention_mask, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, @@ -1562,6 +1592,7 @@ def forward( attn_mask_type=attn_mask_type, attention_mask=attention_mask, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index fcac740cc..56e6f093d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -200,6 +200,9 @@ class AttentionParams: `causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`} window_size : Tuple[int, int], default = None Sliding window attention size. + bottom_right_diagonal: bool, default = `None` + Whether to align sliding window and ALiBi diagonal to the bottom right corner + of the softmax matrix. alibi_slopes_shape : Optional[Union[torch.Size, List]], default = None Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`. core_attention_bias_type : str, default = no_bias @@ -249,6 +252,7 @@ class AttentionParams: head_dim_v: int = 64 attn_mask_type: str = "no_mask" window_size: Union[Tuple[int, int], None] = None + bottom_right_diagonal: bool = True alibi_slopes_shape: Union[torch.Size, List, None] = None core_attention_bias_type: str = "no_bias" core_attention_bias_shape: str = "1hss" @@ -325,6 +329,7 @@ def get_attention_backend( head_dim_v = attention_params.head_dim_v attn_mask_type = attention_params.attn_mask_type window_size = attention_params.window_size + bottom_right_diagonal = attention_params.bottom_right_diagonal alibi_slopes_shape = attention_params.alibi_slopes_shape core_attention_bias_type = attention_params.core_attention_bias_type core_attention_bias_shape = attention_params.core_attention_bias_shape @@ -859,39 +864,43 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt # backend | window_size | diagonal alignment # --------------------------------------------------------------------------------- # FlashAttention | (-1, -1) or (>=0, >=0) | bottom right - # FusedAttention | (-1, 0) or (>=0, 0) | top left - # UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | both; + # FusedAttention | (-1, 0) or (>=0, >=0) | top left, bottom right + # UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | top left, bottom right # | | converts window_size to an 'arbitrary' mask if window_size is None: window_size = check_set_window_size(attn_mask_type, window_size) - else: - if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): - if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha): - logger.debug( - "Disabling FusedAttention as it does not support sliding window attention" - " for FP8" - ) - use_fused_attention = False - elif window_size[1] != 0 or attention_dropout != 0.0: - logger.debug( - "Disabling FusedAttention as it only supports sliding window attention " - "with (left, 0) and no dropout" - ) - use_fused_attention = False - elif max_seqlen_q > max_seqlen_kv: - logger.debug( - "Disabling FusedAttention as it does not support sliding window attention " - "with s_q > s_kv for cross-attention" - ) - use_fused_attention = False - if use_flash_attention_2 and (window_size[0] != -1 or window_size[1] not in [-1, 0]): - if not FlashAttentionUtils.is_installed: - FlashAttentionUtils.version_required = PkgVersion("2.3") - elif not FlashAttentionUtils.v2_3_plus: - logger.debug( - "Disabling FlashAttention as sliding window attention requires flash-attn 2.3+" - ) - use_flash_attention_2 = False + if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): + if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha): + logger.debug( + "Disabling FusedAttention as it does not support sliding window attention for FP8" + ) + use_fused_attention = False + elif attention_dropout != 0.0: + logger.debug( + "Disabling FusedAttention as it only supports sliding window attention " + "without dropout" + ) + use_fused_attention = False + elif max_seqlen_q > max_seqlen_kv: + logger.debug( + "Disabling FusedAttention as it does not support sliding window attention " + "with s_q > s_kv for cross-attention" + ) + use_fused_attention = False + if use_flash_attention_2 and (window_size[0] != -1 or window_size[1] not in [-1, 0]): + if not FlashAttentionUtils.is_installed: + FlashAttentionUtils.version_required = PkgVersion("2.3") + elif not FlashAttentionUtils.v2_3_plus: + logger.debug( + "Disabling FlashAttention as sliding window attention requires flash-attn 2.3+" + ) + use_flash_attention_2 = False + elif not bottom_right_diagonal and max_seqlen_q != max_seqlen_kv: + logger.debug( + "Disabling FlashAttention as it only supports sliding window with bottom right" + " diagonal alignment for cross-attention" + ) + use_flash_attention = False # Filter: Attention bias # backend | bias types | ALiBi diagonal alignment @@ -913,6 +922,12 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt elif not FlashAttentionUtils.v2_4_plus: logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+") use_flash_attention_2 = False + elif not bottom_right_diagonal and max_seqlen_q != max_seqlen_kv: + logger.debug( + "Disabling FlashAttention as it only supports ALiBi with bottom right diagonal" + " alignment for cross-attention" + ) + use_flash_attention = False if ( core_attention_bias_type not in ["no_bias", "alibi"] @@ -930,13 +945,12 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if ( use_fused_attention and core_attention_bias_type == "alibi" - and (alibi_slopes_shape is not None or max_seqlen_q != max_seqlen_kv) + and (alibi_slopes_shape is not None) ): fu_core_attention_bias_type = "post_scale_bias" fu_core_attention_bias_requires_grad = False - if alibi_slopes_shape is None: - fu_core_attention_bias_shape = "1hss" - elif len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads: + + if len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads: fu_core_attention_bias_shape = "1hss" elif ( len(alibi_slopes_shape) == 2 diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index d813e7c8f..01c4955d7 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -31,6 +31,7 @@ from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb +from transformer_engine.pytorch.attention.dot_product_attention import utils as dpa_utils from transformer_engine.pytorch.cpu_offload import start_offload, is_cpu_offload_enabled @@ -92,6 +93,11 @@ class MultiheadAttention(torch.nn.Module): map to ``window_size = (-1, 0)`` and Transformer Engine distinguishes them based on ``attn_mask_type``. Similar to :attr:`attn_mask_type`, ``window_size`` can be overridden by :attr:`window_size` in :meth:`forward` as well. + bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `attn_mask_type` = + {`causal`, `padding_causal`} and `True` for other mask types. num_gqa_groups : int, default = None number of GQA groups in the transformer layer. Grouped Query Attention is described in @@ -247,6 +253,7 @@ def __init__( layer_number: Optional[int] = None, attn_mask_type: str = "causal", window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, tp_group: Optional[dist_group_type] = None, tp_size: int = 1, num_gqa_groups: Optional[int] = None, @@ -285,6 +292,7 @@ def __init__( self.qkv_format = qkv_format self.attn_mask_type = attn_mask_type self.window_size = window_size + self.bottom_right_diagonal = bottom_right_diagonal self.layer_number = 1 if layer_number is None else layer_number self.input_layernorm = input_layernorm self.attention_type = attention_type @@ -621,6 +629,7 @@ def forward( encoder_output: Optional[torch.Tensor] = None, attn_mask_type: Optional[str] = None, window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, is_first_microbatch: Optional[bool] = None, checkpoint_core_attention: bool = False, inference_params: Optional[InferenceParams] = None, @@ -667,6 +676,11 @@ def forward( aligned to the bottom right corner. window_size: Optional[Tuple[int, int]], default = None sliding window size for local attention. + bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `attn_mask_type` = + {`causal`, `padding_causal`} and `True` for other mask types. encoder_output : Optional[torch.Tensor], default = None Output of the encoder block to be fed into the decoder block if using ``layer_type="decoder"``. @@ -731,6 +745,17 @@ def forward( if window_size is None: window_size = self.window_size + window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size) + if bottom_right_diagonal is None: + bottom_right_diagonal = self.bottom_right_diagonal + if attn_mask_type in {"causal", "padding_causal"}: + bottom_right_diagonal = False + if bottom_right_diagonal is None or attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + }: + bottom_right_diagonal = True + if "padding" in attn_mask_type and attention_mask is not None: for mask in attention_mask: assert mask.dtype == torch.bool, "Attention mask must be in boolean type!" @@ -1001,6 +1026,7 @@ def forward( attention_mask=attention_mask, attn_mask_type=attn_mask_type, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, checkpoint_core_attention=checkpoint_core_attention, core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index e226ef32d..101e5b252 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -137,6 +137,7 @@ def fused_attn_fwd( attn_mask_type: str = "padding", softmax_type: str = "vanilla", window_size: Tuple[int, int] = (-1, -1), + bottom_right_diagonal: bool = None, rng_gen: torch.Generator = None, softmax_offset: torch.Tensor = None, return_max_logit: bool = False, @@ -212,6 +213,9 @@ def fused_attn_fwd( in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding window and causal mask specifically. + bottom_right_diagonal: bool, default = None + whether to align sliding window and ALiBi diagonal to the top left (False) or + bottom right (True) corner of the softmax matrix. rng_gen : torch.Generator, default = None random number generator; if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen @@ -255,6 +259,12 @@ def fused_attn_fwd( max_logit : if return_max_logit = True, shape [h] and same data type as O; otherwise None """ + if bottom_right_diagonal is None: + bottom_right_diagonal = attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + } + if attn_scale is None: d = q.size(-1) attn_scale = 1.0 / math.sqrt(d) @@ -306,6 +316,7 @@ def fused_attn_fwd( AttnMaskType[attn_mask_type], SoftmaxType[softmax_type], window_size, + bottom_right_diagonal, cu_seqlens_q, cu_seqlens_kv, q, @@ -370,6 +381,7 @@ def fused_attn_bwd( attn_mask_type: str = "padding", softmax_type: str = "vanilla", window_size: Tuple[int, int] = (-1, -1), + bottom_right_diagonal: bool = None, deterministic: bool = False, cuda_graph: bool = False, ) -> Tuple[Union[torch.Tensor, None], ...]: @@ -442,6 +454,9 @@ def fused_attn_bwd( in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding window and causal mask specifically. + bottom_right_diagonal: bool, default = None + whether to align sliding window and ALiBi diagonal to the top left (False) or + bottom right (True) corner of the softmax matrix. deterministic : bool, default = False whether to execute the backward pass with deterministic behaviours. cuda_graph : bool, default = False @@ -462,6 +477,12 @@ def fused_attn_bwd( gradient tensor of softmax offset of shape [1, h_q, 1, 1]. See softmax_type in DotProductAttention for details. """ + if bottom_right_diagonal is None: + bottom_right_diagonal = attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + } + if attn_scale is None: d = q.size(-1) attn_scale = 1.0 / math.sqrt(d) @@ -500,6 +521,7 @@ def fused_attn_bwd( AttnMaskType[attn_mask_type], SoftmaxType[softmax_type], window_size, + bottom_right_diagonal, deterministic, cu_seqlens_q, cu_seqlens_kv, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 591c89f83..f7cf32eaf 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -87,9 +87,10 @@ std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - const std::vector window_size, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const at::ScalarType fake_dtype, const std::optional cu_seqlens_q_padded, + const std::vector window_size, bool bottom_right_diagonal, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, + const py::handle K, const py::handle V, const at::ScalarType fake_dtype, + const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, @@ -99,10 +100,10 @@ std::vector fused_attn_fwd( std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, const std::vector window_size, bool deterministic, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, - const py::handle K, const py::handle V, const py::handle O, const py::handle dO, - const at::ScalarType fake_dtype, const DType dqkv_type, + NVTE_Softmax_Type softmax_type, const std::vector window_size, + bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, + const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index be645d91b..bf62db8c3 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -100,9 +100,10 @@ std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - const std::vector window_size, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const at::ScalarType fake_dtype, const std::optional cu_seqlens_q_padded, + const std::vector window_size, bool bottom_right_diagonal, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, + const py::handle K, const py::handle V, const at::ScalarType fake_dtype, + const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, @@ -235,7 +236,7 @@ std::vector fused_attn_fwd( te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], workspace.data(), + softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -295,7 +296,7 @@ std::vector fused_attn_fwd( te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], workspace.data(), + softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -310,10 +311,10 @@ std::vector fused_attn_fwd( std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, const std::vector window_size, bool deterministic, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, - const py::handle K, const py::handle V, const py::handle O, const py::handle dO, - const at::ScalarType fake_dtype, const DType dqkv_type, + NVTE_Softmax_Type softmax_type, const std::vector window_size, + bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, + const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, @@ -532,14 +533,14 @@ std::vector fused_attn_bwd( // populate tensors with appropriate shapes and dtypes NVTE_SCOPED_GIL_RELEASE({ - nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), - te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), - te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), - te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, - max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], deterministic, cuda_graph, - workspace.data(), at::cuda::getCurrentCUDAStream()); + nvte_fused_attn_bwd( + te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), + &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), + te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, + attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], + window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), + at::cuda::getCurrentCUDAStream()); }); // allocate memory for workspace @@ -549,14 +550,14 @@ std::vector fused_attn_bwd( // execute kernel NVTE_SCOPED_GIL_RELEASE({ - nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), - te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), - te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), - te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, - max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], deterministic, cuda_graph, - workspace.data(), at::cuda::getCurrentCUDAStream()); + nvte_fused_attn_bwd( + te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), + &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), + te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, + attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], + window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), + at::cuda::getCurrentCUDAStream()); }); // destroy tensor wrappers diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 7c3125a16..fdb386919 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -34,7 +34,7 @@ from transformer_engine.pytorch.distributed import get_distributed_world_size from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.module.base import TransformerEngineBaseModule - +import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils warnings.filterwarnings("module", category=DeprecationWarning, module="transformer") @@ -148,11 +148,21 @@ class TransformerLayer(torch.nn.Module): distinguishes them based on :attr:`self_attn_mask_type` or :attr:`enc_dec_attn_mask_type`. Similar to :attr:`self_attn_mask_type`, :attr:`window_size` can be overridden by :attr:`window_size` in :meth:`forward` as well. + bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `self_attn_mask_type` = + {`causal`, `padding_causal`} and `True` for other mask types. enc_dec_attn_mask_type : {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'}, default = "no_mask" type of attention mask passed into softmax operation for decoder. enc_dec_window_size : Optional[Tuple[int, int]], default = None sliding window size for local attention in decoder. + enc_dec_bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the decoder. + If `None`, it will be set to `False` for `enc_dec_attn_mask_type` = + {`causal`, `padding_causal`} and `True` for other mask types. zero_centered_gamma : bool, default = False if set to ``True``, gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to @@ -301,7 +311,9 @@ def __init__( kv_channels: Optional[int] = None, self_attn_mask_type: str = "causal", window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, enc_dec_attn_mask_type: str = "no_mask", + enc_dec_bottom_right_diagonal: Optional[bool] = None, enc_dec_window_size: Optional[Tuple[int, int]] = None, tp_group: Optional[dist_group_type] = None, tp_size: int = 1, @@ -343,8 +355,10 @@ def __init__( self.self_attn_mask_type = self_attn_mask_type self.window_size = window_size + self.bottom_right_diagonal = bottom_right_diagonal self.enc_dec_attn_mask_type = enc_dec_attn_mask_type self.enc_dec_window_size = enc_dec_window_size + self.enc_dec_bottom_right_diagonal = enc_dec_bottom_right_diagonal params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad @@ -606,10 +620,12 @@ def forward( attention_mask: Optional[torch.Tensor] = None, self_attn_mask_type: Optional[str] = None, window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, encoder_output: Optional[torch.Tensor] = None, enc_dec_attn_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, enc_dec_attn_mask_type: Optional[str] = None, enc_dec_window_size: Optional[Tuple[int, int]] = None, + enc_dec_bottom_right_diagonal: Optional[bool] = None, is_first_microbatch: Optional[bool] = None, checkpoint_core_attention: bool = False, inference_params: Optional[InferenceParams] = None, @@ -654,6 +670,11 @@ def forward( causal masks are aligned to the bottom right corner. window_size: Optional[Tuple[int, int]], default = None Sliding window size for local attention in encoder. + bottom_right_diagonal: Optional[bool] = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `self_attn_mask_type` = + {`causal`, `padding_causal`} and `True` for other mask types. encoder_output : Optional[torch.Tensor], default = None Output of the encoder block to be fed into the decoder block if using :attr:`layer_type` = ``"decoder"``. @@ -670,6 +691,11 @@ def forward( Type of attention mask passed into softmax operation for decoder. enc_dec_window_size: Optional[Tuple[int, int]], default = None Sliding window size for local attention in decoder. + enc_dec_bottom_right_diagonal: Optional[bool] = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the decoder. + If `None`, it will be set to `False` for `enc_dec_attn_mask_type` = + {`causal`, `padding_causal`} and `True` for other mask types. is_first_microbatch : {True, False, None}, default = None During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split @@ -736,10 +762,35 @@ def forward( self_attn_mask_type = self.self_attn_mask_type if window_size is None: window_size = self.window_size + window_size = dpa_utils.check_set_window_size(self_attn_mask_type, window_size) + if enc_dec_attn_mask_type is None: enc_dec_attn_mask_type = self.enc_dec_attn_mask_type if enc_dec_window_size is None: enc_dec_window_size = self.enc_dec_window_size + enc_dec_window_size = dpa_utils.check_set_window_size( + enc_dec_attn_mask_type, enc_dec_window_size + ) + + if bottom_right_diagonal is None: + bottom_right_diagonal = self.bottom_right_diagonal + if self_attn_mask_type in {"causal", "padding_causal"}: + bottom_right_diagonal = False + if bottom_right_diagonal is None or self_attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + }: + bottom_right_diagonal = True + + if enc_dec_bottom_right_diagonal is None: + enc_dec_bottom_right_diagonal = self.enc_dec_bottom_right_diagonal + if enc_dec_attn_mask_type in {"causal", "padding_causal"}: + enc_dec_bottom_right_diagonal = False + if enc_dec_bottom_right_diagonal is None or enc_dec_attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + }: + enc_dec_bottom_right_diagonal = True assert ( self_attn_mask_type in AttnMaskTypes @@ -778,6 +829,7 @@ def forward( attention_mask=attention_mask, attn_mask_type=self_attn_mask_type, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, inference_params=inference_params, is_first_microbatch=is_first_microbatch, checkpoint_core_attention=checkpoint_core_attention, @@ -813,6 +865,7 @@ def forward( attention_mask=enc_dec_attn_mask, attn_mask_type=enc_dec_attn_mask_type, window_size=enc_dec_window_size, + bottom_right_diagonal=enc_dec_bottom_right_diagonal, encoder_output=encoder_output, inference_params=inference_params, is_first_microbatch=is_first_microbatch, From 3da26cd1a422dab7d02462c41218ec1d4132c446 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Tue, 27 Jan 2026 12:10:19 -0800 Subject: [PATCH 11/22] [JAX] Use "nyu-mll/glue" instead of "glue" for encoder datasets to fix 404 error (#2625) * Use "nyu-mll/glue" instead of "glue" for encoder datasets to fix 404 error Signed-off-by: Jeremy Berchtold * rename mnist dataset path Signed-off-by: Jeremy Berchtold * add dataset manifest Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold --- examples/jax/datasets.txt | 3 +++ examples/jax/encoder/test_model_parallel_encoder.py | 4 ++-- examples/jax/encoder/test_multigpu_encoder.py | 4 ++-- examples/jax/encoder/test_multiprocessing_encoder.py | 4 ++-- examples/jax/encoder/test_single_gpu_encoder.py | 4 ++-- examples/jax/mnist/test_single_gpu_mnist.py | 4 ++-- 6 files changed, 13 insertions(+), 10 deletions(-) create mode 100644 examples/jax/datasets.txt diff --git a/examples/jax/datasets.txt b/examples/jax/datasets.txt new file mode 100644 index 000000000..fd3f5bc41 --- /dev/null +++ b/examples/jax/datasets.txt @@ -0,0 +1,3 @@ +# Datasets used by TE encoder tests. Pull these to pre-emptively cache datasets +ylecun/mnist +nyu-mll/glue \ No newline at end of file diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 02937bc39..73b93798a 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -219,11 +219,11 @@ def get_datasets(max_seq_len): vocab = {} word_id = 0 - train_ds = load_dataset("glue", "cola", split="train") + train_ds = load_dataset("nyu-mll/glue", "cola", split="train") train_ds.set_format(type="np") train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len) - test_ds = load_dataset("glue", "cola", split="validation") + test_ds = load_dataset("nyu-mll/glue", "cola", split="validation") test_ds.set_format(type="np") test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len) return train_ds, test_ds, word_id diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index 98184ccd7..22a89cc0a 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -197,11 +197,11 @@ def get_datasets(max_seq_len): vocab = {} word_id = 0 - train_ds = load_dataset("glue", "cola", split="train") + train_ds = load_dataset("nyu-mll/glue", "cola", split="train") train_ds.set_format(type="np") train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len) - test_ds = load_dataset("glue", "cola", split="validation") + test_ds = load_dataset("nyu-mll/glue", "cola", split="validation") test_ds.set_format(type="np") test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len) return train_ds, test_ds, word_id diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index 327540521..0166b60ac 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -307,11 +307,11 @@ def get_datasets(max_seq_len): vocab = {} word_id = 0 - train_ds = load_dataset("glue", "cola", split="train") + train_ds = load_dataset("nyu-mll/glue", "cola", split="train") train_ds.set_format(type="np") train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len) - test_ds = load_dataset("glue", "cola", split="validation") + test_ds = load_dataset("nyu-mll/glue", "cola", split="validation") test_ds.set_format(type="np") test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len) return train_ds, test_ds, word_id diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index 82c7fed38..6d67296bd 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -195,11 +195,11 @@ def get_datasets(max_seq_len): vocab = {} word_id = 0 - train_ds = load_dataset("glue", "cola", split="train") + train_ds = load_dataset("nyu-mll/glue", "cola", split="train") train_ds.set_format(type="np") train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len) - test_ds = load_dataset("glue", "cola", split="validation") + test_ds = load_dataset("nyu-mll/glue", "cola", split="validation") test_ds.set_format(type="np") test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len) return train_ds, test_ds, word_id diff --git a/examples/jax/mnist/test_single_gpu_mnist.py b/examples/jax/mnist/test_single_gpu_mnist.py index 0c76d51c3..ef85f4a7a 100644 --- a/examples/jax/mnist/test_single_gpu_mnist.py +++ b/examples/jax/mnist/test_single_gpu_mnist.py @@ -146,7 +146,7 @@ def eval_model(state, test_ds, batch_size, var_collect): def get_datasets(): """Load MNIST train and test datasets into memory.""" - train_ds = load_dataset("mnist", split="train", trust_remote_code=True) + train_ds = load_dataset("ylecun/mnist", split="train", trust_remote_code=True) train_ds.set_format(type="np") batch_size = train_ds["image"].shape[0] shape = (batch_size, IMAGE_H, IMAGE_W, IMAGE_C) @@ -154,7 +154,7 @@ def get_datasets(): "image": train_ds["image"].astype(np.float32).reshape(shape) / 255.0, "label": train_ds["label"], } - test_ds = load_dataset("mnist", split="test", trust_remote_code=True) + test_ds = load_dataset("ylecun/mnist", split="test", trust_remote_code=True) test_ds.set_format(type="np") batch_size = test_ds["image"].shape[0] shape = (batch_size, IMAGE_H, IMAGE_W, IMAGE_C) From cad802fe9a2f5b42dc7b91c2eb5e5142d274a744 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Wed, 28 Jan 2026 01:27:17 +0100 Subject: [PATCH 12/22] [PyTorch] ONNX test fix + export for FP8 attention (#2598) * jjit bug fix Signed-off-by: Pawel Gadzinski * fix' Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fixes Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * lint fixes Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Pawel Gadzinski Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- qa/L1_pytorch_onnx_unittest/test.sh | 3 +- tests/pytorch/test_onnx_export.py | 22 +++++++-- .../dot_product_attention/backends.py | 46 +++++++++++++++++++ .../dot_product_attention.py | 4 +- .../attention/dot_product_attention/utils.py | 4 +- transformer_engine/pytorch/jit.py | 34 ++++++++++---- 6 files changed, 97 insertions(+), 16 deletions(-) diff --git a/qa/L1_pytorch_onnx_unittest/test.sh b/qa/L1_pytorch_onnx_unittest/test.sh index b3a520e12..6f9ff54e4 100644 --- a/qa/L1_pytorch_onnx_unittest/test.sh +++ b/qa/L1_pytorch_onnx_unittest/test.sh @@ -6,4 +6,5 @@ : ${XML_LOG_DIR:=/logs} mkdir -p "$XML_LOG_DIR" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py +# NVTE_UnfusedDPA_Emulate_FP8=1 enables FP8 attention emulation when no native backend is available +NVTE_UnfusedDPA_Emulate_FP8=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index 50cd150c4..9aea3bc27 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -713,6 +713,14 @@ def test_export_layernorm_mlp_activation(seed_default_rng, activation): _test_export_layernorm_mlp(activation=activation) +# Quantization recipes with fp8_dpa=True for attention emulation export test +dpa_quantization_recipes = [None] # None = no quantization +if fp8_available: + dpa_quantization_recipes.append(recipe.DelayedScaling(fp8_dpa=True)) + dpa_quantization_recipes.append(recipe.Float8CurrentScaling(fp8_dpa=True)) + + +@pytest.mark.parametrize("fp8_recipe", dpa_quantization_recipes) @pytest.mark.parametrize( "precision, use_mask, attn_mask_type", [ @@ -730,6 +738,7 @@ def test_export_core_attention( precision: torch.dtype, use_mask: bool, attn_mask_type: str, + fp8_recipe: recipe.Recipe, ): # Set dimensions (these are arbitrary). seq_len, batch_size, num_attention_heads, kv_channels = (64, 4, 1, 64) @@ -749,22 +758,25 @@ def test_export_core_attention( mask_str = get_attn_mask_str(use_mask, attn_mask_type) high_prec_str = dtype2str(precision) - fname = f"te.core_attention{mask_str}{high_prec_str}.onnx" + fp8_str = "_fp8_dpa" if fp8_recipe is not None else "" + fname = f"te.core_attention{fp8_str}{mask_str}{high_prec_str}.onnx" + + is_fp8 = fp8_recipe is not None model = te.attention.DotProductAttention( num_attention_heads=num_attention_heads, kv_channels=kv_channels, - attention_dropout=0.5, qkv_format=qkv_format, attn_mask_type=attn_mask_type, ).to(device="cuda") - do_export(model, inp, fname, input_names=input_names, fp8_recipe=None) - te_outputs = te_infer(model, inp, is_fp8=False, fp8_recipe=None) + do_export(model, inp, fname, input_names=input_names, fp8_recipe=fp8_recipe) + te_outputs = te_infer(model, inp, is_fp8=is_fp8, fp8_recipe=fp8_recipe) serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) if precision in (torch.bfloat16,): return + atol = 5e-1 if is_fp8 else 1e-2 validate_result( - fname, inp, model, is_fp8=True, atol=1e-2, input_names=input_names, te_outputs=te_outputs + fname, inp, model, is_fp8=True, atol=atol, input_names=input_names, te_outputs=te_outputs ) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index ef7fa0dcc..aa6c06395 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -164,6 +164,11 @@ class FP8EmulationFunc(torch.autograd.Function): @staticmethod def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout): # pylint: disable=missing-function-docstring + if is_in_onnx_export_mode(): + return FP8EmulationFunc.onnx_forward( + tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout + ) + if quantizer_name == "QKV_quantizer": query_layer, key_layer, value_layer = [ x.contiguous() for x in [tensor1, tensor2, tensor3] @@ -202,6 +207,47 @@ def backward(ctx, grad1, grad2, grad3): tensors = grad1, grad2, grad3 return tensors[0], tensors[1], tensors[2], None, None, None + @staticmethod + def onnx_forward(tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout=None): + """ + ONNX-compatible forward for FP8 emulation using operations with defined ONNX translations. + """ + # pylint: disable=unused-argument + is_qkv_quantizer = quantizer_name == "QKV_quantizer" + assert isinstance( + quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) + ), "ONNX FP8 emulation path supports only Float8 quantizers." + + if is_qkv_quantizer: + # Flatten + concatenate + quantize + split. Equivalent to combine_and_quantize Case 3. + orig_dtype = tensor1.dtype + shapes = [tensor1.shape, tensor2.shape, tensor3.shape] + numels = [tensor1.numel(), tensor2.numel(), tensor3.numel()] + + # Flatten and concatenate + combined = torch.cat( + [tensor1.reshape(-1), tensor2.reshape(-1), tensor3.reshape(-1)], dim=0 + ) + + # Quantize + dequantize combined tensor using quantizer's ONNX methods + combined_fp8 = quantizer.onnx_quantize(combined) + out = quantizer.onnx_dequantize(combined_fp8).to(orig_dtype) + + # Split back + out1 = out[: numels[0]].reshape(shapes[0]) + out2 = out[numels[0] : numels[0] + numels[1]].reshape(shapes[1]) + out3 = out[numels[0] + numels[1] :].reshape(shapes[2]) + + return out1, out2, out3 + if quantizer_name in ["S_quantizer", "O_quantizer"]: + # Emulate FP8 on single tensor using quantizer's ONNX methods + orig_dtype = tensor1.dtype + t_fp8 = quantizer.onnx_quantize(tensor1) + out = quantizer.onnx_dequantize(t_fp8).to(orig_dtype) + return out, tensor2, tensor3 + # Pass-through + return tensor1, tensor2, tensor3 + class UnfusedDotProductAttention(torch.nn.Module): """Parallel attention w/o QKV and Proj Gemms diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 5a554d86e..5d830dca3 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -1552,7 +1552,9 @@ def forward( ) if use_unfused_attention: - allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" + allow_emulation = ( + os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" or is_in_onnx_export_mode() + ) if checkpoint_core_attention: return self._checkpointed_attention_forward( self.unfused_attention, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 56e6f093d..0c5a51981 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -479,7 +479,9 @@ def get_attention_backend( logger.debug("Disabling FlashAttention 3 for FP8 training") use_flash_attention_3 = False if use_unfused_attention: - allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" + allow_emulation = ( + os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" or is_in_onnx_export_mode() + ) if not allow_emulation: logger.debug("Disabling UnfusedDotProductAttention for FP8 attention") use_unfused_attention = False diff --git a/transformer_engine/pytorch/jit.py b/transformer_engine/pytorch/jit.py index 5884188b7..1b93b8254 100644 --- a/transformer_engine/pytorch/jit.py +++ b/transformer_engine/pytorch/jit.py @@ -46,17 +46,35 @@ def wrapper(*args, **kwargs): # Decorator to disable Torch Dynamo # See: https://github.com/NVIDIA/TransformerEngine/issues/308 -no_torch_dynamo = lambda recursive=True: lambda func: func if torch.__version__ >= "2": import torch._dynamo - if torch.__version__ >= "2.1": - no_torch_dynamo = lambda recursive=True: lambda f: ( - f if is_in_onnx_export_mode() else torch._dynamo.disable(f, recursive=recursive) - ) - else: - # no "recursive" option in pyTorch 2.0 - it acts as if recursive was True - no_torch_dynamo = lambda recursive=True: torch._dynamo.disable + def no_torch_dynamo(recursive=True): + """Decorator to disable Torch Dynamo, except during ONNX export.""" + + def decorator(f): + # no "recursive" option in pyTorch 2.0 - it acts as if recursive was True + disabled_f = ( + torch._dynamo.disable(f, recursive=recursive) + if torch.__version__ >= "2.1" + else torch._dynamo.disable(f) + ) + + @wraps(f) + def wrapper(*args, **kwargs): + if is_in_onnx_export_mode(): + return f(*args, **kwargs) + return disabled_f(*args, **kwargs) + + return wrapper + + return decorator + +else: + # Fallback for PyTorch < 2.0: no-op decorator + def no_torch_dynamo(recursive=True): # pylint: disable=unused-argument + """No-op decorator for PyTorch < 2.0.""" + return lambda func: func def set_jit_fusion_options() -> None: From 9bb9d22645cf5d137a763fe439bae9f4e2b57457 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Wed, 28 Jan 2026 01:28:30 +0100 Subject: [PATCH 13/22] [common] Add support for cuBLASLt GEMM for GroupedTensor (#2502) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * code drop Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add FP8 scale support and fix alignment for grouped GEMM - Add FP8 scale_inv pointer handling in nvte_grouped_gemm for proper FP8 GEMM - Fix random padding in tests to ensure 16-byte alignment for all dtypes - Reorder GroupedGemmSetupWorkspace members for natural alignment - Remove debug prints Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Grouped GEMM: code cleanup and NULL C support - Remove unused alignment parameter from GroupedGemmSetupWorkspace::from_buffers - Simplify select_grouped_operand by removing dead code branches - Add GroupedOperandSelection.tensor field to avoid passing tensor separately - Extract set_fp8_scale_pointers and init_matrix_layouts helpers - Add safety check for FP8 on Hopper column-wise fallback - Support NULL C tensor when beta=0 (uses D as placeholder) - Remove unused get_scale_inv() from test - Add use_null_c test parameter and test case - Fix documentation: alpha/beta are single element tensors only Signed-off-by: Piotr Gadzinski Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Grouped GEMM: per-matrix alpha/beta support - Change alpha/beta from single values to per-matrix arrays - Validate alpha/beta have exactly num_tensors elements - Update kernel to index alpha_ptr[idx] and beta_ptr[idx] - Move alpha/beta validation to validate_grouped_gemm_inputs - Update tests to use per-matrix alpha/beta arrays - Update documentation Signed-off-by: Piotr Gadzinski Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix alpha/beta numel - use SimpleTensor::numel() Signed-off-by: Piotr Gadzinski Signed-off-by: Pawel Gadzinski * Refactor: move grouped GEMM to separate file and cleanup API Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * Require Blackwell (SM100) and cuBLAS 13.1+ for grouped GEMM Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fixes Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/gemm/config.h Co-authored-by: Przemyslaw Tredak Signed-off-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * changed Signed-off-by: Pawel Gadzinski * suggestions Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactored hopper tensor selection Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Pawel Gadzinski Signed-off-by: Piotr Gadzinski Signed-off-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak --- tests/cpp/operator/CMakeLists.txt | 1 + tests/cpp/operator/test_grouped_gemm.cu | 308 +++++++++ tests/cpp/test_common.cu | 163 +++++ tests/cpp/test_common.h | 54 ++ transformer_engine/common/CMakeLists.txt | 1 + transformer_engine/common/gemm/config.cpp | 103 +++ transformer_engine/common/gemm/config.h | 19 + .../common/gemm/cublaslt_gemm.cu | 35 +- .../common/gemm/cublaslt_grouped_gemm.cu | 645 ++++++++++++++++++ .../common/include/transformer_engine/gemm.h | 171 +++++ .../common/util/cuda_runtime.cpp | 8 + transformer_engine/common/util/cuda_runtime.h | 6 + 12 files changed, 1494 insertions(+), 20 deletions(-) create mode 100644 tests/cpp/operator/test_grouped_gemm.cu create mode 100644 transformer_engine/common/gemm/cublaslt_grouped_gemm.cu diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 26efb3796..08a683949 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -30,6 +30,7 @@ add_executable(test_operator test_causal_softmax.cu test_swizzle.cu test_swap_first_dims.cu + test_grouped_gemm.cu ../test_common.cu) # Find required packages diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu new file mode 100644 index 000000000..35c4375cb --- /dev/null +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -0,0 +1,308 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../test_common.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +enum class InputCase { + kFP8Current, + kBF16, +}; + +enum class ShapeCase { + kAllSame, + kSameFirst, + kSameLast, + kAllDifferent, +}; + +size_t grouped_setup_workspace_size(const size_t num_tensors) { + const size_t ptr_bytes = num_tensors * sizeof(void*); + const size_t int_bytes = num_tensors * sizeof(int); + // Layout: 6 pointer arrays (A, B, C, D, alpha, beta) + 6 int arrays (a_rows, a_cols, b_rows, b_cols, d_rows, d_cols) + size_t size = 6 * ptr_bytes + 6 * int_bytes; + const size_t alignment = 256; + size = ((size + alignment - 1) / alignment) * alignment; + return size; +} + +Tensor make_fp8_operand(const std::string& name, const std::vector& shape) { + Tensor input_fp32(name + "_fp32", shape, DType::kFloat32); + fillUniform(&input_fp32); + + Tensor fp8(name, shape, TypeInfo::dtype, true, true, NVTE_DELAYED_TENSOR_SCALING); + + nvte_compute_amax(input_fp32.data(), fp8.data(), 0); + QuantizationConfigWrapper config; + nvte_compute_scale_from_amax(fp8.data(), config, 0); + nvte_quantize(input_fp32.data(), fp8.data(), 0); + return fp8; +} + +Tensor make_bf16_operand(const std::string& name, const std::vector& shape) { + Tensor t(name, shape, DType::kBFloat16); + const size_t numel = shape[0] * shape[1]; + std::vector<__nv_bfloat16> ones(numel, __float2bfloat16(1.0f)); + NVTE_CHECK_CUDA(cudaMemcpy(t.rowwise_dptr(), ones.data(), + numel * sizeof(__nv_bfloat16), cudaMemcpyHostToDevice)); + return t; +} + +struct TestParams { + InputCase input_case; + bool transa; + bool transb; + ShapeCase shape_case; + bool use_null_c = false; // When true, pass nullptr for C (valid when beta=0) +}; + +// Returns a vector of (M, N, K) tuples for each GEMM in the group. +// M - number of rows in output D +// N - number of columns in output D +// K - reduction dimension shared between A and B +std::vector> make_shapes(ShapeCase scase) { + switch (scase) { + case ShapeCase::kAllSame: + return {{64, 64, 32}, {64, 64, 32}, {64, 64, 32}}; + case ShapeCase::kSameFirst: + // Same M (first dim), varying N and K + return {{64, 80, 32}, {64, 96, 48}, {64, 112, 64}}; + case ShapeCase::kSameLast: + // Same N (last dim), varying M and K + return {{64, 80, 32}, {80, 80, 48}, {96, 80, 64}}; + case ShapeCase::kAllDifferent: + default: + return {{64, 96, 32}, {80, 112, 48}, {96, 128, 64}}; + } +} + +void run_grouped_gemm_case(const TestParams& params) { +#if CUBLAS_VERSION < 130100 + GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.1+, but compile-time cuBLAS version is " + << CUBLAS_VERSION << "."; +#else + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP() << "Grouped GEMM requires Blackwell (SM100) or newer."; + } + + const std::vector> shapes = make_shapes(params.shape_case); + + const size_t num_gemms = shapes.size(); + std::vector A_tensors; + std::vector B_tensors; + std::vector D_multi; + + A_tensors.reserve(num_gemms); + B_tensors.reserve(num_gemms); + D_multi.reserve(num_gemms); + + for (size_t i = 0; i < num_gemms; ++i) { + const auto [M, N, K] = shapes[i]; + const std::vector a_shape = params.transa ? std::vector{M, K} + : std::vector{K, M}; + const std::vector b_shape = params.transb ? std::vector{K, N} + : std::vector{N, K}; + switch (params.input_case) { + case InputCase::kFP8Current: { + A_tensors.emplace_back(make_fp8_operand("A" + std::to_string(i), a_shape)); + B_tensors.emplace_back(make_fp8_operand("B" + std::to_string(i), b_shape)); + break; + } + case InputCase::kBF16: { + A_tensors.emplace_back(make_bf16_operand("A" + std::to_string(i), a_shape)); + B_tensors.emplace_back(make_bf16_operand("B" + std::to_string(i), b_shape)); + break; + } + } + D_multi.emplace_back(Tensor("D_multi" + std::to_string(i), + std::vector{M, N}, + DType::kBFloat16)); + } + + std::vector A_ptrs(num_gemms); + std::vector B_ptrs(num_gemms); + std::vector D_ptrs(num_gemms); + std::vector workspaces(num_gemms); + std::vector workspace_ptrs(num_gemms, nullptr); + std::vector A_views; + std::vector B_views; + A_views.reserve(num_gemms); + B_views.reserve(num_gemms); + + // Empty bias/gelu arrays for nvte_multi_tensor_gemm (no epilogues) + std::vector bias_ptrs(num_gemms, nullptr); + std::vector gelu_ptrs(num_gemms, nullptr); + + const size_t cublas_ws_bytes = 32ull * 1024 * 1024; + + for (size_t i = 0; i < num_gemms; ++i) { + A_ptrs[i] = A_tensors[i].data(); + B_ptrs[i] = B_tensors[i].data(); + D_ptrs[i] = D_multi[i].data(); + workspaces[i] = Tensor("workspace" + std::to_string(i), std::vector{cublas_ws_bytes}, DType::kByte); + workspace_ptrs[i] = workspaces[i].data(); + A_views.push_back(&A_tensors[i]); + B_views.push_back(&B_tensors[i]); + } + + nvte_multi_tensor_gemm(A_ptrs.data(), + B_ptrs.data(), + D_ptrs.data(), + bias_ptrs.data(), + gelu_ptrs.data(), + static_cast(num_gemms), + params.transa, + params.transb, + false, // grad + workspace_ptrs.data(), + false, // accumulate + false, // use_split_accumulator + 0, // sm_count + 0); + + GroupedBuffers grouped_A = build_grouped_tensor(A_views, A_tensors[0].scaling_mode()); + GroupedBuffers grouped_B = build_grouped_tensor(B_views, B_tensors[0].scaling_mode()); + + std::vector C_tensors; + std::vector D_group_tensors; + C_tensors.reserve(num_gemms); + D_group_tensors.reserve(num_gemms); + for (size_t i = 0; i < num_gemms; ++i) { + const auto [M, N, K] = shapes[i]; + (void)K; + if (!params.use_null_c) { + C_tensors.emplace_back(Tensor("C" + std::to_string(i), + std::vector{static_cast(M), static_cast(N)}, + DType::kBFloat16)); + } + D_group_tensors.emplace_back(Tensor("D_group" + std::to_string(i), + std::vector{static_cast(M), static_cast(N)}, + DType::kBFloat16)); + NVTE_CHECK_CUDA(cudaMemset(D_group_tensors.back().rowwise_dptr(), 0, bytes(D_group_tensors.back().rowwise_shape(), D_group_tensors.back().dtype()))); + } + + std::vector C_views, D_views; + for (size_t i = 0; i < num_gemms; ++i) { + if (!params.use_null_c) { + C_views.push_back(&C_tensors[i]); + } + D_views.push_back(&D_group_tensors[i]); + } + + std::optional grouped_C; + if (!params.use_null_c) { + grouped_C = build_grouped_tensor(C_views, NVTE_DELAYED_TENSOR_SCALING); + } + GroupedBuffers grouped_D = build_grouped_tensor(D_views, NVTE_DELAYED_TENSOR_SCALING); + + // Per-matrix alpha/beta (all 1.0 and 0.0 respectively) + Tensor alpha_tensor("alpha", std::vector{num_gemms}, DType::kFloat32); + Tensor beta_tensor("beta", std::vector{num_gemms}, DType::kFloat32); + std::vector alpha_vals(num_gemms, 1.f); + std::vector beta_vals(num_gemms, 0.f); + NVTE_CHECK_CUDA(cudaMemcpy(alpha_tensor.rowwise_dptr(), alpha_vals.data(), + num_gemms * sizeof(float), cudaMemcpyHostToDevice)); + NVTE_CHECK_CUDA(cudaMemcpy(beta_tensor.rowwise_dptr(), beta_vals.data(), + num_gemms * sizeof(float), cudaMemcpyHostToDevice)); + + const size_t setup_ws_bytes = grouped_setup_workspace_size(num_gemms); + Tensor setup_ws("setup_ws", std::vector{setup_ws_bytes}, DType::kByte); + Tensor cublas_ws("cublas_ws", std::vector{cublas_ws_bytes}, DType::kByte); + + nvte_grouped_gemm(grouped_A.get_handle(), + params.transa, + grouped_B.get_handle(), + params.transb, + params.use_null_c ? nullptr : grouped_C->get_handle(), + grouped_D.get_handle(), + alpha_tensor.data(), + beta_tensor.data(), + setup_ws.data(), + cublas_ws.data(), + nullptr, // config (use defaults) + 0); + + for (size_t i = 0; i < num_gemms; ++i) { + Tensor grouped_split("grouped_D" + std::to_string(i), + std::vector{static_cast(std::get<0>(shapes[i])), + static_cast(std::get<1>(shapes[i]))}, + D_multi[i].dtype()); + const size_t offset_bytes = static_cast(grouped_D.offsets_host[i]) * grouped_D.elem_size; + NVTE_CHECK_CUDA(cudaMemcpy(grouped_split.rowwise_dptr(), + static_cast(grouped_D.get_data()) + offset_bytes, + grouped_D.tensor_bytes[i], + cudaMemcpyDeviceToDevice)); + grouped_split.to_cpu(); + D_multi[i].to_cpu(); + auto [atol, rtol] = getTolerances(D_multi[i].dtype()); + compareResults("grouped_vs_multi", + grouped_split, + D_multi[i].rowwise_cpu_dptr(), + true, + atol, + rtol); + } +#endif // CUBLAS_VERSION >= 130100 +} + +class GroupedGemmTest : public ::testing::TestWithParam {}; + +TEST_P(GroupedGemmTest, CompareWithMultiTensorGemm) { + run_grouped_gemm_case(GetParam()); +} + +std::string MakeGroupedGemmTestName(const testing::TestParamInfo& info) { + constexpr const char* kInputNames[] = {"FP8Current", "BF16"}; + constexpr const char* kShapeNames[] = {"AllSame", "SameM", "SameN", "AllDiff"}; + const std::string layout = std::string("ta") + (info.param.transa ? "T" : "N") + + "tb" + (info.param.transb ? "T" : "N"); + const std::string null_c = info.param.use_null_c ? "_NullC" : ""; + return std::string(kInputNames[static_cast(info.param.input_case)]) + "_" + + kShapeNames[static_cast(info.param.shape_case)] + "_" + layout + null_c; +} + +// TestParams: {input_case, transa, transb, shape_case, use_null_c} +const std::vector kTestParams = { + // Basic tests + {InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent, false}, + {InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent, false}, + {InputCase::kFP8Current, false, false, ShapeCase::kAllSame, false}, + {InputCase::kBF16, true, false, ShapeCase::kSameFirst, false}, + {InputCase::kBF16, false, true, ShapeCase::kSameLast, false}, + {InputCase::kBF16, false, false, ShapeCase::kAllSame, false}, + {InputCase::kBF16, true, true, ShapeCase::kAllDifferent, false}, + // Test NULL C (valid when beta=0) + {InputCase::kBF16, false, false, ShapeCase::kAllSame, true}, +}; + +INSTANTIATE_TEST_SUITE_P(OperatorTest, + GroupedGemmTest, + ::testing::ValuesIn(kTestParams), + MakeGroupedGemmTestName); + +} // namespace diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index ed961bfe9..af99d9c42 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -1057,4 +1058,166 @@ std::array get_scale_tensor_dims(const size_t rows, return {unpadded_blocks_Y, unpadded_blocks_X, blocks_Y, blocks_X}; } +GroupedBuffers build_grouped_tensor(const std::vector& tensors, + const NVTEScalingMode scaling_mode) { + NVTE_CHECK(!tensors.empty(), "No tensors provided for grouped tensor build."); + const NVTEShape shape = tensors[0]->rowwise_shape(); + const DType dtype = tensors[0]->dtype(); + const size_t num_tensors = tensors.size(); + const size_t elem_size = typeToNumBits(dtype) / 8; + GroupedBuffers grouped; + grouped.elem_size = elem_size; + grouped.num_tensors = num_tensors; + grouped.dtype = dtype; + grouped.scaling_mode = scaling_mode; + grouped.tensor_bytes.resize(num_tensors); + grouped.offsets_host.resize(num_tensors, 0); + + std::vector first_dims(num_tensors); + std::vector last_dims(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + const auto s = tensors[i]->rowwise_shape(); + NVTE_CHECK(s.ndim == 2, "Grouped tensor build expects 2D tensors."); + first_dims[i] = static_cast(s.data[0]); + last_dims[i] = static_cast(s.data[1]); + grouped.tensor_bytes[i] = bytes(s, dtype); + } + + const bool same_first = std::all_of(first_dims.begin(), first_dims.end(), + [&](int64_t v) { return v == first_dims[0]; }); + const bool same_last = std::all_of(last_dims.begin(), last_dims.end(), + [&](int64_t v) { return v == last_dims[0]; }); + + std::vector offsets(num_tensors, 0); + auto random_padding = [&]() -> int64_t { + // Random padding ensuring 16-byte alignment regardless of element size + // cuBLAS requires aligned pointers for vectorized loads + static std::mt19937 gen(12345); + std::uniform_int_distribution dist(0, 3); + // Calculate elements needed for 16-byte alignment in bytes, rounded up + const size_t align_elements = + std::max(1, (16 + elem_size - 1) / elem_size); // 16 bytes / element_size + return dist(gen) * static_cast(align_elements); + }; + + auto numel = [&](size_t idx) -> int64_t { + return first_dims[idx] * last_dims[idx]; + }; + + const bool need_offsets = !same_first || !same_last; + if (need_offsets) { + offsets[0] = 0; + for (size_t i = 1; i < num_tensors; ++i) { + offsets[i] = offsets[i - 1] + numel(i - 1) + random_padding(); + } + } else { + for (size_t i = 0; i < num_tensors; ++i) { + offsets[i] = static_cast(i) * numel(0); + } + } + grouped.offsets_host = offsets; + + int64_t logical_first = 0; + int64_t logical_last = 0; + if (same_first && same_last) { + logical_first = first_dims[0] * static_cast(num_tensors); + logical_last = last_dims[0]; + } else if (same_first && !same_last) { + logical_first = first_dims[0]; + logical_last = std::accumulate(last_dims.begin(), last_dims.end(), int64_t{0}); + } else if (!same_first && same_last) { + logical_first = std::accumulate(first_dims.begin(), first_dims.end(), int64_t{0}); + logical_last = last_dims[0]; + } else { + logical_first = 1; + logical_last = 0; + for (size_t i = 0; i < num_tensors; ++i) { + logical_last += first_dims[i] * last_dims[i]; + } + } + size_t logical_data[2] = {static_cast(logical_first), + static_cast(logical_last)}; + grouped.logical_shape = nvte_make_shape(logical_data, 2); + grouped.handle.reset(nvte_create_grouped_tensor(scaling_mode, num_tensors, grouped.logical_shape)); + + const int64_t last_idx = static_cast(num_tensors - 1); + const int64_t total_elems = need_offsets + ? (offsets[last_idx] + numel(last_idx)) + : (logical_first * logical_last); + const size_t total_bytes = static_cast(total_elems) * elem_size; + + grouped.data = cuda_alloc(total_bytes); + for (size_t i = 0; i < num_tensors; ++i) { + const size_t offset_bytes = static_cast(offsets[i]) * elem_size; + NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.data.get()) + offset_bytes, + tensors[i]->rowwise_dptr(), + grouped.tensor_bytes[i], + cudaMemcpyDeviceToDevice)); + } + + NVTEBasicTensor data_tensor{grouped.data.get(), static_cast(dtype), grouped.logical_shape}; + NVTEGroupedTensor h = grouped.handle.get(); + nvte_set_grouped_tensor_param(&h, kNVTEGroupedRowwiseData, &data_tensor); + + const bool include_columnwise = isFp8Type(dtype) || isFp4Type(dtype); + if (include_columnwise) { + grouped.columnwise_data = cuda_alloc(total_bytes); + for (size_t i = 0; i < num_tensors; ++i) { + const size_t offset_bytes = static_cast(offsets[i]) * elem_size; + NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.columnwise_data.get()) + offset_bytes, + tensors[i]->columnwise_dptr(), + grouped.tensor_bytes[i], + cudaMemcpyDeviceToDevice)); + } + NVTEBasicTensor col_tensor{grouped.columnwise_data.get(), + static_cast(dtype), + grouped.logical_shape}; + nvte_set_grouped_tensor_param(&h, kNVTEGroupedColumnwiseData, &col_tensor); + } + + if (!same_first) { + grouped.first_dims_dev = cuda_alloc(num_tensors * sizeof(int64_t)); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.first_dims_dev.get(), first_dims.data(), + num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); + NVTEShape fd_shape = nvte_make_shape(&num_tensors, 1); + NVTEBasicTensor fd_tensor{grouped.first_dims_dev.get(), kNVTEInt64, fd_shape}; + nvte_set_grouped_tensor_param(&h, kNVTEGroupedFirstDims, &fd_tensor); + } + + if (!same_last) { + grouped.last_dims_dev = cuda_alloc(num_tensors * sizeof(int64_t)); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.last_dims_dev.get(), last_dims.data(), + num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); + NVTEShape ld_shape = nvte_make_shape(&num_tensors, 1); + NVTEBasicTensor ld_tensor{grouped.last_dims_dev.get(), kNVTEInt64, ld_shape}; + nvte_set_grouped_tensor_param(&h, kNVTEGroupedLastDims, &ld_tensor); + } + + if (!same_first || !same_last) { + grouped.offsets_dev = cuda_alloc(num_tensors * sizeof(int64_t)); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.offsets_dev.get(), offsets.data(), + num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); + NVTEShape off_shape = nvte_make_shape(&num_tensors, 1); + NVTEBasicTensor off_tensor{grouped.offsets_dev.get(), kNVTEInt64, off_shape}; + nvte_set_grouped_tensor_param(&h, kNVTEGroupedTensorOffsets, &off_tensor); + } + + if (isFp8Type(dtype)) { + std::vector scale_inv_cpu(num_tensors, 1.f); + for (size_t i = 0; i < num_tensors; ++i) { + tensors[i]->to_cpu(); + scale_inv_cpu[i] = tensors[i]->rowwise_cpu_scale_inv_ptr()[0]; + } + grouped.scale_inv = cuda_alloc(sizeof(float) * num_tensors); + NVTE_CHECK_CUDA(cudaMemcpy(grouped.scale_inv.get(), scale_inv_cpu.data(), + sizeof(float) * num_tensors, cudaMemcpyHostToDevice)); + NVTEShape scale_shape = nvte_make_shape(&num_tensors, 1); + NVTEBasicTensor scale_tensor{grouped.scale_inv.get(), kNVTEFloat32, scale_shape}; + nvte_set_grouped_tensor_param(&h, kNVTEGroupedRowwiseScaleInv, &scale_tensor); + nvte_set_grouped_tensor_param(&h, kNVTEGroupedColumnwiseScaleInv, &scale_tensor); + } + + return grouped; +} + } // namespace test diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index b528a79b4..082677c97 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -504,6 +504,60 @@ int32_t getDeviceComputeCapability(); constexpr int32_t hopperComputeCapability = 90; constexpr int32_t blackwellComputeCapability = 100; +// Custom deleters for RAII +struct CudaDeleter { + void operator()(void* p) const { if (p) cudaFree(p); } +}; +struct GroupedTensorDeleter { + void operator()(NVTEGroupedTensor h) const { if (h) nvte_destroy_grouped_tensor(h); } +}; + +template +using CudaPtr = std::unique_ptr; +using GroupedTensorHandle = std::unique_ptr, GroupedTensorDeleter>; + +// Helper to allocate CUDA memory into a CudaPtr +template +CudaPtr cuda_alloc(size_t bytes) { + void* ptr = nullptr; + NVTE_CHECK_CUDA(cudaMalloc(&ptr, bytes)); + return CudaPtr(static_cast(ptr)); +} + +// Helper owning GPU buffers that back NVTEGroupedTensor. +// NVTEGroupedTensor does not own memory; data/offsets/scales +// must be allocated and freed by the test. +struct GroupedBuffers { + GroupedTensorHandle handle; + CudaPtr<> data; + CudaPtr<> scale_inv; + CudaPtr first_dims_dev; + CudaPtr last_dims_dev; + CudaPtr offsets_dev; + CudaPtr<> columnwise_data; + NVTEShape logical_shape{}; + std::vector offsets_host; + std::vector tensor_bytes; + size_t num_tensors{0}; + size_t elem_size{0}; + DType dtype{DType::kFloat32}; + NVTEScalingMode scaling_mode{NVTE_DELAYED_TENSOR_SCALING}; + + GroupedBuffers() = default; + GroupedBuffers(const GroupedBuffers&) = delete; + GroupedBuffers& operator=(const GroupedBuffers&) = delete; + GroupedBuffers(GroupedBuffers&&) = default; + GroupedBuffers& operator=(GroupedBuffers&&) = default; + ~GroupedBuffers() = default; + + // Convenience accessors for raw pointers + NVTEGroupedTensor get_handle() const { return handle.get(); } + void* get_data() const { return data.get(); } +}; + +GroupedBuffers build_grouped_tensor(const std::vector& tensors, + const NVTEScalingMode scaling_mode); + } // namespace test #if FP4_TYPE_SUPPORTED diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index a83cbe3e3..efe958f84 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -144,6 +144,7 @@ list(APPEND transformer_engine_cuda_sources fused_attn/fused_attn_fp8.cu fused_attn/utils.cu gemm/cublaslt_gemm.cu + gemm/cublaslt_grouped_gemm.cu normalization/layernorm/ln_bwd_semi_cuda_kernel.cu normalization/layernorm/ln_fwd_cuda_kernel.cu normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu diff --git a/transformer_engine/common/gemm/config.cpp b/transformer_engine/common/gemm/config.cpp index 2532e96bb..286fc0cc9 100644 --- a/transformer_engine/common/gemm/config.cpp +++ b/transformer_engine/common/gemm/config.cpp @@ -126,3 +126,106 @@ void nvte_destroy_matmul_config(NVTEMatmulConfig config) { delete reinterpret_cast(config); } } + +NVTEGroupedMatmulConfig nvte_create_grouped_matmul_config() { + return new transformer_engine::GroupedMatmulConfig; +} + +void nvte_get_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config, + NVTEGroupedMatmulConfigAttribute attr, void *buf, + size_t size_in_bytes, size_t *size_written) { + // Write attribute size + NVTE_CHECK(attr < kNVTEGroupedMatmulConfigNumAttributes, + "Invalid NVTEGroupedMatmulConfigAttribute (got ", static_cast(attr), ")"); + NVTE_CHECK(size_written != nullptr, "Invalid size_written (got NULL)"); + const auto &attr_size = transformer_engine::GroupedMatmulConfig::attr_sizes[attr]; + *size_written = attr_size; + + // Return immediately if buffer is not provided + if (buf == nullptr) { + return; + } + + // Check buffer size + NVTE_CHECK(size_in_bytes >= attr_size, + "Buffer is too small for grouped matmul config attribute " + "(attribute ", + static_cast(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes, + " bytes)"); + + // Write to buffer + NVTE_CHECK(config != nullptr, "Invalid NVTEGroupedMatmulConfig (got NULL)"); + const auto &config_ = *reinterpret_cast(config); + switch (attr) { + case kNVTEGroupedMatmulConfigAvgM: { + int64_t val = config_.avg_m.value_or(0); + std::memcpy(buf, &val, attr_size); + break; + } + case kNVTEGroupedMatmulConfigAvgN: { + int64_t val = config_.avg_n.value_or(0); + std::memcpy(buf, &val, attr_size); + break; + } + case kNVTEGroupedMatmulConfigAvgK: { + int64_t val = config_.avg_k.value_or(0); + std::memcpy(buf, &val, attr_size); + break; + } + case kNVTEGroupedMatmulConfigSMCount: + std::memcpy(buf, &config_.sm_count, attr_size); + break; + default: + NVTE_ERROR("Unsupported NVTEGroupedMatmulConfigAttribute (got ", static_cast(attr), ")"); + } +} + +void nvte_set_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config, + NVTEGroupedMatmulConfigAttribute attr, + const void *buf, size_t size_in_bytes) { + // Check attribute and buffer + NVTE_CHECK(attr < kNVTEGroupedMatmulConfigNumAttributes, + "Invalid NVTEGroupedMatmulConfigAttribute (got ", static_cast(attr), ")"); + const auto &attr_size = transformer_engine::GroupedMatmulConfig::attr_sizes[attr]; + NVTE_CHECK(size_in_bytes >= attr_size, + "Buffer is too small for grouped matmul config attribute " + "(attribute ", + static_cast(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes, + " bytes)"); + NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)"); + + // Read from buffer + NVTE_CHECK(config != nullptr, "Invalid NVTEGroupedMatmulConfig (got NULL)"); + auto &config_ = *reinterpret_cast(config); + switch (attr) { + case kNVTEGroupedMatmulConfigAvgM: { + int64_t val; + std::memcpy(&val, buf, attr_size); + config_.avg_m = val; + break; + } + case kNVTEGroupedMatmulConfigAvgN: { + int64_t val; + std::memcpy(&val, buf, attr_size); + config_.avg_n = val; + break; + } + case kNVTEGroupedMatmulConfigAvgK: { + int64_t val; + std::memcpy(&val, buf, attr_size); + config_.avg_k = val; + break; + } + case kNVTEGroupedMatmulConfigSMCount: + std::memcpy(&config_.sm_count, buf, attr_size); + break; + default: + NVTE_ERROR("Unsupported NVTEGroupedMatmulConfigAttribute (got ", static_cast(attr), ")"); + } +} + +void nvte_destroy_grouped_matmul_config(NVTEGroupedMatmulConfig config) { + if (config != nullptr) { + delete reinterpret_cast(config); + } +} diff --git a/transformer_engine/common/gemm/config.h b/transformer_engine/common/gemm/config.h index 86a617b5f..ad38e8833 100644 --- a/transformer_engine/common/gemm/config.h +++ b/transformer_engine/common/gemm/config.h @@ -9,6 +9,9 @@ #include +#include +#include + namespace transformer_engine { struct MatmulConfig { @@ -31,6 +34,22 @@ struct MatmulConfig { }; }; +struct GroupedMatmulConfig { + // Average dimension hints for cuBLASLt algorithm selection heuristics. + // nullopt means "not set" - compute automatically from tensor shapes. + std::optional avg_m; + std::optional avg_n; + std::optional avg_k; + + // Number of streaming multiprocessors to use in GEMM kernel + int sm_count = 0; + + // Note: API transfers the value type, not std::optional + static constexpr size_t attr_sizes[] = {sizeof(decltype(avg_m)::value_type), + sizeof(decltype(avg_n)::value_type), + sizeof(decltype(avg_k)::value_type), sizeof(sm_count)}; +}; + } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_GEMM_CONFIG_H_ diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 02faad40d..e4e97abd9 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -302,13 +302,6 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla return ret; } -/* cuBLAS version number at run-time */ -size_t cublas_version() { - // Cache version to avoid cuBLAS logging overhead - static size_t version = cublasLtGetVersion(); - return version; -} - } // namespace namespace transformer_engine { @@ -501,8 +494,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, #endif // CUBLAS_VERSION >= 120800 } else if (mxfp8_gemm) { #if CUBLAS_VERSION >= 120800 - NVTE_CHECK(cublas_version() >= 120800, - "MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version()); + NVTE_CHECK(cuda::cublas_version() >= 120800, + "MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ", + cuda::cublas_version()); // Check that scales are in expected format NVTE_CHECK(inputA->with_gemm_swizzled_scales, @@ -524,7 +518,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, // Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling. // CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set. - if (cublas_version() <= 120803) { + if (cuda::cublas_version() <= 120803) { const int64_t dummy_a_vec_stride = 1; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride, @@ -536,8 +530,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, #endif // CUBLAS_VERSION >= 120800 } else if (use_fp4) { // NVFP4 GEMM #if CUBLAS_VERSION >= 120800 - NVTE_CHECK(cublas_version() >= 120800, - "FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version()); + NVTE_CHECK(cuda::cublas_version() >= 120800, + "FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is ", + cuda::cublas_version()); // Check that scales are in expected format NVTE_CHECK(inputA->with_gemm_swizzled_scales, @@ -572,9 +567,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, (inputB->scaling_mode == NVTE_BLOCK_SCALING_1D || inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) { #if CUBLAS_VERSION >= 120900 - NVTE_CHECK(cublas_version() >= 120900, + NVTE_CHECK(cuda::cublas_version() >= 120900, "FP8 block scaling requires cuBLAS 12.9+, but run-time cuBLAS version is ", - cublas_version()); + cuda::cublas_version()); // Check that matrix formats are valid NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D && @@ -607,7 +602,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, } #if CUBLAS_VERSION >= 120800 - if (cublas_version() >= 120800) { + if (cuda::cublas_version() >= 120800) { NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scaling_mode_a, sizeof(scaling_mode_a))); @@ -624,7 +619,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax))); #if CUBLAS_VERSION >= 120800 - if (cublas_version() >= 120800) { + if (cuda::cublas_version() >= 120800) { // NOTE: In all current cases where FP8 output is supported, the input is // scaled identically to the output. NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, @@ -711,9 +706,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000, "Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA version is ", cuda::cudart_version()); - NVTE_CHECK(cublas_version() >= 120205 && cublas_version() < 130000, + NVTE_CHECK(cuda::cublas_version() >= 120205 && cuda::cublas_version() < 130000, "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS version is ", - cublas_version()); + cuda::cublas_version()); if (m_split == 0) m_split = 1; if (n_split == 0) n_split = 1; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( @@ -939,9 +934,9 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor "Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA version is ", transformer_engine::cuda::cudart_version()); NVTE_CHECK( - cublas_version() >= 120205 && cublas_version() < 130000, + cuda::cublas_version() >= 120205 && cuda::cublas_version() < 130000, "Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is ", - cublas_version()); + cuda::cublas_version()); const Tensor *inputA = convertNVTETensorCheck(A); const Tensor *inputB = convertNVTETensorCheck(B); diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu new file mode 100644 index 000000000..a1206474e --- /dev/null +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -0,0 +1,645 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include + +#include "../common.h" +#include "../util/cuda_runtime.h" +#include "../util/handle_manager.h" +#include "../util/logging.h" +#include "./config.h" + +namespace { + +inline void CreateCublasHandle(cublasLtHandle_t *handle) { + NVTE_CHECK_CUBLAS(cublasLtCreate(handle)); +} + +} // namespace + +#if CUBLAS_VERSION >= 130100 + +namespace { + +// Helper struct to pass per-tensor shape/offset info (pointer or uniform value) +struct TensorShapeInfo { + const int64_t *first_dims; // nullptr if uniform + const int64_t *last_dims; // nullptr if uniform + const int64_t *offsets; // nullptr if need to compute + int64_t uniform_first; // used if first_dims == nullptr + int64_t uniform_last; // used if last_dims == nullptr + + // Create from GroupedTensor + static TensorShapeInfo from_tensor(const transformer_engine::GroupedTensor *t) { + const bool has_first = t->first_dims.has_data(); + const bool has_last = t->last_dims.has_data(); + // When per-tensor dims are not provided, we must be in the uniform-shape case. + NVTE_CHECK(has_first || t->all_same_first_dim(), + "GroupedTensor is missing first_dims for varying shapes"); + NVTE_CHECK(has_last || t->all_same_last_dim(), + "GroupedTensor is missing last_dims for varying shapes"); + + const int64_t *first_ptr = + has_first ? static_cast(t->first_dims.dptr) : nullptr; + const int64_t *last_ptr = has_last ? static_cast(t->last_dims.dptr) : nullptr; + + const int64_t uniform_first = has_first ? 0 : static_cast(t->get_common_first_dim()); + const int64_t uniform_last = has_last ? 0 : static_cast(t->get_common_last_dim()); + + return {first_ptr, last_ptr, + t->tensor_offsets.has_data() ? static_cast(t->tensor_offsets.dptr) + : nullptr, + uniform_first, uniform_last}; + } + + // Create for C tensor (uses D's dimensions, only has offsets) + static TensorShapeInfo create_shape_info_for_C(const transformer_engine::GroupedTensor *C, + const transformer_engine::GroupedTensor *D) { + const bool has_first = D->first_dims.has_data(); + const bool has_last = D->last_dims.has_data(); + NVTE_CHECK(has_first || D->all_same_first_dim(), + "GroupedTensor D is missing first_dims for varying shapes"); + NVTE_CHECK(has_last || D->all_same_last_dim(), + "GroupedTensor D is missing last_dims for varying shapes"); + + const int64_t *first_ptr = + has_first ? static_cast(D->first_dims.dptr) : nullptr; + const int64_t *last_ptr = has_last ? static_cast(D->last_dims.dptr) : nullptr; + const int64_t uniform_first = has_first ? 0 : static_cast(D->get_common_first_dim()); + const int64_t uniform_last = has_last ? 0 : static_cast(D->get_common_last_dim()); + + return {first_ptr, last_ptr, + C->tensor_offsets.has_data() ? static_cast(C->tensor_offsets.dptr) + : nullptr, + uniform_first, uniform_last}; + } +}; + +// Helper functions to compute average dimensions from logical_shape for heuristics +// These are hints for cuBLASLt algorithm selection, don't need to be exact +inline int64_t compute_avg_first_dim(const transformer_engine::GroupedTensor *t) { + // logical_shape[0] is either num_tensors*M (uniform) or sum_of_M (varying first) + // In both cases, dividing by num_tensors gives the average + return static_cast(t->logical_shape.data[0]) / static_cast(t->num_tensors); +} + +inline int64_t compute_avg_last_dim(const transformer_engine::GroupedTensor *t) { + if (t->all_same_last_dim()) { + // logical_shape[1] is the common N + return static_cast(t->logical_shape.data[1]); + } + // When varying, logical_shape[1] should be sum of last dims if provided; otherwise fallback to avg via division. + return static_cast(t->logical_shape.data[1]) / static_cast(t->num_tensors); +} + +// Workspace layout for grouped GEMM +struct GroupedGemmSetupWorkspace { + void **A_ptrs; + void **B_ptrs; + void **C_ptrs; + void **D_ptrs; + float **alpha_ptrs; + float **beta_ptrs; + // Storage dimensions for cuBLAS matrix layouts + int *a_rows; + int *a_cols; + int *b_rows; + int *b_cols; + int *d_rows; // M (first dim) - also used for C + int *d_cols; // N (last dim) - also used for C + + // Initialize from workspace buffer + // Layout: all pointer arrays first (8-byte aligned), then int arrays (4-byte aligned) + static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors) { + GroupedGemmSetupWorkspace ws; + size_t offset = 0; + const size_t ptr_size = num_tensors * sizeof(void *); + const size_t int_size = num_tensors * sizeof(int); + + // Pointer arrays first (all 8-byte aligned) + ws.A_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.B_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.C_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.D_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.alpha_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + ws.beta_ptrs = reinterpret_cast(setup_ws_ptr + offset); + offset += ptr_size; + + // Int arrays for storage dimensions (4-byte aligned) + ws.a_rows = reinterpret_cast(setup_ws_ptr + offset); + offset += int_size; + ws.a_cols = reinterpret_cast(setup_ws_ptr + offset); + offset += int_size; + ws.b_rows = reinterpret_cast(setup_ws_ptr + offset); + offset += int_size; + ws.b_cols = reinterpret_cast(setup_ws_ptr + offset); + offset += int_size; + ws.d_rows = reinterpret_cast(setup_ws_ptr + offset); + offset += int_size; + ws.d_cols = reinterpret_cast(setup_ws_ptr + offset); + + return ws; + } + + // Calculate required size for setup workspace + static size_t required_setup_size(size_t num_tensors, size_t alignment) { + const size_t ptr_size = num_tensors * sizeof(void *); + const size_t int_size = num_tensors * sizeof(int); + // Layout: 6 ptr arrays, then 6 int arrays + size_t size = 6 * ptr_size + 6 * int_size; + size = ((size + alignment - 1) / alignment) * alignment; + return size; + } +}; + +// ----------------------------------------------------------------------------- +// Helper routines to keep nvte_grouped_gemm readable +// ----------------------------------------------------------------------------- +inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor *inputA, + const transformer_engine::GroupedTensor *inputB, + const transformer_engine::GroupedTensor *inputC, + const transformer_engine::GroupedTensor *outputD, + const transformer_engine::Tensor *alpha_tensor, + const transformer_engine::Tensor *beta_tensor) { + const size_t num_tensors = inputA->num_tensors; + NVTE_CHECK(num_tensors >= 1, "Grouped GEMM: number of tensors must be at least 1"); + NVTE_CHECK(inputB->num_tensors == num_tensors, + "Grouped GEMM: A and B must have the same number of tensors"); + // C can be NULL (will use D as C when beta=0) + if (inputC != nullptr) { + NVTE_CHECK(inputC->num_tensors == num_tensors, + "Grouped GEMM: A and C must have the same number of tensors"); + } + NVTE_CHECK(outputD->num_tensors == num_tensors, + "Grouped GEMM: A and D must have the same number of tensors"); + + // Validate alpha/beta have per-matrix values + const size_t alpha_numel = alpha_tensor->data.numel(); + const size_t beta_numel = beta_tensor->data.numel(); + NVTE_CHECK(alpha_numel == num_tensors, "Grouped GEMM: alpha must have num_tensors (", num_tensors, + ") elements, got ", alpha_numel); + NVTE_CHECK(beta_numel == num_tensors, "Grouped GEMM: beta must have num_tensors (", num_tensors, + ") elements, got ", beta_numel); + + auto is_fp8_or_16bit = [](transformer_engine::DType dtype) { + return dtype == transformer_engine::DType::kFloat8E4M3 || + dtype == transformer_engine::DType::kFloat8E5M2 || + dtype == transformer_engine::DType::kBFloat16 || + dtype == transformer_engine::DType::kFloat16; + }; + auto is_output_dtype = [](transformer_engine::DType dtype) { + return dtype == transformer_engine::DType::kBFloat16 || + dtype == transformer_engine::DType::kFloat16 || + dtype == transformer_engine::DType::kFloat32; + }; + NVTE_CHECK(is_fp8_or_16bit(inputA->dtype()) && is_fp8_or_16bit(inputB->dtype()), + "Grouped GEMM inputs must be FP8, BF16, or FP16."); + // Only check C dtype if C is provided + if (inputC != nullptr) { + NVTE_CHECK(is_output_dtype(inputC->dtype()), "Grouped GEMM: C must be BF16, FP16, or FP32."); + } + NVTE_CHECK(is_output_dtype(outputD->dtype()), "Grouped GEMM: D must be BF16, FP16, or FP32."); + NVTE_CHECK(inputA->has_data() || inputA->has_columnwise_data(), + "Grouped GEMM: A tensor is missing both row-wise and column-wise data"); + NVTE_CHECK(inputB->has_data() || inputB->has_columnwise_data(), + "Grouped GEMM: B tensor is missing both row-wise and column-wise data"); +} + +// Select row-wise vs column-wise storage and adjust transpose flag for grouped GEMM. +// Mirrors the non-grouped GEMM logic for FP8 layout handling (TN-only on Hopper) and +// fallback to column-wise data when row-wise is absent. +// Contains all information needed for GEMM setup - shape already accounts for storage layout. +struct GroupedOperandSelection { + TensorShapeInfo shape; // Shape info with dims already swapped for columnwise if needed + char *dptr = nullptr; + void *scale_inv = nullptr; + transformer_engine::DType dtype = transformer_engine::DType::kNumTypes; + bool trans = false; +}; + +// Helper to create TensorShapeInfo from a GroupedTensor, optionally swapping first/last dims. +// When swap_dims=true, first_dims and last_dims are swapped to account for columnwise storage. +// Note: tensor_offsets are the same for rowwise and columnwise data (same element count per tensor). +inline TensorShapeInfo create_shape_info(const transformer_engine::GroupedTensor *t, + bool swap_dims) { + const bool has_first = t->first_dims.has_data(); + const bool has_last = t->last_dims.has_data(); + NVTE_CHECK(has_first || t->all_same_first_dim(), + "GroupedTensor is missing first_dims for varying shapes"); + NVTE_CHECK(has_last || t->all_same_last_dim(), + "GroupedTensor is missing last_dims for varying shapes"); + + const int64_t *first_ptr = has_first ? static_cast(t->first_dims.dptr) : nullptr; + const int64_t *last_ptr = has_last ? static_cast(t->last_dims.dptr) : nullptr; + const int64_t uniform_first = has_first ? 0 : static_cast(t->get_common_first_dim()); + const int64_t uniform_last = has_last ? 0 : static_cast(t->get_common_last_dim()); + + const int64_t *offsets_ptr = + t->tensor_offsets.has_data() ? static_cast(t->tensor_offsets.dptr) : nullptr; + + if (swap_dims) { + // Swap first/last to account for columnwise (transposed) storage + return {last_ptr, first_ptr, offsets_ptr, uniform_last, uniform_first}; + } + return {first_ptr, last_ptr, offsets_ptr, uniform_first, uniform_last}; +} + +inline GroupedOperandSelection select_grouped_operand(const transformer_engine::GroupedTensor *t, + bool trans, bool is_A) { + using namespace transformer_engine; + const bool has_row = t->has_data(); + const bool has_col = t->has_columnwise_data(); + NVTE_CHECK(has_row || has_col, + "Grouped GEMM operand is missing both row-wise and column-wise data"); + + // Currently only unquantized data and tensor-scaled FP8 are supported. + const auto sm = t->scaling_mode; + NVTE_CHECK(sm == NVTE_DELAYED_TENSOR_SCALING, + "Grouped GEMM is only supported with unquantized data and tensor-scaled FP8 data"); + + const DType row_dtype = t->data.dtype; + const DType col_dtype = t->columnwise_data.dtype; + GroupedOperandSelection sel; + sel.trans = trans; + + const DType rep_dtype = has_row ? row_dtype : col_dtype; + const bool is_fp8 = is_fp8_dtype(rep_dtype); + const bool non_tn_fp8_ok = nvte_is_non_tn_fp8_gemm_supported(); + + // Helper to select columnwise storage (swaps dims in shape) + auto use_columnwise = [&]() { + sel.dptr = static_cast(t->columnwise_data.dptr); + sel.scale_inv = t->columnwise_scale_inv.dptr; + sel.dtype = col_dtype; + sel.shape = create_shape_info(t, /*swap_dims=*/true); + }; + + // Helper to select row-wise storage + auto use_rowwise = [&]() { + sel.dptr = static_cast(t->data.dptr); + sel.scale_inv = t->scale_inv.dptr; + sel.dtype = row_dtype; + sel.shape = create_shape_info(t, /*swap_dims=*/false); + }; + + // Hopper-style TN-only FP8: force TN by switching layout and flipping transpose when needed. + if (is_fp8 && !non_tn_fp8_ok) { + if (is_A) { + if (!sel.trans) { + NVTE_CHECK(has_col, "Grouped GEMM: A is missing column-wise data needed for FP8 TN layout"); + use_columnwise(); + sel.trans = true; // using pre-transposed storage + return sel; + } + } else { // B + if (sel.trans) { + NVTE_CHECK(has_col, "Grouped GEMM: B is missing column-wise data needed for FP8 TN layout"); + use_columnwise(); + sel.trans = false; // using pre-transposed storage + return sel; + } + } + } + + // If only column-wise data is available, mirror the transpose flag (pre-transposed storage). + if (!has_row && has_col) { + // On Hopper FP8, this would break TN requirement - should have been handled above + NVTE_CHECK( + !is_fp8 || non_tn_fp8_ok, + "Grouped GEMM: FP8 on Hopper requires row-wise data for this transpose configuration"); + use_columnwise(); + sel.trans = !trans; // flip transpose for pre-transposed storage + return sel; + } + + // Default: use row-wise data + use_rowwise(); + return sel; +} + +inline void *validate_and_get_workspace_ptr(transformer_engine::Tensor *ws, size_t required_size, + const char *workspace_name) { + NVTE_CHECK(ws != nullptr, workspace_name, " tensor is null."); + const size_t provided_size = get_buffer_size_bytes(ws->data.numel(), ws->data.dtype); + NVTE_CHECK(provided_size >= required_size, "Grouped GEMM: Insufficient ", workspace_name, + ". Required: ", required_size, " bytes, Available: ", provided_size, " bytes."); + return ws->data.dptr; +} + +inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, + cublasLtMatrixLayoutOpaque_t &descB, + cublasLtMatrixLayoutOpaque_t &descC, + cublasLtMatrixLayoutOpaque_t &descD, + const GroupedGemmSetupWorkspace &ws, + const GroupedOperandSelection &A_sel, + const GroupedOperandSelection &B_sel, + const transformer_engine::GroupedTensor *D, size_t num_tensors) { + const cudaDataType_t A_type = get_cuda_dtype(A_sel.dtype); + const cudaDataType_t B_type = get_cuda_dtype(B_sel.dtype); + const cudaDataType_t D_type = get_cuda_dtype(D->dtype()); + + // Storage dimensions computed by kernel, leading dimension = rows + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, ws.a_rows, + ws.a_cols, ws.a_rows)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, ws.b_rows, + ws.b_cols, ws.b_rows)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, ws.d_rows, + ws.d_cols, ws.d_rows)); + NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, ws.d_rows, + ws.d_cols, ws.d_rows)); +} + +inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOperation_t op_A, + cublasOperation_t op_B) { + NVTE_CHECK_CUBLAS(cublasLtMatmulDescInit(&matmulDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F)); + + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &op_A, + sizeof(op_A))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &op_B, + sizeof(op_B))); + + cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, + &pointer_mode, sizeof(pointer_mode))); + + int64_t alphabeta_batch_stride = 1; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, + CUBLASLT_MATMUL_DESC_ALPHA_BATCH_STRIDE, + &alphabeta_batch_stride, sizeof(int64_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, + CUBLASLT_MATMUL_DESC_BETA_BATCH_STRIDE, + &alphabeta_batch_stride, sizeof(int64_t))); +} + +inline void set_fp8_scale_pointers(cublasLtMatmulDescOpaque_t &matmulDesc, + const GroupedOperandSelection &A_sel, + const GroupedOperandSelection &B_sel) { + const bool is_fp8_a = is_fp8_dtype(A_sel.dtype); + const bool is_fp8_b = is_fp8_dtype(B_sel.dtype); + if (!is_fp8_a && !is_fp8_b) return; + + if (is_fp8_a) { + void *a_scale_inv = A_sel.scale_inv; + NVTE_CHECK(a_scale_inv != nullptr, "FP8 grouped GEMM: A scale_inv is required"); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + &matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_scale_inv, sizeof(a_scale_inv))); + } + if (is_fp8_b) { + void *b_scale_inv = B_sel.scale_inv; + NVTE_CHECK(b_scale_inv != nullptr, "FP8 grouped GEMM: B scale_inv is required"); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + &matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_scale_inv, sizeof(b_scale_inv))); + } +} + +// Constants for grouped GEMM workspace (declared early for use in heuristics) +static constexpr size_t kGroupedGemmAlignment = 256; +static constexpr size_t kGroupedGemmCublasWorkspaceSize = 32ull * 1024 * 1024; // 32 MiB + +inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle, + cublasLtMatmulDescOpaque_t &matmulDesc, + cublasLtMatrixLayoutOpaque_t &descA, + cublasLtMatrixLayoutOpaque_t &descB, + cublasLtMatrixLayoutOpaque_t &descC, + cublasLtMatrixLayoutOpaque_t &descD, + int64_t avg_m, int64_t avg_n, int64_t avg_k) { + cublasLtMatmulPreferenceOpaque_t preference; + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceInit(&preference)); + NVTE_CHECK_CUBLAS( + cublasLtMatmulPreferenceSetAttribute(&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &kGroupedGemmCublasWorkspaceSize, sizeof(size_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_ROWS, &avg_m, sizeof(int64_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_COLS, &avg_n, sizeof(int64_t))); + NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_GROUPED_AVERAGE_REDUCTION_DIM, &avg_k, sizeof(int64_t))); + + cublasLtMatmulHeuristicResult_t heuristicResult; + int returnedResults = 0; + auto status = cublasLtMatmulAlgoGetHeuristic(handle, &matmulDesc, &descA, &descB, &descC, &descD, + &preference, 1, &heuristicResult, &returnedResults); + NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED, + "Unable to find suitable cuBLAS grouped GEMM algorithm"); + NVTE_CHECK_CUBLAS(status); + NVTE_CHECK(returnedResults > 0, "No suitable algorithm found for grouped GEMM"); + return heuristicResult.algo; +} + +// Single kernel that sets up all GEMM parameters. +// Rationale: cuBLASLt grouped matmul API needs flat arrays of pointers and per-matrix dimensions, +// but NVTEGroupedTensor stores a single contiguous buffer + optional per-tensor offsets/shapes. +// We bridge the mismatch on GPU by computing per-group pointers and storage dims in one kernel. +__global__ void setup_grouped_gemm_kernel( + // Output arrays + void **A_ptrs, void **B_ptrs, void **C_ptrs, void **D_ptrs, int *a_rows, int *a_cols, + int *b_rows, int *b_cols, int *d_rows, int *d_cols, float **alpha_ptrs, float **beta_ptrs, + // Inputs + char *a_base, char *b_base, char *c_base, char *d_base, TensorShapeInfo A_meta, + TensorShapeInfo B_meta, TensorShapeInfo C_meta, TensorShapeInfo D_meta, size_t a_elem_size, + size_t b_elem_size, size_t c_elem_size, size_t d_elem_size, float *alpha_ptr, float *beta_ptr, + size_t num_tensors) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= num_tensors) return; + + // Get dimensions for this tensor (from array or uniform value) + int64_t a_first = A_meta.first_dims ? A_meta.first_dims[idx] : A_meta.uniform_first; + int64_t a_last = A_meta.last_dims ? A_meta.last_dims[idx] : A_meta.uniform_last; + int64_t b_first = B_meta.first_dims ? B_meta.first_dims[idx] : B_meta.uniform_first; + int64_t b_last = B_meta.last_dims ? B_meta.last_dims[idx] : B_meta.uniform_last; + int64_t d_first = D_meta.first_dims ? D_meta.first_dims[idx] : D_meta.uniform_first; + int64_t d_last = D_meta.last_dims ? D_meta.last_dims[idx] : D_meta.uniform_last; + + // Compute offsets (from array or compute from uniform dims) + int64_t a_offset = + A_meta.offsets ? A_meta.offsets[idx] : (idx * A_meta.uniform_first * A_meta.uniform_last); + int64_t b_offset = + B_meta.offsets ? B_meta.offsets[idx] : (idx * B_meta.uniform_first * B_meta.uniform_last); + int64_t c_offset = + C_meta.offsets ? C_meta.offsets[idx] : (idx * C_meta.uniform_first * C_meta.uniform_last); + int64_t d_offset = + D_meta.offsets ? D_meta.offsets[idx] : (idx * D_meta.uniform_first * D_meta.uniform_last); + + // Compute data pointers + A_ptrs[idx] = a_base + a_offset * a_elem_size; + B_ptrs[idx] = b_base + b_offset * b_elem_size; + C_ptrs[idx] = c_base + c_offset * c_elem_size; + D_ptrs[idx] = d_base + d_offset * d_elem_size; + + // Compute storage dimensions for cuBLAS matrix layouts. + // For INPUTS (A, B): Row-wise storage is seen as transposed column-major by cuBLAS, + // so rows=last, cols=first. For columnwise, dims are already swapped. + a_rows[idx] = static_cast(a_last); + a_cols[idx] = static_cast(a_first); + b_rows[idx] = static_cast(b_last); + b_cols[idx] = static_cast(b_first); + // For OUTPUTS (D, C): cuBLAS writes in column-major, so rows=first (M), cols=last (N). + d_rows[idx] = static_cast(d_first); + d_cols[idx] = static_cast(d_last); + + // Fill alpha/beta pointers (per-matrix) + alpha_ptrs[idx] = alpha_ptr + idx; + beta_ptrs[idx] = beta_ptr + idx; +} + +// Launch the setup kernel to populate workspace arrays +inline void launch_grouped_gemm_setup( + const GroupedGemmSetupWorkspace &ws, const GroupedOperandSelection &A_sel, + const GroupedOperandSelection &B_sel, const transformer_engine::GroupedTensor *C, + const transformer_engine::GroupedTensor *D, const transformer_engine::Tensor *alpha_tensor, + const transformer_engine::Tensor *beta_tensor, size_t num_tensors, cudaStream_t stream) { + // Use shape info from selection (already accounts for columnwise dimension swap) + TensorShapeInfo A_meta = A_sel.shape; + TensorShapeInfo B_meta = B_sel.shape; + TensorShapeInfo C_meta = TensorShapeInfo::create_shape_info_for_C(C, D); + TensorShapeInfo D_meta = TensorShapeInfo::from_tensor(D); + + char *c_base = static_cast(C->data.dptr); + char *d_base = static_cast(D->data.dptr); + + const size_t a_elem_size = transformer_engine::typeToSize(A_sel.dtype); + const size_t b_elem_size = transformer_engine::typeToSize(B_sel.dtype); + const size_t c_elem_size = transformer_engine::typeToSize(C->dtype()); + const size_t d_elem_size = transformer_engine::typeToSize(D->dtype()); + + const int threads_per_block = 256; + const int num_blocks = (num_tensors + threads_per_block - 1) / threads_per_block; + + setup_grouped_gemm_kernel<<>>( + ws.A_ptrs, ws.B_ptrs, ws.C_ptrs, ws.D_ptrs, ws.a_rows, ws.a_cols, ws.b_rows, ws.b_cols, + ws.d_rows, ws.d_cols, ws.alpha_ptrs, ws.beta_ptrs, A_sel.dptr, B_sel.dptr, c_base, d_base, + A_meta, B_meta, C_meta, D_meta, a_elem_size, b_elem_size, c_elem_size, d_elem_size, + static_cast(alpha_tensor->data.dptr), static_cast(beta_tensor->data.dptr), + num_tensors); + + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) { + return GroupedGemmSetupWorkspace::required_setup_size(num_tensors, kGroupedGemmAlignment); +} + +} // namespace + +void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb, + const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha, + const NVTETensor beta, NVTETensor workspace_setup, + NVTETensor workspace_cublas, NVTEGroupedMatmulConfig config, + cudaStream_t stream) { + NVTE_API_CALL(nvte_grouped_gemm); + using namespace transformer_engine; + + // Grouped GEMM requires Blackwell (SM100) or newer and cuBLAS 13.1+ + const int current_device = cuda::current_device(); + NVTE_CHECK(cuda::sm_arch(current_device) >= 100, + "nvte_grouped_gemm requires Blackwell (SM100) or newer architecture."); + NVTE_CHECK(cuda::cublas_version() >= 130100, + "nvte_grouped_gemm requires cuBLAS 13.1+, but run-time cuBLAS version is ", + cuda::cublas_version()); + + // Convert to internal types + const GroupedTensor *inputA = convertNVTEGroupedTensorCheck(A); + const GroupedTensor *inputB = convertNVTEGroupedTensorCheck(B); + const GroupedTensor *inputC_raw = convertNVTEGroupedTensor(C); // Can be NULL + GroupedTensor *outputD = convertNVTEGroupedTensorCheck(D); + const Tensor *alpha_tensor = convertNVTETensorCheck(alpha); + const Tensor *beta_tensor = convertNVTETensorCheck(beta); + Tensor *wspace_setup = convertNVTETensor(workspace_setup); + Tensor *wspace_cublas = convertNVTETensor(workspace_cublas); + + // Parse config (if provided) + GroupedMatmulConfig config_; + if (config != nullptr) { + config_ = *reinterpret_cast(config); + } + + // Validate inputs and num_tensors + validate_grouped_gemm_inputs(inputA, inputB, inputC_raw, outputD, alpha_tensor, beta_tensor); + + // If C is NULL, use D as C (valid when beta=0, cuBLAS won't read C data) + const GroupedTensor *inputC = (inputC_raw != nullptr) ? inputC_raw : outputD; + const size_t num_tensors = inputA->num_tensors; + + // Select operand storage (row-wise vs column-wise) and adjust transpose flags to + // mirror the non-grouped GEMM logic for FP8 layout constraints. + const auto A_sel = select_grouped_operand(inputA, static_cast(transa), /*is_A=*/true); + const auto B_sel = select_grouped_operand(inputB, static_cast(transb), /*is_A=*/false); + + // Workspaces: setup (pointer arrays) and cuBLAS + const size_t setup_workspace_size = grouped_gemm_setup_workspace_size(num_tensors); + const size_t cublas_workspace_size = kGroupedGemmCublasWorkspaceSize; + + void *setup_workspace_ptr = validate_and_get_workspace_ptr(wspace_setup, setup_workspace_size, + "Grouped GEMM setup workspace"); + void *cublas_workspace_ptr = validate_and_get_workspace_ptr(wspace_cublas, cublas_workspace_size, + "Grouped GEMM cuBLAS workspace"); + + auto setup_workspace = GroupedGemmSetupWorkspace::from_buffers( + static_cast(setup_workspace_ptr), num_tensors); + launch_grouped_gemm_setup(setup_workspace, A_sel, B_sel, inputC, outputD, alpha_tensor, + beta_tensor, num_tensors, stream); + + // Get cuBLAS handle + using cublasHandleManager = detail::HandleManager; + cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle(); + + // Setup cuBLAS operations + cublasOperation_t op_A = A_sel.trans ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t op_B = B_sel.trans ? CUBLAS_OP_T : CUBLAS_OP_N; + + // Create grouped matrix layouts + cublasLtMatrixLayoutOpaque_t descA, descB, descC, descD; + init_matrix_layouts(descA, descB, descC, descD, setup_workspace, A_sel, B_sel, outputD, + num_tensors); + + // Create matmul descriptor + cublasLtMatmulDescOpaque_t matmulDesc; + init_matmul_desc(matmulDesc, op_A, op_B); + set_fp8_scale_pointers(matmulDesc, A_sel, B_sel); + + // Compute average dimensions for heuristics + // K dimension: if transa, K is A's first dim; if not, K is A's last dim + // Use original inputA and transa for heuristics (not modified A_sel.trans) + int64_t avg_m_val = config_.avg_m.value_or(compute_avg_first_dim(outputD)); + int64_t avg_n_val = config_.avg_n.value_or(compute_avg_last_dim(outputD)); + int64_t avg_k_val = + config_.avg_k.value_or(transa ? compute_avg_first_dim(inputA) : compute_avg_last_dim(inputA)); + + // Heuristic selection + cublasLtMatmulAlgo_t algo = select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC, + descD, avg_m_val, avg_n_val, avg_k_val); + + // Execute the grouped GEMM + NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, &matmulDesc, setup_workspace.alpha_ptrs, + setup_workspace.A_ptrs, &descA, setup_workspace.B_ptrs, &descB, + setup_workspace.beta_ptrs, setup_workspace.C_ptrs, &descC, + setup_workspace.D_ptrs, &descD, &algo, cublas_workspace_ptr, + kGroupedGemmCublasWorkspaceSize, stream)); +} + +#else // CUBLAS_VERSION < 130100 + +void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb, + const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha, + const NVTETensor beta, NVTETensor workspace_setup, + NVTETensor workspace_cublas, NVTEGroupedMatmulConfig config, + cudaStream_t stream) { + NVTE_ERROR("nvte_grouped_gemm requires cuBLAS 13.1+, but compile-time cuBLAS version is ", + CUBLAS_VERSION, ". Please upgrade to CUDA 13.1 or newer."); +} + +#endif // CUBLAS_VERSION >= 130100 diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index b304ed34b..1afc9828e 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -11,6 +11,8 @@ #ifndef TRANSFORMER_ENGINE_GEMM_H_ #define TRANSFORMER_ENGINE_GEMM_H_ +#include + #include "transformer_engine.h" #ifdef __cplusplus @@ -20,6 +22,9 @@ extern "C" { /*! \brief Configuration for matrix multiplication. */ typedef void *NVTEMatmulConfig; +/*! \brief Configuration for grouped matrix multiplication. */ +typedef void *NVTEGroupedMatmulConfig; + /*! \enum NVTEMatmulConfigAttribute * \brief Type of option for matrix multiplication. */ @@ -52,6 +57,36 @@ enum NVTEMatmulConfigAttribute { kNVTEMatmulConfigNumAttributes }; +/*! \enum NVTEGroupedMatmulConfigAttribute + * \brief Type of option for grouped matrix multiplication. + */ +enum NVTEGroupedMatmulConfigAttribute { + /*! Average M dimension hint + * + * Optional hint for average M dimension across all matrices in the group. + * Used by cuBLASLt for algorithm selection heuristics. If not set, + * computed automatically from D's logical shape. + */ + kNVTEGroupedMatmulConfigAvgM = 0, + /*! Average N dimension hint + * + * Optional hint for average N dimension across all matrices in the group. + * Used by cuBLASLt for algorithm selection heuristics. If not set, + * computed automatically from D's logical shape. + */ + kNVTEGroupedMatmulConfigAvgN = 1, + /*! Average K (reduction) dimension hint + * + * Optional hint for average K dimension across all matrices in the group. + * Used by cuBLASLt for algorithm selection heuristics. If not set, + * computed automatically from A's logical shape. + */ + kNVTEGroupedMatmulConfigAvgK = 2, + /*! Number of streaming multiprocessors to use in GEMM kernel. */ + kNVTEGroupedMatmulConfigSMCount = 3, + kNVTEGroupedMatmulConfigNumAttributes +}; + /*! \brief Create a matrix multiplication configuration. */ NVTEMatmulConfig nvte_create_matmul_config(); @@ -82,6 +117,38 @@ void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigA /*! \brief Destroy a matrix multiplication configuration. */ void nvte_destroy_matmul_config(NVTEMatmulConfig config); +/*! \brief Create a grouped matrix multiplication configuration. */ +NVTEGroupedMatmulConfig nvte_create_grouped_matmul_config(); + +/*! \brief Query an option in grouped matrix multiplication configuration. + * + * \param[in] config Grouped matrix multiplication configuration. + * \param[in] attr Option type. + * \param[out] buf Memory address to write option value. Ignored if + * NULL. + * \param[in] size_in_bytes Size of buf. + * \param[out] size_written Number of bytes that have been written to + * buf. If buf is NULL, then the number of + * bytes that would have been written. + */ +void nvte_get_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config, + NVTEGroupedMatmulConfigAttribute attr, void *buf, + size_t size_in_bytes, size_t *size_written); + +/*! \brief Set an option in grouped matrix multiplication configuration. + * + * \param[in] config Grouped matrix multiplication configuration. + * \param[in] attr Option type. + * \param[out] buf Memory address to read option value. + * \param[in] size_in_bytes Size of buf. + */ +void nvte_set_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config, + NVTEGroupedMatmulConfigAttribute attr, + const void *buf, size_t size_in_bytes); + +/*! \brief Destroy a grouped matrix multiplication configuration. */ +void nvte_destroy_grouped_matmul_config(NVTEGroupedMatmulConfig config); + /*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations (deprecated). * * This has been deprecated in favor of nvte_cublas_gemm_v2. @@ -228,6 +295,46 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor bool transa, bool transb, bool grad, NVTETensor *workspace, bool accumulate, bool use_split_accumulator, int math_sm_count, cudaStream_t stream); + +/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ +/*! \brief Grouped matrix multiplication: D = alpha * op(A) @ op(B) + beta * C + * + * \note Requires cuBLAS 13.1+ (CUDA 13.1+) and Blackwell (SM100) or newer GPU architecture. + * Will error at runtime if compiled with an older cuBLAS version or run on + * a pre-Blackwell GPU. + * + * Performs batched GEMM on a collection of matrices with potentially different shapes. + * All tensors in the group must have compatible dimensions for matrix multiplication. + * Uses NVTEGroupedTensor to efficiently handle collections of tensors with contiguous + * memory layout and shape metadata. + * + * \param[in] A Input grouped tensor A. + * \param[in] transa Whether to transpose A matrices. + * \param[in] B Input grouped tensor B. + * \param[in] transb Whether to transpose B matrices. + * \param[in] C Input grouped tensor C (can be NULL for beta=0). + * \param[out] D Output grouped tensor D. + * \param[in] alpha Scale multipliers for A @ B (NVTETensor with num_tensors elements). + * \param[in] beta Scale multipliers for C (NVTETensor with num_tensors elements). + * \param[in] workspace_setup Workspace tensor for pointer array setup. + * \param[in] workspace_cublas Workspace tensor for cuBLAS operations. + * \param[in] config Additional configuration (can be NULL for defaults). + * \param[in] stream CUDA stream for the operation. + * + * Requirements: + * - cuBLAS 13.1+ (CUDA 13.1+) + * - Blackwell (SM100) or newer GPU architecture + * - A, B, C (if provided), D must have the same num_tensors + * - For each i: D[i] = alpha[i] * op(A[i]) @ op(B[i]) + beta[i] * C[i] + * - Shape compatibility: if transa=false, transb=false: + * - A[i]: (M[i], K[i]), B[i]: (K[i], N[i]), D[i]: (M[i], N[i]) + */ +void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb, + const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha, + const NVTETensor beta, NVTETensor workspace_setup, + NVTETensor workspace_cublas, NVTEGroupedMatmulConfig config, + cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus @@ -331,6 +438,70 @@ class MatmulConfigWrapper { NVTEMatmulConfig config_ = nullptr; }; +/*! \struct GroupedMatmulConfigWrapper + * \brief C++ wrapper for NVTEGroupedMatmulConfig. + */ +class GroupedMatmulConfigWrapper { + public: + GroupedMatmulConfigWrapper() : config_{nvte_create_grouped_matmul_config()} {} + + GroupedMatmulConfigWrapper(const GroupedMatmulConfigWrapper &) = delete; + GroupedMatmulConfigWrapper &operator=(const GroupedMatmulConfigWrapper &) = delete; + + GroupedMatmulConfigWrapper(GroupedMatmulConfigWrapper &&other) : config_{other.config_} { + other.config_ = nullptr; + } + GroupedMatmulConfigWrapper &operator=(GroupedMatmulConfigWrapper &&other) { + if (config_ != nullptr) { + nvte_destroy_grouped_matmul_config(config_); + } + config_ = other.config_; + other.config_ = nullptr; + return *this; + } + + ~GroupedMatmulConfigWrapper() { + if (config_ != nullptr) { + nvte_destroy_grouped_matmul_config(config_); + config_ = nullptr; + } + } + + /*! \brief Get the underlying NVTEGroupedMatmulConfig. + * + * \return NVTEGroupedMatmulConfig held by this GroupedMatmulConfigWrapper. + */ + operator NVTEGroupedMatmulConfig() const noexcept { return config_; } + + /*! \brief Set average M dimension hint for algorithm selection. */ + void set_avg_m(int64_t avg_m) { + nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigAvgM, &avg_m, + sizeof(int64_t)); + } + + /*! \brief Set average N dimension hint for algorithm selection. */ + void set_avg_n(int64_t avg_n) { + nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigAvgN, &avg_n, + sizeof(int64_t)); + } + + /*! \brief Set average K dimension hint for algorithm selection. */ + void set_avg_k(int64_t avg_k) { + nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigAvgK, &avg_k, + sizeof(int64_t)); + } + + /*! \brief Set number of streaming multiprocessors to use. */ + void set_sm_count(int sm_count) { + nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigSMCount, &sm_count, + sizeof(int)); + } + + private: + /*! \brief Wrapped NVTEGroupedMatmulConfig. */ + NVTEGroupedMatmulConfig config_ = nullptr; +}; + } // namespace transformer_engine #endif // __cplusplus diff --git a/transformer_engine/common/util/cuda_runtime.cpp b/transformer_engine/common/util/cuda_runtime.cpp index f99900bac..4b43940a5 100644 --- a/transformer_engine/common/util/cuda_runtime.cpp +++ b/transformer_engine/common/util/cuda_runtime.cpp @@ -6,6 +6,8 @@ #include "../util/cuda_runtime.h" +#include + #include #include @@ -210,6 +212,12 @@ int cudart_version() { return version; } +size_t cublas_version() { + // Cache version to avoid cuBLAS logging overhead + static size_t version = cublasLtGetVersion(); + return version; +} + } // namespace cuda } // namespace transformer_engine diff --git a/transformer_engine/common/util/cuda_runtime.h b/transformer_engine/common/util/cuda_runtime.h index c696f6b57..f0aa23962 100644 --- a/transformer_engine/common/util/cuda_runtime.h +++ b/transformer_engine/common/util/cuda_runtime.h @@ -73,6 +73,12 @@ const std::string &include_directory(bool required = false); */ int cudart_version(); +/* \brief cuBLAS version number at run-time + * + * Versions may differ between compile-time and run-time. + */ +size_t cublas_version(); + } // namespace cuda } // namespace transformer_engine From 5671fd3675906cda1ade26c24a65d3dedd88eb89 Mon Sep 17 00:00:00 2001 From: Kshitij Janardan Lakhani Date: Tue, 27 Jan 2026 22:00:28 -0800 Subject: [PATCH 14/22] Revert "[common] Add support for cuBLASLt GEMM for GroupedTensor (#2502)" This reverts commit 9bb9d22645cf5d137a763fe439bae9f4e2b57457. --- tests/cpp/operator/CMakeLists.txt | 1 - tests/cpp/operator/test_grouped_gemm.cu | 308 --------- tests/cpp/test_common.cu | 163 ----- tests/cpp/test_common.h | 54 -- transformer_engine/common/CMakeLists.txt | 1 - transformer_engine/common/gemm/config.cpp | 103 --- transformer_engine/common/gemm/config.h | 19 - .../common/gemm/cublaslt_gemm.cu | 35 +- .../common/gemm/cublaslt_grouped_gemm.cu | 645 ------------------ .../common/include/transformer_engine/gemm.h | 171 ----- .../common/util/cuda_runtime.cpp | 8 - transformer_engine/common/util/cuda_runtime.h | 6 - 12 files changed, 20 insertions(+), 1494 deletions(-) delete mode 100644 tests/cpp/operator/test_grouped_gemm.cu delete mode 100644 transformer_engine/common/gemm/cublaslt_grouped_gemm.cu diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 08a683949..26efb3796 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -30,7 +30,6 @@ add_executable(test_operator test_causal_softmax.cu test_swizzle.cu test_swap_first_dims.cu - test_grouped_gemm.cu ../test_common.cu) # Find required packages diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu deleted file mode 100644 index 35c4375cb..000000000 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ /dev/null @@ -1,308 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -#include "../test_common.h" - -using namespace transformer_engine; -using namespace test; - -namespace { - -enum class InputCase { - kFP8Current, - kBF16, -}; - -enum class ShapeCase { - kAllSame, - kSameFirst, - kSameLast, - kAllDifferent, -}; - -size_t grouped_setup_workspace_size(const size_t num_tensors) { - const size_t ptr_bytes = num_tensors * sizeof(void*); - const size_t int_bytes = num_tensors * sizeof(int); - // Layout: 6 pointer arrays (A, B, C, D, alpha, beta) + 6 int arrays (a_rows, a_cols, b_rows, b_cols, d_rows, d_cols) - size_t size = 6 * ptr_bytes + 6 * int_bytes; - const size_t alignment = 256; - size = ((size + alignment - 1) / alignment) * alignment; - return size; -} - -Tensor make_fp8_operand(const std::string& name, const std::vector& shape) { - Tensor input_fp32(name + "_fp32", shape, DType::kFloat32); - fillUniform(&input_fp32); - - Tensor fp8(name, shape, TypeInfo::dtype, true, true, NVTE_DELAYED_TENSOR_SCALING); - - nvte_compute_amax(input_fp32.data(), fp8.data(), 0); - QuantizationConfigWrapper config; - nvte_compute_scale_from_amax(fp8.data(), config, 0); - nvte_quantize(input_fp32.data(), fp8.data(), 0); - return fp8; -} - -Tensor make_bf16_operand(const std::string& name, const std::vector& shape) { - Tensor t(name, shape, DType::kBFloat16); - const size_t numel = shape[0] * shape[1]; - std::vector<__nv_bfloat16> ones(numel, __float2bfloat16(1.0f)); - NVTE_CHECK_CUDA(cudaMemcpy(t.rowwise_dptr(), ones.data(), - numel * sizeof(__nv_bfloat16), cudaMemcpyHostToDevice)); - return t; -} - -struct TestParams { - InputCase input_case; - bool transa; - bool transb; - ShapeCase shape_case; - bool use_null_c = false; // When true, pass nullptr for C (valid when beta=0) -}; - -// Returns a vector of (M, N, K) tuples for each GEMM in the group. -// M - number of rows in output D -// N - number of columns in output D -// K - reduction dimension shared between A and B -std::vector> make_shapes(ShapeCase scase) { - switch (scase) { - case ShapeCase::kAllSame: - return {{64, 64, 32}, {64, 64, 32}, {64, 64, 32}}; - case ShapeCase::kSameFirst: - // Same M (first dim), varying N and K - return {{64, 80, 32}, {64, 96, 48}, {64, 112, 64}}; - case ShapeCase::kSameLast: - // Same N (last dim), varying M and K - return {{64, 80, 32}, {80, 80, 48}, {96, 80, 64}}; - case ShapeCase::kAllDifferent: - default: - return {{64, 96, 32}, {80, 112, 48}, {96, 128, 64}}; - } -} - -void run_grouped_gemm_case(const TestParams& params) { -#if CUBLAS_VERSION < 130100 - GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.1+, but compile-time cuBLAS version is " - << CUBLAS_VERSION << "."; -#else - if (getDeviceComputeCapability() < blackwellComputeCapability) { - GTEST_SKIP() << "Grouped GEMM requires Blackwell (SM100) or newer."; - } - - const std::vector> shapes = make_shapes(params.shape_case); - - const size_t num_gemms = shapes.size(); - std::vector A_tensors; - std::vector B_tensors; - std::vector D_multi; - - A_tensors.reserve(num_gemms); - B_tensors.reserve(num_gemms); - D_multi.reserve(num_gemms); - - for (size_t i = 0; i < num_gemms; ++i) { - const auto [M, N, K] = shapes[i]; - const std::vector a_shape = params.transa ? std::vector{M, K} - : std::vector{K, M}; - const std::vector b_shape = params.transb ? std::vector{K, N} - : std::vector{N, K}; - switch (params.input_case) { - case InputCase::kFP8Current: { - A_tensors.emplace_back(make_fp8_operand("A" + std::to_string(i), a_shape)); - B_tensors.emplace_back(make_fp8_operand("B" + std::to_string(i), b_shape)); - break; - } - case InputCase::kBF16: { - A_tensors.emplace_back(make_bf16_operand("A" + std::to_string(i), a_shape)); - B_tensors.emplace_back(make_bf16_operand("B" + std::to_string(i), b_shape)); - break; - } - } - D_multi.emplace_back(Tensor("D_multi" + std::to_string(i), - std::vector{M, N}, - DType::kBFloat16)); - } - - std::vector A_ptrs(num_gemms); - std::vector B_ptrs(num_gemms); - std::vector D_ptrs(num_gemms); - std::vector workspaces(num_gemms); - std::vector workspace_ptrs(num_gemms, nullptr); - std::vector A_views; - std::vector B_views; - A_views.reserve(num_gemms); - B_views.reserve(num_gemms); - - // Empty bias/gelu arrays for nvte_multi_tensor_gemm (no epilogues) - std::vector bias_ptrs(num_gemms, nullptr); - std::vector gelu_ptrs(num_gemms, nullptr); - - const size_t cublas_ws_bytes = 32ull * 1024 * 1024; - - for (size_t i = 0; i < num_gemms; ++i) { - A_ptrs[i] = A_tensors[i].data(); - B_ptrs[i] = B_tensors[i].data(); - D_ptrs[i] = D_multi[i].data(); - workspaces[i] = Tensor("workspace" + std::to_string(i), std::vector{cublas_ws_bytes}, DType::kByte); - workspace_ptrs[i] = workspaces[i].data(); - A_views.push_back(&A_tensors[i]); - B_views.push_back(&B_tensors[i]); - } - - nvte_multi_tensor_gemm(A_ptrs.data(), - B_ptrs.data(), - D_ptrs.data(), - bias_ptrs.data(), - gelu_ptrs.data(), - static_cast(num_gemms), - params.transa, - params.transb, - false, // grad - workspace_ptrs.data(), - false, // accumulate - false, // use_split_accumulator - 0, // sm_count - 0); - - GroupedBuffers grouped_A = build_grouped_tensor(A_views, A_tensors[0].scaling_mode()); - GroupedBuffers grouped_B = build_grouped_tensor(B_views, B_tensors[0].scaling_mode()); - - std::vector C_tensors; - std::vector D_group_tensors; - C_tensors.reserve(num_gemms); - D_group_tensors.reserve(num_gemms); - for (size_t i = 0; i < num_gemms; ++i) { - const auto [M, N, K] = shapes[i]; - (void)K; - if (!params.use_null_c) { - C_tensors.emplace_back(Tensor("C" + std::to_string(i), - std::vector{static_cast(M), static_cast(N)}, - DType::kBFloat16)); - } - D_group_tensors.emplace_back(Tensor("D_group" + std::to_string(i), - std::vector{static_cast(M), static_cast(N)}, - DType::kBFloat16)); - NVTE_CHECK_CUDA(cudaMemset(D_group_tensors.back().rowwise_dptr(), 0, bytes(D_group_tensors.back().rowwise_shape(), D_group_tensors.back().dtype()))); - } - - std::vector C_views, D_views; - for (size_t i = 0; i < num_gemms; ++i) { - if (!params.use_null_c) { - C_views.push_back(&C_tensors[i]); - } - D_views.push_back(&D_group_tensors[i]); - } - - std::optional grouped_C; - if (!params.use_null_c) { - grouped_C = build_grouped_tensor(C_views, NVTE_DELAYED_TENSOR_SCALING); - } - GroupedBuffers grouped_D = build_grouped_tensor(D_views, NVTE_DELAYED_TENSOR_SCALING); - - // Per-matrix alpha/beta (all 1.0 and 0.0 respectively) - Tensor alpha_tensor("alpha", std::vector{num_gemms}, DType::kFloat32); - Tensor beta_tensor("beta", std::vector{num_gemms}, DType::kFloat32); - std::vector alpha_vals(num_gemms, 1.f); - std::vector beta_vals(num_gemms, 0.f); - NVTE_CHECK_CUDA(cudaMemcpy(alpha_tensor.rowwise_dptr(), alpha_vals.data(), - num_gemms * sizeof(float), cudaMemcpyHostToDevice)); - NVTE_CHECK_CUDA(cudaMemcpy(beta_tensor.rowwise_dptr(), beta_vals.data(), - num_gemms * sizeof(float), cudaMemcpyHostToDevice)); - - const size_t setup_ws_bytes = grouped_setup_workspace_size(num_gemms); - Tensor setup_ws("setup_ws", std::vector{setup_ws_bytes}, DType::kByte); - Tensor cublas_ws("cublas_ws", std::vector{cublas_ws_bytes}, DType::kByte); - - nvte_grouped_gemm(grouped_A.get_handle(), - params.transa, - grouped_B.get_handle(), - params.transb, - params.use_null_c ? nullptr : grouped_C->get_handle(), - grouped_D.get_handle(), - alpha_tensor.data(), - beta_tensor.data(), - setup_ws.data(), - cublas_ws.data(), - nullptr, // config (use defaults) - 0); - - for (size_t i = 0; i < num_gemms; ++i) { - Tensor grouped_split("grouped_D" + std::to_string(i), - std::vector{static_cast(std::get<0>(shapes[i])), - static_cast(std::get<1>(shapes[i]))}, - D_multi[i].dtype()); - const size_t offset_bytes = static_cast(grouped_D.offsets_host[i]) * grouped_D.elem_size; - NVTE_CHECK_CUDA(cudaMemcpy(grouped_split.rowwise_dptr(), - static_cast(grouped_D.get_data()) + offset_bytes, - grouped_D.tensor_bytes[i], - cudaMemcpyDeviceToDevice)); - grouped_split.to_cpu(); - D_multi[i].to_cpu(); - auto [atol, rtol] = getTolerances(D_multi[i].dtype()); - compareResults("grouped_vs_multi", - grouped_split, - D_multi[i].rowwise_cpu_dptr(), - true, - atol, - rtol); - } -#endif // CUBLAS_VERSION >= 130100 -} - -class GroupedGemmTest : public ::testing::TestWithParam {}; - -TEST_P(GroupedGemmTest, CompareWithMultiTensorGemm) { - run_grouped_gemm_case(GetParam()); -} - -std::string MakeGroupedGemmTestName(const testing::TestParamInfo& info) { - constexpr const char* kInputNames[] = {"FP8Current", "BF16"}; - constexpr const char* kShapeNames[] = {"AllSame", "SameM", "SameN", "AllDiff"}; - const std::string layout = std::string("ta") + (info.param.transa ? "T" : "N") + - "tb" + (info.param.transb ? "T" : "N"); - const std::string null_c = info.param.use_null_c ? "_NullC" : ""; - return std::string(kInputNames[static_cast(info.param.input_case)]) + "_" + - kShapeNames[static_cast(info.param.shape_case)] + "_" + layout + null_c; -} - -// TestParams: {input_case, transa, transb, shape_case, use_null_c} -const std::vector kTestParams = { - // Basic tests - {InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent, false}, - {InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent, false}, - {InputCase::kFP8Current, false, false, ShapeCase::kAllSame, false}, - {InputCase::kBF16, true, false, ShapeCase::kSameFirst, false}, - {InputCase::kBF16, false, true, ShapeCase::kSameLast, false}, - {InputCase::kBF16, false, false, ShapeCase::kAllSame, false}, - {InputCase::kBF16, true, true, ShapeCase::kAllDifferent, false}, - // Test NULL C (valid when beta=0) - {InputCase::kBF16, false, false, ShapeCase::kAllSame, true}, -}; - -INSTANTIATE_TEST_SUITE_P(OperatorTest, - GroupedGemmTest, - ::testing::ValuesIn(kTestParams), - MakeGroupedGemmTestName); - -} // namespace diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index af99d9c42..ed961bfe9 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -9,7 +9,6 @@ #include #include -#include #include #include #include @@ -1058,166 +1057,4 @@ std::array get_scale_tensor_dims(const size_t rows, return {unpadded_blocks_Y, unpadded_blocks_X, blocks_Y, blocks_X}; } -GroupedBuffers build_grouped_tensor(const std::vector& tensors, - const NVTEScalingMode scaling_mode) { - NVTE_CHECK(!tensors.empty(), "No tensors provided for grouped tensor build."); - const NVTEShape shape = tensors[0]->rowwise_shape(); - const DType dtype = tensors[0]->dtype(); - const size_t num_tensors = tensors.size(); - const size_t elem_size = typeToNumBits(dtype) / 8; - GroupedBuffers grouped; - grouped.elem_size = elem_size; - grouped.num_tensors = num_tensors; - grouped.dtype = dtype; - grouped.scaling_mode = scaling_mode; - grouped.tensor_bytes.resize(num_tensors); - grouped.offsets_host.resize(num_tensors, 0); - - std::vector first_dims(num_tensors); - std::vector last_dims(num_tensors); - for (size_t i = 0; i < num_tensors; ++i) { - const auto s = tensors[i]->rowwise_shape(); - NVTE_CHECK(s.ndim == 2, "Grouped tensor build expects 2D tensors."); - first_dims[i] = static_cast(s.data[0]); - last_dims[i] = static_cast(s.data[1]); - grouped.tensor_bytes[i] = bytes(s, dtype); - } - - const bool same_first = std::all_of(first_dims.begin(), first_dims.end(), - [&](int64_t v) { return v == first_dims[0]; }); - const bool same_last = std::all_of(last_dims.begin(), last_dims.end(), - [&](int64_t v) { return v == last_dims[0]; }); - - std::vector offsets(num_tensors, 0); - auto random_padding = [&]() -> int64_t { - // Random padding ensuring 16-byte alignment regardless of element size - // cuBLAS requires aligned pointers for vectorized loads - static std::mt19937 gen(12345); - std::uniform_int_distribution dist(0, 3); - // Calculate elements needed for 16-byte alignment in bytes, rounded up - const size_t align_elements = - std::max(1, (16 + elem_size - 1) / elem_size); // 16 bytes / element_size - return dist(gen) * static_cast(align_elements); - }; - - auto numel = [&](size_t idx) -> int64_t { - return first_dims[idx] * last_dims[idx]; - }; - - const bool need_offsets = !same_first || !same_last; - if (need_offsets) { - offsets[0] = 0; - for (size_t i = 1; i < num_tensors; ++i) { - offsets[i] = offsets[i - 1] + numel(i - 1) + random_padding(); - } - } else { - for (size_t i = 0; i < num_tensors; ++i) { - offsets[i] = static_cast(i) * numel(0); - } - } - grouped.offsets_host = offsets; - - int64_t logical_first = 0; - int64_t logical_last = 0; - if (same_first && same_last) { - logical_first = first_dims[0] * static_cast(num_tensors); - logical_last = last_dims[0]; - } else if (same_first && !same_last) { - logical_first = first_dims[0]; - logical_last = std::accumulate(last_dims.begin(), last_dims.end(), int64_t{0}); - } else if (!same_first && same_last) { - logical_first = std::accumulate(first_dims.begin(), first_dims.end(), int64_t{0}); - logical_last = last_dims[0]; - } else { - logical_first = 1; - logical_last = 0; - for (size_t i = 0; i < num_tensors; ++i) { - logical_last += first_dims[i] * last_dims[i]; - } - } - size_t logical_data[2] = {static_cast(logical_first), - static_cast(logical_last)}; - grouped.logical_shape = nvte_make_shape(logical_data, 2); - grouped.handle.reset(nvte_create_grouped_tensor(scaling_mode, num_tensors, grouped.logical_shape)); - - const int64_t last_idx = static_cast(num_tensors - 1); - const int64_t total_elems = need_offsets - ? (offsets[last_idx] + numel(last_idx)) - : (logical_first * logical_last); - const size_t total_bytes = static_cast(total_elems) * elem_size; - - grouped.data = cuda_alloc(total_bytes); - for (size_t i = 0; i < num_tensors; ++i) { - const size_t offset_bytes = static_cast(offsets[i]) * elem_size; - NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.data.get()) + offset_bytes, - tensors[i]->rowwise_dptr(), - grouped.tensor_bytes[i], - cudaMemcpyDeviceToDevice)); - } - - NVTEBasicTensor data_tensor{grouped.data.get(), static_cast(dtype), grouped.logical_shape}; - NVTEGroupedTensor h = grouped.handle.get(); - nvte_set_grouped_tensor_param(&h, kNVTEGroupedRowwiseData, &data_tensor); - - const bool include_columnwise = isFp8Type(dtype) || isFp4Type(dtype); - if (include_columnwise) { - grouped.columnwise_data = cuda_alloc(total_bytes); - for (size_t i = 0; i < num_tensors; ++i) { - const size_t offset_bytes = static_cast(offsets[i]) * elem_size; - NVTE_CHECK_CUDA(cudaMemcpy(static_cast(grouped.columnwise_data.get()) + offset_bytes, - tensors[i]->columnwise_dptr(), - grouped.tensor_bytes[i], - cudaMemcpyDeviceToDevice)); - } - NVTEBasicTensor col_tensor{grouped.columnwise_data.get(), - static_cast(dtype), - grouped.logical_shape}; - nvte_set_grouped_tensor_param(&h, kNVTEGroupedColumnwiseData, &col_tensor); - } - - if (!same_first) { - grouped.first_dims_dev = cuda_alloc(num_tensors * sizeof(int64_t)); - NVTE_CHECK_CUDA(cudaMemcpy(grouped.first_dims_dev.get(), first_dims.data(), - num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); - NVTEShape fd_shape = nvte_make_shape(&num_tensors, 1); - NVTEBasicTensor fd_tensor{grouped.first_dims_dev.get(), kNVTEInt64, fd_shape}; - nvte_set_grouped_tensor_param(&h, kNVTEGroupedFirstDims, &fd_tensor); - } - - if (!same_last) { - grouped.last_dims_dev = cuda_alloc(num_tensors * sizeof(int64_t)); - NVTE_CHECK_CUDA(cudaMemcpy(grouped.last_dims_dev.get(), last_dims.data(), - num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); - NVTEShape ld_shape = nvte_make_shape(&num_tensors, 1); - NVTEBasicTensor ld_tensor{grouped.last_dims_dev.get(), kNVTEInt64, ld_shape}; - nvte_set_grouped_tensor_param(&h, kNVTEGroupedLastDims, &ld_tensor); - } - - if (!same_first || !same_last) { - grouped.offsets_dev = cuda_alloc(num_tensors * sizeof(int64_t)); - NVTE_CHECK_CUDA(cudaMemcpy(grouped.offsets_dev.get(), offsets.data(), - num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); - NVTEShape off_shape = nvte_make_shape(&num_tensors, 1); - NVTEBasicTensor off_tensor{grouped.offsets_dev.get(), kNVTEInt64, off_shape}; - nvte_set_grouped_tensor_param(&h, kNVTEGroupedTensorOffsets, &off_tensor); - } - - if (isFp8Type(dtype)) { - std::vector scale_inv_cpu(num_tensors, 1.f); - for (size_t i = 0; i < num_tensors; ++i) { - tensors[i]->to_cpu(); - scale_inv_cpu[i] = tensors[i]->rowwise_cpu_scale_inv_ptr()[0]; - } - grouped.scale_inv = cuda_alloc(sizeof(float) * num_tensors); - NVTE_CHECK_CUDA(cudaMemcpy(grouped.scale_inv.get(), scale_inv_cpu.data(), - sizeof(float) * num_tensors, cudaMemcpyHostToDevice)); - NVTEShape scale_shape = nvte_make_shape(&num_tensors, 1); - NVTEBasicTensor scale_tensor{grouped.scale_inv.get(), kNVTEFloat32, scale_shape}; - nvte_set_grouped_tensor_param(&h, kNVTEGroupedRowwiseScaleInv, &scale_tensor); - nvte_set_grouped_tensor_param(&h, kNVTEGroupedColumnwiseScaleInv, &scale_tensor); - } - - return grouped; -} - } // namespace test diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 082677c97..b528a79b4 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -504,60 +504,6 @@ int32_t getDeviceComputeCapability(); constexpr int32_t hopperComputeCapability = 90; constexpr int32_t blackwellComputeCapability = 100; -// Custom deleters for RAII -struct CudaDeleter { - void operator()(void* p) const { if (p) cudaFree(p); } -}; -struct GroupedTensorDeleter { - void operator()(NVTEGroupedTensor h) const { if (h) nvte_destroy_grouped_tensor(h); } -}; - -template -using CudaPtr = std::unique_ptr; -using GroupedTensorHandle = std::unique_ptr, GroupedTensorDeleter>; - -// Helper to allocate CUDA memory into a CudaPtr -template -CudaPtr cuda_alloc(size_t bytes) { - void* ptr = nullptr; - NVTE_CHECK_CUDA(cudaMalloc(&ptr, bytes)); - return CudaPtr(static_cast(ptr)); -} - -// Helper owning GPU buffers that back NVTEGroupedTensor. -// NVTEGroupedTensor does not own memory; data/offsets/scales -// must be allocated and freed by the test. -struct GroupedBuffers { - GroupedTensorHandle handle; - CudaPtr<> data; - CudaPtr<> scale_inv; - CudaPtr first_dims_dev; - CudaPtr last_dims_dev; - CudaPtr offsets_dev; - CudaPtr<> columnwise_data; - NVTEShape logical_shape{}; - std::vector offsets_host; - std::vector tensor_bytes; - size_t num_tensors{0}; - size_t elem_size{0}; - DType dtype{DType::kFloat32}; - NVTEScalingMode scaling_mode{NVTE_DELAYED_TENSOR_SCALING}; - - GroupedBuffers() = default; - GroupedBuffers(const GroupedBuffers&) = delete; - GroupedBuffers& operator=(const GroupedBuffers&) = delete; - GroupedBuffers(GroupedBuffers&&) = default; - GroupedBuffers& operator=(GroupedBuffers&&) = default; - ~GroupedBuffers() = default; - - // Convenience accessors for raw pointers - NVTEGroupedTensor get_handle() const { return handle.get(); } - void* get_data() const { return data.get(); } -}; - -GroupedBuffers build_grouped_tensor(const std::vector& tensors, - const NVTEScalingMode scaling_mode); - } // namespace test #if FP4_TYPE_SUPPORTED diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index efe958f84..a83cbe3e3 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -144,7 +144,6 @@ list(APPEND transformer_engine_cuda_sources fused_attn/fused_attn_fp8.cu fused_attn/utils.cu gemm/cublaslt_gemm.cu - gemm/cublaslt_grouped_gemm.cu normalization/layernorm/ln_bwd_semi_cuda_kernel.cu normalization/layernorm/ln_fwd_cuda_kernel.cu normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu diff --git a/transformer_engine/common/gemm/config.cpp b/transformer_engine/common/gemm/config.cpp index 286fc0cc9..2532e96bb 100644 --- a/transformer_engine/common/gemm/config.cpp +++ b/transformer_engine/common/gemm/config.cpp @@ -126,106 +126,3 @@ void nvte_destroy_matmul_config(NVTEMatmulConfig config) { delete reinterpret_cast(config); } } - -NVTEGroupedMatmulConfig nvte_create_grouped_matmul_config() { - return new transformer_engine::GroupedMatmulConfig; -} - -void nvte_get_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config, - NVTEGroupedMatmulConfigAttribute attr, void *buf, - size_t size_in_bytes, size_t *size_written) { - // Write attribute size - NVTE_CHECK(attr < kNVTEGroupedMatmulConfigNumAttributes, - "Invalid NVTEGroupedMatmulConfigAttribute (got ", static_cast(attr), ")"); - NVTE_CHECK(size_written != nullptr, "Invalid size_written (got NULL)"); - const auto &attr_size = transformer_engine::GroupedMatmulConfig::attr_sizes[attr]; - *size_written = attr_size; - - // Return immediately if buffer is not provided - if (buf == nullptr) { - return; - } - - // Check buffer size - NVTE_CHECK(size_in_bytes >= attr_size, - "Buffer is too small for grouped matmul config attribute " - "(attribute ", - static_cast(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes, - " bytes)"); - - // Write to buffer - NVTE_CHECK(config != nullptr, "Invalid NVTEGroupedMatmulConfig (got NULL)"); - const auto &config_ = *reinterpret_cast(config); - switch (attr) { - case kNVTEGroupedMatmulConfigAvgM: { - int64_t val = config_.avg_m.value_or(0); - std::memcpy(buf, &val, attr_size); - break; - } - case kNVTEGroupedMatmulConfigAvgN: { - int64_t val = config_.avg_n.value_or(0); - std::memcpy(buf, &val, attr_size); - break; - } - case kNVTEGroupedMatmulConfigAvgK: { - int64_t val = config_.avg_k.value_or(0); - std::memcpy(buf, &val, attr_size); - break; - } - case kNVTEGroupedMatmulConfigSMCount: - std::memcpy(buf, &config_.sm_count, attr_size); - break; - default: - NVTE_ERROR("Unsupported NVTEGroupedMatmulConfigAttribute (got ", static_cast(attr), ")"); - } -} - -void nvte_set_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config, - NVTEGroupedMatmulConfigAttribute attr, - const void *buf, size_t size_in_bytes) { - // Check attribute and buffer - NVTE_CHECK(attr < kNVTEGroupedMatmulConfigNumAttributes, - "Invalid NVTEGroupedMatmulConfigAttribute (got ", static_cast(attr), ")"); - const auto &attr_size = transformer_engine::GroupedMatmulConfig::attr_sizes[attr]; - NVTE_CHECK(size_in_bytes >= attr_size, - "Buffer is too small for grouped matmul config attribute " - "(attribute ", - static_cast(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes, - " bytes)"); - NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)"); - - // Read from buffer - NVTE_CHECK(config != nullptr, "Invalid NVTEGroupedMatmulConfig (got NULL)"); - auto &config_ = *reinterpret_cast(config); - switch (attr) { - case kNVTEGroupedMatmulConfigAvgM: { - int64_t val; - std::memcpy(&val, buf, attr_size); - config_.avg_m = val; - break; - } - case kNVTEGroupedMatmulConfigAvgN: { - int64_t val; - std::memcpy(&val, buf, attr_size); - config_.avg_n = val; - break; - } - case kNVTEGroupedMatmulConfigAvgK: { - int64_t val; - std::memcpy(&val, buf, attr_size); - config_.avg_k = val; - break; - } - case kNVTEGroupedMatmulConfigSMCount: - std::memcpy(&config_.sm_count, buf, attr_size); - break; - default: - NVTE_ERROR("Unsupported NVTEGroupedMatmulConfigAttribute (got ", static_cast(attr), ")"); - } -} - -void nvte_destroy_grouped_matmul_config(NVTEGroupedMatmulConfig config) { - if (config != nullptr) { - delete reinterpret_cast(config); - } -} diff --git a/transformer_engine/common/gemm/config.h b/transformer_engine/common/gemm/config.h index ad38e8833..86a617b5f 100644 --- a/transformer_engine/common/gemm/config.h +++ b/transformer_engine/common/gemm/config.h @@ -9,9 +9,6 @@ #include -#include -#include - namespace transformer_engine { struct MatmulConfig { @@ -34,22 +31,6 @@ struct MatmulConfig { }; }; -struct GroupedMatmulConfig { - // Average dimension hints for cuBLASLt algorithm selection heuristics. - // nullopt means "not set" - compute automatically from tensor shapes. - std::optional avg_m; - std::optional avg_n; - std::optional avg_k; - - // Number of streaming multiprocessors to use in GEMM kernel - int sm_count = 0; - - // Note: API transfers the value type, not std::optional - static constexpr size_t attr_sizes[] = {sizeof(decltype(avg_m)::value_type), - sizeof(decltype(avg_n)::value_type), - sizeof(decltype(avg_k)::value_type), sizeof(sm_count)}; -}; - } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_GEMM_CONFIG_H_ diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index e4e97abd9..02faad40d 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -302,6 +302,13 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla return ret; } +/* cuBLAS version number at run-time */ +size_t cublas_version() { + // Cache version to avoid cuBLAS logging overhead + static size_t version = cublasLtGetVersion(); + return version; +} + } // namespace namespace transformer_engine { @@ -494,9 +501,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, #endif // CUBLAS_VERSION >= 120800 } else if (mxfp8_gemm) { #if CUBLAS_VERSION >= 120800 - NVTE_CHECK(cuda::cublas_version() >= 120800, - "MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ", - cuda::cublas_version()); + NVTE_CHECK(cublas_version() >= 120800, + "MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version()); // Check that scales are in expected format NVTE_CHECK(inputA->with_gemm_swizzled_scales, @@ -518,7 +524,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, // Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling. // CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set. - if (cuda::cublas_version() <= 120803) { + if (cublas_version() <= 120803) { const int64_t dummy_a_vec_stride = 1; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride, @@ -530,9 +536,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, #endif // CUBLAS_VERSION >= 120800 } else if (use_fp4) { // NVFP4 GEMM #if CUBLAS_VERSION >= 120800 - NVTE_CHECK(cuda::cublas_version() >= 120800, - "FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is ", - cuda::cublas_version()); + NVTE_CHECK(cublas_version() >= 120800, + "FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version()); // Check that scales are in expected format NVTE_CHECK(inputA->with_gemm_swizzled_scales, @@ -567,9 +572,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, (inputB->scaling_mode == NVTE_BLOCK_SCALING_1D || inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) { #if CUBLAS_VERSION >= 120900 - NVTE_CHECK(cuda::cublas_version() >= 120900, + NVTE_CHECK(cublas_version() >= 120900, "FP8 block scaling requires cuBLAS 12.9+, but run-time cuBLAS version is ", - cuda::cublas_version()); + cublas_version()); // Check that matrix formats are valid NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D && @@ -602,7 +607,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, } #if CUBLAS_VERSION >= 120800 - if (cuda::cublas_version() >= 120800) { + if (cublas_version() >= 120800) { NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scaling_mode_a, sizeof(scaling_mode_a))); @@ -619,7 +624,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax))); #if CUBLAS_VERSION >= 120800 - if (cuda::cublas_version() >= 120800) { + if (cublas_version() >= 120800) { // NOTE: In all current cases where FP8 output is supported, the input is // scaled identically to the output. NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, @@ -706,9 +711,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000, "Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA version is ", cuda::cudart_version()); - NVTE_CHECK(cuda::cublas_version() >= 120205 && cuda::cublas_version() < 130000, + NVTE_CHECK(cublas_version() >= 120205 && cublas_version() < 130000, "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS version is ", - cuda::cublas_version()); + cublas_version()); if (m_split == 0) m_split = 1; if (n_split == 0) n_split = 1; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( @@ -934,9 +939,9 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor "Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA version is ", transformer_engine::cuda::cudart_version()); NVTE_CHECK( - cuda::cublas_version() >= 120205 && cuda::cublas_version() < 130000, + cublas_version() >= 120205 && cublas_version() < 130000, "Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is ", - cuda::cublas_version()); + cublas_version()); const Tensor *inputA = convertNVTETensorCheck(A); const Tensor *inputB = convertNVTETensorCheck(B); diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu deleted file mode 100644 index a1206474e..000000000 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ /dev/null @@ -1,645 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include -#include -#include -#include -#include - -#include - -#include "../common.h" -#include "../util/cuda_runtime.h" -#include "../util/handle_manager.h" -#include "../util/logging.h" -#include "./config.h" - -namespace { - -inline void CreateCublasHandle(cublasLtHandle_t *handle) { - NVTE_CHECK_CUBLAS(cublasLtCreate(handle)); -} - -} // namespace - -#if CUBLAS_VERSION >= 130100 - -namespace { - -// Helper struct to pass per-tensor shape/offset info (pointer or uniform value) -struct TensorShapeInfo { - const int64_t *first_dims; // nullptr if uniform - const int64_t *last_dims; // nullptr if uniform - const int64_t *offsets; // nullptr if need to compute - int64_t uniform_first; // used if first_dims == nullptr - int64_t uniform_last; // used if last_dims == nullptr - - // Create from GroupedTensor - static TensorShapeInfo from_tensor(const transformer_engine::GroupedTensor *t) { - const bool has_first = t->first_dims.has_data(); - const bool has_last = t->last_dims.has_data(); - // When per-tensor dims are not provided, we must be in the uniform-shape case. - NVTE_CHECK(has_first || t->all_same_first_dim(), - "GroupedTensor is missing first_dims for varying shapes"); - NVTE_CHECK(has_last || t->all_same_last_dim(), - "GroupedTensor is missing last_dims for varying shapes"); - - const int64_t *first_ptr = - has_first ? static_cast(t->first_dims.dptr) : nullptr; - const int64_t *last_ptr = has_last ? static_cast(t->last_dims.dptr) : nullptr; - - const int64_t uniform_first = has_first ? 0 : static_cast(t->get_common_first_dim()); - const int64_t uniform_last = has_last ? 0 : static_cast(t->get_common_last_dim()); - - return {first_ptr, last_ptr, - t->tensor_offsets.has_data() ? static_cast(t->tensor_offsets.dptr) - : nullptr, - uniform_first, uniform_last}; - } - - // Create for C tensor (uses D's dimensions, only has offsets) - static TensorShapeInfo create_shape_info_for_C(const transformer_engine::GroupedTensor *C, - const transformer_engine::GroupedTensor *D) { - const bool has_first = D->first_dims.has_data(); - const bool has_last = D->last_dims.has_data(); - NVTE_CHECK(has_first || D->all_same_first_dim(), - "GroupedTensor D is missing first_dims for varying shapes"); - NVTE_CHECK(has_last || D->all_same_last_dim(), - "GroupedTensor D is missing last_dims for varying shapes"); - - const int64_t *first_ptr = - has_first ? static_cast(D->first_dims.dptr) : nullptr; - const int64_t *last_ptr = has_last ? static_cast(D->last_dims.dptr) : nullptr; - const int64_t uniform_first = has_first ? 0 : static_cast(D->get_common_first_dim()); - const int64_t uniform_last = has_last ? 0 : static_cast(D->get_common_last_dim()); - - return {first_ptr, last_ptr, - C->tensor_offsets.has_data() ? static_cast(C->tensor_offsets.dptr) - : nullptr, - uniform_first, uniform_last}; - } -}; - -// Helper functions to compute average dimensions from logical_shape for heuristics -// These are hints for cuBLASLt algorithm selection, don't need to be exact -inline int64_t compute_avg_first_dim(const transformer_engine::GroupedTensor *t) { - // logical_shape[0] is either num_tensors*M (uniform) or sum_of_M (varying first) - // In both cases, dividing by num_tensors gives the average - return static_cast(t->logical_shape.data[0]) / static_cast(t->num_tensors); -} - -inline int64_t compute_avg_last_dim(const transformer_engine::GroupedTensor *t) { - if (t->all_same_last_dim()) { - // logical_shape[1] is the common N - return static_cast(t->logical_shape.data[1]); - } - // When varying, logical_shape[1] should be sum of last dims if provided; otherwise fallback to avg via division. - return static_cast(t->logical_shape.data[1]) / static_cast(t->num_tensors); -} - -// Workspace layout for grouped GEMM -struct GroupedGemmSetupWorkspace { - void **A_ptrs; - void **B_ptrs; - void **C_ptrs; - void **D_ptrs; - float **alpha_ptrs; - float **beta_ptrs; - // Storage dimensions for cuBLAS matrix layouts - int *a_rows; - int *a_cols; - int *b_rows; - int *b_cols; - int *d_rows; // M (first dim) - also used for C - int *d_cols; // N (last dim) - also used for C - - // Initialize from workspace buffer - // Layout: all pointer arrays first (8-byte aligned), then int arrays (4-byte aligned) - static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors) { - GroupedGemmSetupWorkspace ws; - size_t offset = 0; - const size_t ptr_size = num_tensors * sizeof(void *); - const size_t int_size = num_tensors * sizeof(int); - - // Pointer arrays first (all 8-byte aligned) - ws.A_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - ws.B_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - ws.C_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - ws.D_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - ws.alpha_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - ws.beta_ptrs = reinterpret_cast(setup_ws_ptr + offset); - offset += ptr_size; - - // Int arrays for storage dimensions (4-byte aligned) - ws.a_rows = reinterpret_cast(setup_ws_ptr + offset); - offset += int_size; - ws.a_cols = reinterpret_cast(setup_ws_ptr + offset); - offset += int_size; - ws.b_rows = reinterpret_cast(setup_ws_ptr + offset); - offset += int_size; - ws.b_cols = reinterpret_cast(setup_ws_ptr + offset); - offset += int_size; - ws.d_rows = reinterpret_cast(setup_ws_ptr + offset); - offset += int_size; - ws.d_cols = reinterpret_cast(setup_ws_ptr + offset); - - return ws; - } - - // Calculate required size for setup workspace - static size_t required_setup_size(size_t num_tensors, size_t alignment) { - const size_t ptr_size = num_tensors * sizeof(void *); - const size_t int_size = num_tensors * sizeof(int); - // Layout: 6 ptr arrays, then 6 int arrays - size_t size = 6 * ptr_size + 6 * int_size; - size = ((size + alignment - 1) / alignment) * alignment; - return size; - } -}; - -// ----------------------------------------------------------------------------- -// Helper routines to keep nvte_grouped_gemm readable -// ----------------------------------------------------------------------------- -inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor *inputA, - const transformer_engine::GroupedTensor *inputB, - const transformer_engine::GroupedTensor *inputC, - const transformer_engine::GroupedTensor *outputD, - const transformer_engine::Tensor *alpha_tensor, - const transformer_engine::Tensor *beta_tensor) { - const size_t num_tensors = inputA->num_tensors; - NVTE_CHECK(num_tensors >= 1, "Grouped GEMM: number of tensors must be at least 1"); - NVTE_CHECK(inputB->num_tensors == num_tensors, - "Grouped GEMM: A and B must have the same number of tensors"); - // C can be NULL (will use D as C when beta=0) - if (inputC != nullptr) { - NVTE_CHECK(inputC->num_tensors == num_tensors, - "Grouped GEMM: A and C must have the same number of tensors"); - } - NVTE_CHECK(outputD->num_tensors == num_tensors, - "Grouped GEMM: A and D must have the same number of tensors"); - - // Validate alpha/beta have per-matrix values - const size_t alpha_numel = alpha_tensor->data.numel(); - const size_t beta_numel = beta_tensor->data.numel(); - NVTE_CHECK(alpha_numel == num_tensors, "Grouped GEMM: alpha must have num_tensors (", num_tensors, - ") elements, got ", alpha_numel); - NVTE_CHECK(beta_numel == num_tensors, "Grouped GEMM: beta must have num_tensors (", num_tensors, - ") elements, got ", beta_numel); - - auto is_fp8_or_16bit = [](transformer_engine::DType dtype) { - return dtype == transformer_engine::DType::kFloat8E4M3 || - dtype == transformer_engine::DType::kFloat8E5M2 || - dtype == transformer_engine::DType::kBFloat16 || - dtype == transformer_engine::DType::kFloat16; - }; - auto is_output_dtype = [](transformer_engine::DType dtype) { - return dtype == transformer_engine::DType::kBFloat16 || - dtype == transformer_engine::DType::kFloat16 || - dtype == transformer_engine::DType::kFloat32; - }; - NVTE_CHECK(is_fp8_or_16bit(inputA->dtype()) && is_fp8_or_16bit(inputB->dtype()), - "Grouped GEMM inputs must be FP8, BF16, or FP16."); - // Only check C dtype if C is provided - if (inputC != nullptr) { - NVTE_CHECK(is_output_dtype(inputC->dtype()), "Grouped GEMM: C must be BF16, FP16, or FP32."); - } - NVTE_CHECK(is_output_dtype(outputD->dtype()), "Grouped GEMM: D must be BF16, FP16, or FP32."); - NVTE_CHECK(inputA->has_data() || inputA->has_columnwise_data(), - "Grouped GEMM: A tensor is missing both row-wise and column-wise data"); - NVTE_CHECK(inputB->has_data() || inputB->has_columnwise_data(), - "Grouped GEMM: B tensor is missing both row-wise and column-wise data"); -} - -// Select row-wise vs column-wise storage and adjust transpose flag for grouped GEMM. -// Mirrors the non-grouped GEMM logic for FP8 layout handling (TN-only on Hopper) and -// fallback to column-wise data when row-wise is absent. -// Contains all information needed for GEMM setup - shape already accounts for storage layout. -struct GroupedOperandSelection { - TensorShapeInfo shape; // Shape info with dims already swapped for columnwise if needed - char *dptr = nullptr; - void *scale_inv = nullptr; - transformer_engine::DType dtype = transformer_engine::DType::kNumTypes; - bool trans = false; -}; - -// Helper to create TensorShapeInfo from a GroupedTensor, optionally swapping first/last dims. -// When swap_dims=true, first_dims and last_dims are swapped to account for columnwise storage. -// Note: tensor_offsets are the same for rowwise and columnwise data (same element count per tensor). -inline TensorShapeInfo create_shape_info(const transformer_engine::GroupedTensor *t, - bool swap_dims) { - const bool has_first = t->first_dims.has_data(); - const bool has_last = t->last_dims.has_data(); - NVTE_CHECK(has_first || t->all_same_first_dim(), - "GroupedTensor is missing first_dims for varying shapes"); - NVTE_CHECK(has_last || t->all_same_last_dim(), - "GroupedTensor is missing last_dims for varying shapes"); - - const int64_t *first_ptr = has_first ? static_cast(t->first_dims.dptr) : nullptr; - const int64_t *last_ptr = has_last ? static_cast(t->last_dims.dptr) : nullptr; - const int64_t uniform_first = has_first ? 0 : static_cast(t->get_common_first_dim()); - const int64_t uniform_last = has_last ? 0 : static_cast(t->get_common_last_dim()); - - const int64_t *offsets_ptr = - t->tensor_offsets.has_data() ? static_cast(t->tensor_offsets.dptr) : nullptr; - - if (swap_dims) { - // Swap first/last to account for columnwise (transposed) storage - return {last_ptr, first_ptr, offsets_ptr, uniform_last, uniform_first}; - } - return {first_ptr, last_ptr, offsets_ptr, uniform_first, uniform_last}; -} - -inline GroupedOperandSelection select_grouped_operand(const transformer_engine::GroupedTensor *t, - bool trans, bool is_A) { - using namespace transformer_engine; - const bool has_row = t->has_data(); - const bool has_col = t->has_columnwise_data(); - NVTE_CHECK(has_row || has_col, - "Grouped GEMM operand is missing both row-wise and column-wise data"); - - // Currently only unquantized data and tensor-scaled FP8 are supported. - const auto sm = t->scaling_mode; - NVTE_CHECK(sm == NVTE_DELAYED_TENSOR_SCALING, - "Grouped GEMM is only supported with unquantized data and tensor-scaled FP8 data"); - - const DType row_dtype = t->data.dtype; - const DType col_dtype = t->columnwise_data.dtype; - GroupedOperandSelection sel; - sel.trans = trans; - - const DType rep_dtype = has_row ? row_dtype : col_dtype; - const bool is_fp8 = is_fp8_dtype(rep_dtype); - const bool non_tn_fp8_ok = nvte_is_non_tn_fp8_gemm_supported(); - - // Helper to select columnwise storage (swaps dims in shape) - auto use_columnwise = [&]() { - sel.dptr = static_cast(t->columnwise_data.dptr); - sel.scale_inv = t->columnwise_scale_inv.dptr; - sel.dtype = col_dtype; - sel.shape = create_shape_info(t, /*swap_dims=*/true); - }; - - // Helper to select row-wise storage - auto use_rowwise = [&]() { - sel.dptr = static_cast(t->data.dptr); - sel.scale_inv = t->scale_inv.dptr; - sel.dtype = row_dtype; - sel.shape = create_shape_info(t, /*swap_dims=*/false); - }; - - // Hopper-style TN-only FP8: force TN by switching layout and flipping transpose when needed. - if (is_fp8 && !non_tn_fp8_ok) { - if (is_A) { - if (!sel.trans) { - NVTE_CHECK(has_col, "Grouped GEMM: A is missing column-wise data needed for FP8 TN layout"); - use_columnwise(); - sel.trans = true; // using pre-transposed storage - return sel; - } - } else { // B - if (sel.trans) { - NVTE_CHECK(has_col, "Grouped GEMM: B is missing column-wise data needed for FP8 TN layout"); - use_columnwise(); - sel.trans = false; // using pre-transposed storage - return sel; - } - } - } - - // If only column-wise data is available, mirror the transpose flag (pre-transposed storage). - if (!has_row && has_col) { - // On Hopper FP8, this would break TN requirement - should have been handled above - NVTE_CHECK( - !is_fp8 || non_tn_fp8_ok, - "Grouped GEMM: FP8 on Hopper requires row-wise data for this transpose configuration"); - use_columnwise(); - sel.trans = !trans; // flip transpose for pre-transposed storage - return sel; - } - - // Default: use row-wise data - use_rowwise(); - return sel; -} - -inline void *validate_and_get_workspace_ptr(transformer_engine::Tensor *ws, size_t required_size, - const char *workspace_name) { - NVTE_CHECK(ws != nullptr, workspace_name, " tensor is null."); - const size_t provided_size = get_buffer_size_bytes(ws->data.numel(), ws->data.dtype); - NVTE_CHECK(provided_size >= required_size, "Grouped GEMM: Insufficient ", workspace_name, - ". Required: ", required_size, " bytes, Available: ", provided_size, " bytes."); - return ws->data.dptr; -} - -inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, - cublasLtMatrixLayoutOpaque_t &descB, - cublasLtMatrixLayoutOpaque_t &descC, - cublasLtMatrixLayoutOpaque_t &descD, - const GroupedGemmSetupWorkspace &ws, - const GroupedOperandSelection &A_sel, - const GroupedOperandSelection &B_sel, - const transformer_engine::GroupedTensor *D, size_t num_tensors) { - const cudaDataType_t A_type = get_cuda_dtype(A_sel.dtype); - const cudaDataType_t B_type = get_cuda_dtype(B_sel.dtype); - const cudaDataType_t D_type = get_cuda_dtype(D->dtype()); - - // Storage dimensions computed by kernel, leading dimension = rows - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, ws.a_rows, - ws.a_cols, ws.a_rows)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, ws.b_rows, - ws.b_cols, ws.b_rows)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, ws.d_rows, - ws.d_cols, ws.d_rows)); - NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, ws.d_rows, - ws.d_cols, ws.d_rows)); -} - -inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOperation_t op_A, - cublasOperation_t op_B) { - NVTE_CHECK_CUBLAS(cublasLtMatmulDescInit(&matmulDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F)); - - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &op_A, - sizeof(op_A))); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &op_B, - sizeof(op_B))); - - cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE; - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, - &pointer_mode, sizeof(pointer_mode))); - - int64_t alphabeta_batch_stride = 1; - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, - CUBLASLT_MATMUL_DESC_ALPHA_BATCH_STRIDE, - &alphabeta_batch_stride, sizeof(int64_t))); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, - CUBLASLT_MATMUL_DESC_BETA_BATCH_STRIDE, - &alphabeta_batch_stride, sizeof(int64_t))); -} - -inline void set_fp8_scale_pointers(cublasLtMatmulDescOpaque_t &matmulDesc, - const GroupedOperandSelection &A_sel, - const GroupedOperandSelection &B_sel) { - const bool is_fp8_a = is_fp8_dtype(A_sel.dtype); - const bool is_fp8_b = is_fp8_dtype(B_sel.dtype); - if (!is_fp8_a && !is_fp8_b) return; - - if (is_fp8_a) { - void *a_scale_inv = A_sel.scale_inv; - NVTE_CHECK(a_scale_inv != nullptr, "FP8 grouped GEMM: A scale_inv is required"); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - &matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_scale_inv, sizeof(a_scale_inv))); - } - if (is_fp8_b) { - void *b_scale_inv = B_sel.scale_inv; - NVTE_CHECK(b_scale_inv != nullptr, "FP8 grouped GEMM: B scale_inv is required"); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( - &matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_scale_inv, sizeof(b_scale_inv))); - } -} - -// Constants for grouped GEMM workspace (declared early for use in heuristics) -static constexpr size_t kGroupedGemmAlignment = 256; -static constexpr size_t kGroupedGemmCublasWorkspaceSize = 32ull * 1024 * 1024; // 32 MiB - -inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle, - cublasLtMatmulDescOpaque_t &matmulDesc, - cublasLtMatrixLayoutOpaque_t &descA, - cublasLtMatrixLayoutOpaque_t &descB, - cublasLtMatrixLayoutOpaque_t &descC, - cublasLtMatrixLayoutOpaque_t &descD, - int64_t avg_m, int64_t avg_n, int64_t avg_k) { - cublasLtMatmulPreferenceOpaque_t preference; - NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceInit(&preference)); - NVTE_CHECK_CUBLAS( - cublasLtMatmulPreferenceSetAttribute(&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, - &kGroupedGemmCublasWorkspaceSize, sizeof(size_t))); - NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_ROWS, &avg_m, sizeof(int64_t))); - NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_COLS, &avg_n, sizeof(int64_t))); - NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_GROUPED_AVERAGE_REDUCTION_DIM, &avg_k, sizeof(int64_t))); - - cublasLtMatmulHeuristicResult_t heuristicResult; - int returnedResults = 0; - auto status = cublasLtMatmulAlgoGetHeuristic(handle, &matmulDesc, &descA, &descB, &descC, &descD, - &preference, 1, &heuristicResult, &returnedResults); - NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED, - "Unable to find suitable cuBLAS grouped GEMM algorithm"); - NVTE_CHECK_CUBLAS(status); - NVTE_CHECK(returnedResults > 0, "No suitable algorithm found for grouped GEMM"); - return heuristicResult.algo; -} - -// Single kernel that sets up all GEMM parameters. -// Rationale: cuBLASLt grouped matmul API needs flat arrays of pointers and per-matrix dimensions, -// but NVTEGroupedTensor stores a single contiguous buffer + optional per-tensor offsets/shapes. -// We bridge the mismatch on GPU by computing per-group pointers and storage dims in one kernel. -__global__ void setup_grouped_gemm_kernel( - // Output arrays - void **A_ptrs, void **B_ptrs, void **C_ptrs, void **D_ptrs, int *a_rows, int *a_cols, - int *b_rows, int *b_cols, int *d_rows, int *d_cols, float **alpha_ptrs, float **beta_ptrs, - // Inputs - char *a_base, char *b_base, char *c_base, char *d_base, TensorShapeInfo A_meta, - TensorShapeInfo B_meta, TensorShapeInfo C_meta, TensorShapeInfo D_meta, size_t a_elem_size, - size_t b_elem_size, size_t c_elem_size, size_t d_elem_size, float *alpha_ptr, float *beta_ptr, - size_t num_tensors) { - size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= num_tensors) return; - - // Get dimensions for this tensor (from array or uniform value) - int64_t a_first = A_meta.first_dims ? A_meta.first_dims[idx] : A_meta.uniform_first; - int64_t a_last = A_meta.last_dims ? A_meta.last_dims[idx] : A_meta.uniform_last; - int64_t b_first = B_meta.first_dims ? B_meta.first_dims[idx] : B_meta.uniform_first; - int64_t b_last = B_meta.last_dims ? B_meta.last_dims[idx] : B_meta.uniform_last; - int64_t d_first = D_meta.first_dims ? D_meta.first_dims[idx] : D_meta.uniform_first; - int64_t d_last = D_meta.last_dims ? D_meta.last_dims[idx] : D_meta.uniform_last; - - // Compute offsets (from array or compute from uniform dims) - int64_t a_offset = - A_meta.offsets ? A_meta.offsets[idx] : (idx * A_meta.uniform_first * A_meta.uniform_last); - int64_t b_offset = - B_meta.offsets ? B_meta.offsets[idx] : (idx * B_meta.uniform_first * B_meta.uniform_last); - int64_t c_offset = - C_meta.offsets ? C_meta.offsets[idx] : (idx * C_meta.uniform_first * C_meta.uniform_last); - int64_t d_offset = - D_meta.offsets ? D_meta.offsets[idx] : (idx * D_meta.uniform_first * D_meta.uniform_last); - - // Compute data pointers - A_ptrs[idx] = a_base + a_offset * a_elem_size; - B_ptrs[idx] = b_base + b_offset * b_elem_size; - C_ptrs[idx] = c_base + c_offset * c_elem_size; - D_ptrs[idx] = d_base + d_offset * d_elem_size; - - // Compute storage dimensions for cuBLAS matrix layouts. - // For INPUTS (A, B): Row-wise storage is seen as transposed column-major by cuBLAS, - // so rows=last, cols=first. For columnwise, dims are already swapped. - a_rows[idx] = static_cast(a_last); - a_cols[idx] = static_cast(a_first); - b_rows[idx] = static_cast(b_last); - b_cols[idx] = static_cast(b_first); - // For OUTPUTS (D, C): cuBLAS writes in column-major, so rows=first (M), cols=last (N). - d_rows[idx] = static_cast(d_first); - d_cols[idx] = static_cast(d_last); - - // Fill alpha/beta pointers (per-matrix) - alpha_ptrs[idx] = alpha_ptr + idx; - beta_ptrs[idx] = beta_ptr + idx; -} - -// Launch the setup kernel to populate workspace arrays -inline void launch_grouped_gemm_setup( - const GroupedGemmSetupWorkspace &ws, const GroupedOperandSelection &A_sel, - const GroupedOperandSelection &B_sel, const transformer_engine::GroupedTensor *C, - const transformer_engine::GroupedTensor *D, const transformer_engine::Tensor *alpha_tensor, - const transformer_engine::Tensor *beta_tensor, size_t num_tensors, cudaStream_t stream) { - // Use shape info from selection (already accounts for columnwise dimension swap) - TensorShapeInfo A_meta = A_sel.shape; - TensorShapeInfo B_meta = B_sel.shape; - TensorShapeInfo C_meta = TensorShapeInfo::create_shape_info_for_C(C, D); - TensorShapeInfo D_meta = TensorShapeInfo::from_tensor(D); - - char *c_base = static_cast(C->data.dptr); - char *d_base = static_cast(D->data.dptr); - - const size_t a_elem_size = transformer_engine::typeToSize(A_sel.dtype); - const size_t b_elem_size = transformer_engine::typeToSize(B_sel.dtype); - const size_t c_elem_size = transformer_engine::typeToSize(C->dtype()); - const size_t d_elem_size = transformer_engine::typeToSize(D->dtype()); - - const int threads_per_block = 256; - const int num_blocks = (num_tensors + threads_per_block - 1) / threads_per_block; - - setup_grouped_gemm_kernel<<>>( - ws.A_ptrs, ws.B_ptrs, ws.C_ptrs, ws.D_ptrs, ws.a_rows, ws.a_cols, ws.b_rows, ws.b_cols, - ws.d_rows, ws.d_cols, ws.alpha_ptrs, ws.beta_ptrs, A_sel.dptr, B_sel.dptr, c_base, d_base, - A_meta, B_meta, C_meta, D_meta, a_elem_size, b_elem_size, c_elem_size, d_elem_size, - static_cast(alpha_tensor->data.dptr), static_cast(beta_tensor->data.dptr), - num_tensors); - - NVTE_CHECK_CUDA(cudaGetLastError()); -} - -inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) { - return GroupedGemmSetupWorkspace::required_setup_size(num_tensors, kGroupedGemmAlignment); -} - -} // namespace - -void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb, - const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha, - const NVTETensor beta, NVTETensor workspace_setup, - NVTETensor workspace_cublas, NVTEGroupedMatmulConfig config, - cudaStream_t stream) { - NVTE_API_CALL(nvte_grouped_gemm); - using namespace transformer_engine; - - // Grouped GEMM requires Blackwell (SM100) or newer and cuBLAS 13.1+ - const int current_device = cuda::current_device(); - NVTE_CHECK(cuda::sm_arch(current_device) >= 100, - "nvte_grouped_gemm requires Blackwell (SM100) or newer architecture."); - NVTE_CHECK(cuda::cublas_version() >= 130100, - "nvte_grouped_gemm requires cuBLAS 13.1+, but run-time cuBLAS version is ", - cuda::cublas_version()); - - // Convert to internal types - const GroupedTensor *inputA = convertNVTEGroupedTensorCheck(A); - const GroupedTensor *inputB = convertNVTEGroupedTensorCheck(B); - const GroupedTensor *inputC_raw = convertNVTEGroupedTensor(C); // Can be NULL - GroupedTensor *outputD = convertNVTEGroupedTensorCheck(D); - const Tensor *alpha_tensor = convertNVTETensorCheck(alpha); - const Tensor *beta_tensor = convertNVTETensorCheck(beta); - Tensor *wspace_setup = convertNVTETensor(workspace_setup); - Tensor *wspace_cublas = convertNVTETensor(workspace_cublas); - - // Parse config (if provided) - GroupedMatmulConfig config_; - if (config != nullptr) { - config_ = *reinterpret_cast(config); - } - - // Validate inputs and num_tensors - validate_grouped_gemm_inputs(inputA, inputB, inputC_raw, outputD, alpha_tensor, beta_tensor); - - // If C is NULL, use D as C (valid when beta=0, cuBLAS won't read C data) - const GroupedTensor *inputC = (inputC_raw != nullptr) ? inputC_raw : outputD; - const size_t num_tensors = inputA->num_tensors; - - // Select operand storage (row-wise vs column-wise) and adjust transpose flags to - // mirror the non-grouped GEMM logic for FP8 layout constraints. - const auto A_sel = select_grouped_operand(inputA, static_cast(transa), /*is_A=*/true); - const auto B_sel = select_grouped_operand(inputB, static_cast(transb), /*is_A=*/false); - - // Workspaces: setup (pointer arrays) and cuBLAS - const size_t setup_workspace_size = grouped_gemm_setup_workspace_size(num_tensors); - const size_t cublas_workspace_size = kGroupedGemmCublasWorkspaceSize; - - void *setup_workspace_ptr = validate_and_get_workspace_ptr(wspace_setup, setup_workspace_size, - "Grouped GEMM setup workspace"); - void *cublas_workspace_ptr = validate_and_get_workspace_ptr(wspace_cublas, cublas_workspace_size, - "Grouped GEMM cuBLAS workspace"); - - auto setup_workspace = GroupedGemmSetupWorkspace::from_buffers( - static_cast(setup_workspace_ptr), num_tensors); - launch_grouped_gemm_setup(setup_workspace, A_sel, B_sel, inputC, outputD, alpha_tensor, - beta_tensor, num_tensors, stream); - - // Get cuBLAS handle - using cublasHandleManager = detail::HandleManager; - cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle(); - - // Setup cuBLAS operations - cublasOperation_t op_A = A_sel.trans ? CUBLAS_OP_T : CUBLAS_OP_N; - cublasOperation_t op_B = B_sel.trans ? CUBLAS_OP_T : CUBLAS_OP_N; - - // Create grouped matrix layouts - cublasLtMatrixLayoutOpaque_t descA, descB, descC, descD; - init_matrix_layouts(descA, descB, descC, descD, setup_workspace, A_sel, B_sel, outputD, - num_tensors); - - // Create matmul descriptor - cublasLtMatmulDescOpaque_t matmulDesc; - init_matmul_desc(matmulDesc, op_A, op_B); - set_fp8_scale_pointers(matmulDesc, A_sel, B_sel); - - // Compute average dimensions for heuristics - // K dimension: if transa, K is A's first dim; if not, K is A's last dim - // Use original inputA and transa for heuristics (not modified A_sel.trans) - int64_t avg_m_val = config_.avg_m.value_or(compute_avg_first_dim(outputD)); - int64_t avg_n_val = config_.avg_n.value_or(compute_avg_last_dim(outputD)); - int64_t avg_k_val = - config_.avg_k.value_or(transa ? compute_avg_first_dim(inputA) : compute_avg_last_dim(inputA)); - - // Heuristic selection - cublasLtMatmulAlgo_t algo = select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC, - descD, avg_m_val, avg_n_val, avg_k_val); - - // Execute the grouped GEMM - NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, &matmulDesc, setup_workspace.alpha_ptrs, - setup_workspace.A_ptrs, &descA, setup_workspace.B_ptrs, &descB, - setup_workspace.beta_ptrs, setup_workspace.C_ptrs, &descC, - setup_workspace.D_ptrs, &descD, &algo, cublas_workspace_ptr, - kGroupedGemmCublasWorkspaceSize, stream)); -} - -#else // CUBLAS_VERSION < 130100 - -void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb, - const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha, - const NVTETensor beta, NVTETensor workspace_setup, - NVTETensor workspace_cublas, NVTEGroupedMatmulConfig config, - cudaStream_t stream) { - NVTE_ERROR("nvte_grouped_gemm requires cuBLAS 13.1+, but compile-time cuBLAS version is ", - CUBLAS_VERSION, ". Please upgrade to CUDA 13.1 or newer."); -} - -#endif // CUBLAS_VERSION >= 130100 diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 1afc9828e..b304ed34b 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -11,8 +11,6 @@ #ifndef TRANSFORMER_ENGINE_GEMM_H_ #define TRANSFORMER_ENGINE_GEMM_H_ -#include - #include "transformer_engine.h" #ifdef __cplusplus @@ -22,9 +20,6 @@ extern "C" { /*! \brief Configuration for matrix multiplication. */ typedef void *NVTEMatmulConfig; -/*! \brief Configuration for grouped matrix multiplication. */ -typedef void *NVTEGroupedMatmulConfig; - /*! \enum NVTEMatmulConfigAttribute * \brief Type of option for matrix multiplication. */ @@ -57,36 +52,6 @@ enum NVTEMatmulConfigAttribute { kNVTEMatmulConfigNumAttributes }; -/*! \enum NVTEGroupedMatmulConfigAttribute - * \brief Type of option for grouped matrix multiplication. - */ -enum NVTEGroupedMatmulConfigAttribute { - /*! Average M dimension hint - * - * Optional hint for average M dimension across all matrices in the group. - * Used by cuBLASLt for algorithm selection heuristics. If not set, - * computed automatically from D's logical shape. - */ - kNVTEGroupedMatmulConfigAvgM = 0, - /*! Average N dimension hint - * - * Optional hint for average N dimension across all matrices in the group. - * Used by cuBLASLt for algorithm selection heuristics. If not set, - * computed automatically from D's logical shape. - */ - kNVTEGroupedMatmulConfigAvgN = 1, - /*! Average K (reduction) dimension hint - * - * Optional hint for average K dimension across all matrices in the group. - * Used by cuBLASLt for algorithm selection heuristics. If not set, - * computed automatically from A's logical shape. - */ - kNVTEGroupedMatmulConfigAvgK = 2, - /*! Number of streaming multiprocessors to use in GEMM kernel. */ - kNVTEGroupedMatmulConfigSMCount = 3, - kNVTEGroupedMatmulConfigNumAttributes -}; - /*! \brief Create a matrix multiplication configuration. */ NVTEMatmulConfig nvte_create_matmul_config(); @@ -117,38 +82,6 @@ void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigA /*! \brief Destroy a matrix multiplication configuration. */ void nvte_destroy_matmul_config(NVTEMatmulConfig config); -/*! \brief Create a grouped matrix multiplication configuration. */ -NVTEGroupedMatmulConfig nvte_create_grouped_matmul_config(); - -/*! \brief Query an option in grouped matrix multiplication configuration. - * - * \param[in] config Grouped matrix multiplication configuration. - * \param[in] attr Option type. - * \param[out] buf Memory address to write option value. Ignored if - * NULL. - * \param[in] size_in_bytes Size of buf. - * \param[out] size_written Number of bytes that have been written to - * buf. If buf is NULL, then the number of - * bytes that would have been written. - */ -void nvte_get_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config, - NVTEGroupedMatmulConfigAttribute attr, void *buf, - size_t size_in_bytes, size_t *size_written); - -/*! \brief Set an option in grouped matrix multiplication configuration. - * - * \param[in] config Grouped matrix multiplication configuration. - * \param[in] attr Option type. - * \param[out] buf Memory address to read option value. - * \param[in] size_in_bytes Size of buf. - */ -void nvte_set_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config, - NVTEGroupedMatmulConfigAttribute attr, - const void *buf, size_t size_in_bytes); - -/*! \brief Destroy a grouped matrix multiplication configuration. */ -void nvte_destroy_grouped_matmul_config(NVTEGroupedMatmulConfig config); - /*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations (deprecated). * * This has been deprecated in favor of nvte_cublas_gemm_v2. @@ -295,46 +228,6 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor bool transa, bool transb, bool grad, NVTETensor *workspace, bool accumulate, bool use_split_accumulator, int math_sm_count, cudaStream_t stream); - -/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ -/*! \brief Grouped matrix multiplication: D = alpha * op(A) @ op(B) + beta * C - * - * \note Requires cuBLAS 13.1+ (CUDA 13.1+) and Blackwell (SM100) or newer GPU architecture. - * Will error at runtime if compiled with an older cuBLAS version or run on - * a pre-Blackwell GPU. - * - * Performs batched GEMM on a collection of matrices with potentially different shapes. - * All tensors in the group must have compatible dimensions for matrix multiplication. - * Uses NVTEGroupedTensor to efficiently handle collections of tensors with contiguous - * memory layout and shape metadata. - * - * \param[in] A Input grouped tensor A. - * \param[in] transa Whether to transpose A matrices. - * \param[in] B Input grouped tensor B. - * \param[in] transb Whether to transpose B matrices. - * \param[in] C Input grouped tensor C (can be NULL for beta=0). - * \param[out] D Output grouped tensor D. - * \param[in] alpha Scale multipliers for A @ B (NVTETensor with num_tensors elements). - * \param[in] beta Scale multipliers for C (NVTETensor with num_tensors elements). - * \param[in] workspace_setup Workspace tensor for pointer array setup. - * \param[in] workspace_cublas Workspace tensor for cuBLAS operations. - * \param[in] config Additional configuration (can be NULL for defaults). - * \param[in] stream CUDA stream for the operation. - * - * Requirements: - * - cuBLAS 13.1+ (CUDA 13.1+) - * - Blackwell (SM100) or newer GPU architecture - * - A, B, C (if provided), D must have the same num_tensors - * - For each i: D[i] = alpha[i] * op(A[i]) @ op(B[i]) + beta[i] * C[i] - * - Shape compatibility: if transa=false, transb=false: - * - A[i]: (M[i], K[i]), B[i]: (K[i], N[i]), D[i]: (M[i], N[i]) - */ -void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb, - const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha, - const NVTETensor beta, NVTETensor workspace_setup, - NVTETensor workspace_cublas, NVTEGroupedMatmulConfig config, - cudaStream_t stream); - #ifdef __cplusplus } // extern "C" #endif // __cplusplus @@ -438,70 +331,6 @@ class MatmulConfigWrapper { NVTEMatmulConfig config_ = nullptr; }; -/*! \struct GroupedMatmulConfigWrapper - * \brief C++ wrapper for NVTEGroupedMatmulConfig. - */ -class GroupedMatmulConfigWrapper { - public: - GroupedMatmulConfigWrapper() : config_{nvte_create_grouped_matmul_config()} {} - - GroupedMatmulConfigWrapper(const GroupedMatmulConfigWrapper &) = delete; - GroupedMatmulConfigWrapper &operator=(const GroupedMatmulConfigWrapper &) = delete; - - GroupedMatmulConfigWrapper(GroupedMatmulConfigWrapper &&other) : config_{other.config_} { - other.config_ = nullptr; - } - GroupedMatmulConfigWrapper &operator=(GroupedMatmulConfigWrapper &&other) { - if (config_ != nullptr) { - nvte_destroy_grouped_matmul_config(config_); - } - config_ = other.config_; - other.config_ = nullptr; - return *this; - } - - ~GroupedMatmulConfigWrapper() { - if (config_ != nullptr) { - nvte_destroy_grouped_matmul_config(config_); - config_ = nullptr; - } - } - - /*! \brief Get the underlying NVTEGroupedMatmulConfig. - * - * \return NVTEGroupedMatmulConfig held by this GroupedMatmulConfigWrapper. - */ - operator NVTEGroupedMatmulConfig() const noexcept { return config_; } - - /*! \brief Set average M dimension hint for algorithm selection. */ - void set_avg_m(int64_t avg_m) { - nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigAvgM, &avg_m, - sizeof(int64_t)); - } - - /*! \brief Set average N dimension hint for algorithm selection. */ - void set_avg_n(int64_t avg_n) { - nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigAvgN, &avg_n, - sizeof(int64_t)); - } - - /*! \brief Set average K dimension hint for algorithm selection. */ - void set_avg_k(int64_t avg_k) { - nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigAvgK, &avg_k, - sizeof(int64_t)); - } - - /*! \brief Set number of streaming multiprocessors to use. */ - void set_sm_count(int sm_count) { - nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigSMCount, &sm_count, - sizeof(int)); - } - - private: - /*! \brief Wrapped NVTEGroupedMatmulConfig. */ - NVTEGroupedMatmulConfig config_ = nullptr; -}; - } // namespace transformer_engine #endif // __cplusplus diff --git a/transformer_engine/common/util/cuda_runtime.cpp b/transformer_engine/common/util/cuda_runtime.cpp index 4b43940a5..f99900bac 100644 --- a/transformer_engine/common/util/cuda_runtime.cpp +++ b/transformer_engine/common/util/cuda_runtime.cpp @@ -6,8 +6,6 @@ #include "../util/cuda_runtime.h" -#include - #include #include @@ -212,12 +210,6 @@ int cudart_version() { return version; } -size_t cublas_version() { - // Cache version to avoid cuBLAS logging overhead - static size_t version = cublasLtGetVersion(); - return version; -} - } // namespace cuda } // namespace transformer_engine diff --git a/transformer_engine/common/util/cuda_runtime.h b/transformer_engine/common/util/cuda_runtime.h index f0aa23962..c696f6b57 100644 --- a/transformer_engine/common/util/cuda_runtime.h +++ b/transformer_engine/common/util/cuda_runtime.h @@ -73,12 +73,6 @@ const std::string &include_directory(bool required = false); */ int cudart_version(); -/* \brief cuBLAS version number at run-time - * - * Versions may differ between compile-time and run-time. - */ -size_t cublas_version(); - } // namespace cuda } // namespace transformer_engine From 80187b25bd67b0832c53aa2f70befc841d4b89f3 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 14 Apr 2026 16:38:51 +0000 Subject: [PATCH 15/22] Add guards to new functions --- transformer_engine/common/cast/core/common.cuh | 2 ++ transformer_engine/common/util/ptx.cuh | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/transformer_engine/common/cast/core/common.cuh b/transformer_engine/common/cast/core/common.cuh index 6b0d82c54..f26abeb90 100644 --- a/transformer_engine/common/cast/core/common.cuh +++ b/transformer_engine/common/cast/core/common.cuh @@ -39,11 +39,13 @@ inline bool dimensions_supported_by_TMA(const Tensor *const t) { return cols % alignment_requirement == 0; } +#ifndef __HIP_PLATFORM_AMD__ __device__ __forceinline__ unsigned char *align_smem_ptr_per_TMA_requirements(unsigned char *p) { size_t addr = reinterpret_cast(p); addr = (addr + TMA_SHMEM_ALIGNMENT - 1) & ~(TMA_SHMEM_ALIGNMENT - 1); return reinterpret_cast(addr); } +#endif //#ifndef __HIP_PLATFORM_AMD__ namespace kernel { diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 539f5a87a..7bfdb1bee 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -291,6 +291,7 @@ __device__ __forceinline__ void mbarrier_wait_parity_acquire_cta_shared_cta(uint #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +#ifndef __HIP_PLATFORM_AMD__ __device__ __forceinline__ void try_cancel_cta(uint64_t *mbar, __uint128_t *response_data_ptr) { constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; if constexpr (is_blackwell) { @@ -337,6 +338,7 @@ __device__ __forceinline__ void get_cancelled_cta_id_2D(__uint128_t *response_da "Try recompiling with sm_XXXa instead of sm_XXX."); } } +#endif //#ifndef __HIP_PLATFORM_AMD__ constexpr uint32_t FP32_MANTISSA_BITS = 23; constexpr uint32_t FP32_EXPONENT_BIAS = 127; @@ -781,6 +783,7 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, c } } +#ifndef __HIP_PLATFORM_AMD__ template __device__ __forceinline__ uint32_t mul_cvt_bf16_to_fp4_8x_round_to_nearest( const uint64_t in03, const uint64_t in47, const SCALING_COEFFICIENT_TYPE scaling_coefficient) { @@ -952,6 +955,7 @@ __device__ __forceinline__ uint32_t mul_cvt_bf16_to_fp4_8x_stochastic_rounding( } return out_8x; } +#endif //#ifndef __HIP_PLATFORM_AMD__ #endif // FP4_TYPE_SUPPORTED @@ -1806,6 +1810,7 @@ __device__ __forceinline__ floatx4 up_cast(const bf16x4 &in) { } #endif //#ifndef __HIP_PLATFORM_AMD__ +#ifndef __HIP_PLATFORM_AMD__ // Loads single BF16/FP16 element from shared memory state space __device__ __forceinline__ bf16 ld_shared_b16(const bf16 *__restrict__ src_smem) { const uint32_t src_smem_ptr = __cvta_generic_to_shared(src_smem); @@ -1858,6 +1863,7 @@ __device__ __forceinline__ void st_shared_b64(fp4e2m1x2 *__restrict__ dst_smem, asm volatile("st.shared.b64 [%0], %1;" : : "r"(dst_smem_ptr), "l"(fp4_pack_x16)); } #endif +#endif //#ifndef __HIP_PLATFORM_AMD__ } // namespace ptx namespace { From 6ec90f8617cc29af4f691b729f395e3ba4572ae3 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 14 Apr 2026 19:40:16 +0000 Subject: [PATCH 16/22] Updated signatures --- .../common/fused_attn_rocm/fused_attn.cpp | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index e787b31c8..d85858d9a 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -279,7 +279,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph) { + int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic) { using namespace transformer_engine; // TODO: Add return_max_logit support @@ -345,15 +345,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // NVTE fused attention FWD with packed QKV -void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, - const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, +void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, + NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, + const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen, bool is_training, bool return_max_logit, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - bool cuda_graph, NVTE_Bias_Type bias_type, + bool cuda_graph, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); using namespace transformer_engine; @@ -431,8 +432,8 @@ void nvte_fused_attn_bwd_qkvpacked( NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph, - NVTETensor workspace, cudaStream_t stream) { + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); using namespace transformer_engine; @@ -523,7 +524,8 @@ void nvte_fused_attn_fwd_kvpacked( size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { + int64_t window_size_right, bool bottom_right_diagonal, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); using namespace transformer_engine; @@ -614,8 +616,8 @@ void nvte_fused_attn_bwd_kvpacked( const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace, - cudaStream_t stream) { + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, bool cuda_graph, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -720,7 +722,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; @@ -806,7 +809,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; From 7911721e09f140135154a53dee4327ff148d85c6 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 15 Apr 2026 14:59:13 +0000 Subject: [PATCH 17/22] Adjusted call sites for deterministic kwd --- .../common/fused_attn_rocm/fused_attn.cpp | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index d85858d9a..fdfddf28d 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -395,7 +395,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_logit, - cuda_graph); + cuda_graph, false); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { fused_attn_ck_fwd_qkvpacked( @@ -478,7 +478,8 @@ void nvte_fused_attn_bwd_qkvpacked( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h, - max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false, cuda_graph); + max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false, cuda_graph, + deterministic); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { if((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)){ @@ -572,7 +573,7 @@ void nvte_fused_attn_fwd_kvpacked( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, - return_max_logit, cuda_graph); + return_max_logit, cuda_graph, false); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { fused_attn_ck_fwd_kvpacked( @@ -668,7 +669,8 @@ void nvte_fused_attn_bwd_kvpacked( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, - d, window_size_left, window_size_right, false, cuda_graph); + d, window_size_left, window_size_right, false, cuda_graph, + deterministic); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { @@ -763,7 +765,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, - return_max_logit, cuda_graph); + return_max_logit, cuda_graph, false); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { fused_attn_ck_fwd( @@ -856,7 +858,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false, - cuda_graph); + cuda_graph, deterministic); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { From 45d50dff421e60c93ba21640f9e14af8a7549748 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 16 Apr 2026 16:30:58 +0000 Subject: [PATCH 18/22] Build corrections and hardening for ptx --- tests/cpp/operator/test_cast_nvfp4_transpose.cu | 1 + transformer_engine/common/util/ptx.cuh | 8 ++++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index ef7331405..60b044bde 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -626,6 +626,7 @@ void performTest(float (*OP)(const float), rng_state.rowwise_cpu_dptr()[1] = 321; // rng_sequence rng_state.from_cpu(); + QuantizationConfigWrapper quant_config; quant_config.set_use_fast_math(use_fast_math); #ifdef __HIP_PLATFORM_AMD__ quant_config.set_stochastic_rounding(use_stochastic_rounding); diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 7bfdb1bee..36cc8a952 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -137,6 +137,7 @@ constexpr bool is_supported_arch() { #endif //#ifndef __HIP_PLATFORM_AMD__ +#ifndef __HIP_PLATFORM_AMD__ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init __device__ __forceinline__ void mbarrier_init(uint64_t *mbar, const uint32_t count) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) @@ -291,7 +292,6 @@ __device__ __forceinline__ void mbarrier_wait_parity_acquire_cta_shared_cta(uint #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } -#ifndef __HIP_PLATFORM_AMD__ __device__ __forceinline__ void try_cancel_cta(uint64_t *mbar, __uint128_t *response_data_ptr) { constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; if constexpr (is_blackwell) { @@ -391,6 +391,7 @@ __device__ __forceinline__ e8m0_t float_to_e8m0(float val) { #endif //#ifndef __HIP_PLATFORM_AMD__ } +#ifndef __HIP_PLATFORM_AMD__ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor // shared::cta -> global __device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global(uint64_t *dst_global_ptr, @@ -499,6 +500,7 @@ __device__ __forceinline__ void fence_proxy_async_shared_cta() { NVTE_DEVICE_ERROR("fence_proxy_async_shared_cta is only supported on SM 9.0+."); #endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } +#endif //#ifndef __HIP_PLATFORM_AMD__ template struct alignas(2 * sizeof(T)) FPx2 { @@ -959,6 +961,7 @@ __device__ __forceinline__ uint32_t mul_cvt_bf16_to_fp4_8x_stochastic_rounding( #endif // FP4_TYPE_SUPPORTED +#ifndef __HIP_PLATFORM_AMD__ // SIMD like "Fused" cast + multiplication (x2) __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in, const floatx2 &scale) { @@ -1126,7 +1129,6 @@ __device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const #endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) } -#ifndef __HIP_PLATFORM_AMD__ __device__ __forceinline__ int32_t elect_one_sync(uint32_t mask = 0xFFFFFFFFu) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) int32_t pred = 0; @@ -1868,6 +1870,7 @@ __device__ __forceinline__ void st_shared_b64(fp4e2m1x2 *__restrict__ dst_smem, namespace { +#ifndef __HIP_PLATFORM_AMD__ template __forceinline__ __device__ void initialize_barriers(uint64_t *mbar, const bool is_master_thread) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) @@ -2004,6 +2007,7 @@ __forceinline__ __device__ void copy_2d_to_sharedx3( NVTE_DEVICE_ERROR("copy_2d_to_sharedx3 is only supported on SM 10.0+."); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +#endif //#ifndef __HIP_PLATFORM_AMD__ } // namespace } // namespace transformer_engine From b5318e1bcdf2ba7263ee318c71de5251bb322cd2 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 21 Apr 2026 19:18:26 +0000 Subject: [PATCH 19/22] Added back rounding error mitigation in comparison --- .../cpp/operator/test_cast_nvfp4_transpose.cu | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 60b044bde..e5b6d716b 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -400,7 +400,21 @@ void compare_nvfp4_tensors(const std::string& name, const double t = (k == 0 ? test_data_pair.x : test_data_pair.y); const double r = (k == 0 ? ref_data_pair.x : ref_data_pair.y); +#ifdef __HIP_PLATFORM_AMD__ + bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); + if (mismatch) { + // Check if it is just a failure of round to nearest choosing different + // side of the real value + const double mean = (t + r) / 2; + const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); + const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); + const double cast_mean_p = static_cast(static_cast(mean_p)); + const double cast_mean_m = static_cast(static_cast(mean_m)); + mismatch = !(cast_mean_m == std::min(t, r) && cast_mean_p == std::max(t, r)); + } +#else const bool mismatch = fabs(t - r) > (atol + fabs(r) * rtol); +#endif if (mismatch) { total_mismatches++; // Optional: limit number of detailed messages to avoid overwhelming output @@ -662,8 +676,13 @@ void performTest(float (*OP)(const float), } ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); +#ifdef __HIP_PLATFORM_AMD__ + const double atol = 0.05; + const double rtol = 0.1; +#else const double atol = 1.0E-6; const double rtol = 1.0E-6; +#endif // Set dump_data=true to enable dumping tensor data to files for analysis compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, false); From 6a51c42679ff0b36e832071e20e477b49fe01f61 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 24 Apr 2026 14:59:47 +0000 Subject: [PATCH 20/22] PR feedback --- tests/cpp/operator/test_cast_nvfp4_transpose.cu | 5 ----- tests/jax/test_fused_attn.py | 2 +- transformer_engine/jax/cpp_extensions/attention.py | 4 ++-- 3 files changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index e5b6d716b..66244444f 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -676,13 +676,8 @@ void performTest(float (*OP)(const float), } ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); -#ifdef __HIP_PLATFORM_AMD__ - const double atol = 0.05; - const double rtol = 0.1; -#else const double atol = 1.0E-6; const double rtol = 1.0E-6; -#endif // Set dump_data=true to enable dumping tensor data to files for analysis compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, false); diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index bd92bcdda..26f5514d3 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -421,7 +421,7 @@ def _check_configs(self): pytest.skip( "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN" ) - if get_device_compute_capability(0) >= 100 and self.is_training and not is_hip_extension(): + if not is_hip_extension() and get_device_compute_capability(0) >= 100 and self.is_training: if FusedAttnHelper.is_non_deterministic_allowed() and ( (self.dropout_prob != 0.0 and self.attn_bias_type != AttnBiasType.NO_BIAS) or get_cudnn_version() < 90700 diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 3b3d6bd5e..65857cbc4 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -3609,8 +3609,8 @@ def fused_attn_bwd( softmax_offset, (None, HEAD_AXES, None, None) ) - compute_capabilities = get_all_device_compute_capability() - if any(x >= 100 for x in compute_capabilities) and is_training and not is_hip_extension(): + compute_capabilities = get_all_device_compute_capability() if not is_hip_extension() else [] + if any(x >= 100 for x in compute_capabilities) and is_training: assert ( FusedAttnHelper.is_non_deterministic_allowed() and get_cudnn_version() >= (9, 7, 0) From 4d27c387ca7da54195fd4078fb446d228056447b Mon Sep 17 00:00:00 2001 From: ipanfilo <145064111+ipanfilo@users.noreply.github.com> Date: Tue, 5 May 2026 13:46:35 -0400 Subject: [PATCH 21/22] Fix build on Pytorch 2.11 (#16505) (#575) --- build_tools/hipify/custom_map.json | 1 - .../pytorch/csrc/extensions/attention.cpp | 11 ++++++++++- transformer_engine/pytorch/csrc/extensions/gemm.cpp | 13 ++++++++++--- .../pytorch/csrc/extensions/normalization.cpp | 11 +++++++++-- 4 files changed, 29 insertions(+), 7 deletions(-) diff --git a/build_tools/hipify/custom_map.json b/build_tools/hipify/custom_map.json index 5fc3cded0..6525731f5 100644 --- a/build_tools/hipify/custom_map.json +++ b/build_tools/hipify/custom_map.json @@ -11,7 +11,6 @@ "__nv_fp8_e5m2" : "te_hip_fp8_e5m2", "__nv_fp8_e4m3" : "te_hip_fp8_e4m3", "cuda::getCurrentCUDAStream" : "hip::getCurrentHIPStreamMasqueradingAsCUDA", - "at::cuda::CUDAGuard" : "at::hip::HIPGuardMasqueradingAsCUDA", "__nv_fp4_e2m1" : "__hip_fp4_e2m1", "__nv_fp4x2_e2m1" : "__hip_fp4x2_e2m1", "__nv_fp4x4_e2m1" : "__hip_fp4x4_e2m1", diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index bf62db8c3..72087a521 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -8,6 +10,13 @@ #include "common.h" #include "pybind.h" +#include +#if USE_ROCM && TORCH_VERSION_MINOR < 11 +using TECUDAGuard = at::hip::HIPGuardMasqueradingAsCUDA; +#else +using TECUDAGuard = at::cuda::CUDAGuard; +#endif + namespace { constexpr int block_size = 512; @@ -112,7 +121,7 @@ std::vector fused_attn_fwd( // Ensure that cuDNN handle is created on the correct device, // overriding torch.cuda.set_device calls from user side. // Assumes all tensors passed are on the same device. - at::cuda::CUDAGuard device_guard(cu_seqlens_q.device()); + TECUDAGuard device_guard(cu_seqlens_q.device()); auto none = py::none(); diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 941b88e36..6898ce387 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -20,6 +20,13 @@ #include "transformer_engine/transformer_engine.h" #include "util.h" +#include +#if USE_ROCM && TORCH_VERSION_MINOR < 11 +using TECUDAGuard = at::hip::HIPGuardMasqueradingAsCUDA; +#else +using TECUDAGuard = at::cuda::CUDAGuard; +#endif + namespace { void* get_data_ptr(transformer_engine::pytorch::MaybeTensor tensor) { @@ -100,7 +107,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Ensure that cublasLt handle is created on the correct device, // overriding torch.cuda.set_device calls from user side. // Assumes all tensors passed are on the same device. - at::cuda::CUDAGuard device_guard(workspace.device()); + TECUDAGuard device_guard(workspace.device()); // Input tensors NVTE_CHECK(!A.is_none(), "Tensor A has not been provided"); @@ -388,7 +395,7 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, // Ensure that cublasLt handle is created on the correct device, // overriding torch.cuda.set_device calls from user side. // Assumes all tensors passed are on the same device. - at::cuda::CUDAGuard device_guard(workspace.device()); + TECUDAGuard device_guard(workspace.device()); // TODO: Handle scaling modes NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYED_TENSOR_SCALING; @@ -442,7 +449,7 @@ std::optional> te_general_grouped_gemm( // Ensure that cublasLt handle is created on the correct device, // overriding torch.cuda.set_device calls from user side. // Assumes all tensors passed are on the same device. - at::cuda::CUDAGuard device_guard(workspace[0].device()); + TECUDAGuard device_guard(workspace[0].device()); void* output_data_ptr = nullptr; if (single_output) { diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index b78982d4d..8f8eed2c3 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -10,6 +10,13 @@ #include "common/util/system.h" #include "pybind.h" +#include +#if USE_ROCM && TORCH_VERSION_MINOR < 11 +using TECUDAGuard = at::hip::HIPGuardMasqueradingAsCUDA; +#else +using TECUDAGuard = at::cuda::CUDAGuard; +#endif + namespace transformer_engine::pytorch { std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, @@ -69,7 +76,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // Ensure that cuDNN handle is created on the correct device, // overriding torch.cuda.set_device calls from user side. // Assumes all tensors passed are on the same device. - at::cuda::CUDAGuard device_guard(input.cast().device()); + TECUDAGuard device_guard(input.cast().device()); // Input and param tensors auto none = py::none(); @@ -319,7 +326,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Ensure that cuDNN handle is created on the correct device, // overriding torch.cuda.set_device calls from user side. // Assumes all tensors passed are on the same device. - at::cuda::CUDAGuard device_guard(input.cast().device()); + TECUDAGuard device_guard(input.cast().device()); // Input and param tensors auto none = py::none(); From 036130e484fd1cd062c0d3ea8a972d55742246d6 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 11 May 2026 16:29:47 +0000 Subject: [PATCH 22/22] Updated w/ explicit guard for bottom_right swa alignment --- .../common/fused_attn_rocm/fused_attn.cpp | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index fdfddf28d..2564f059d 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -228,6 +228,29 @@ const std::unordered_map mNVTEMaskTypeStr = { {NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK, "PADDING_CAUSAL_BOTTOM_RIGHT_MASK"}, }; +// True for the two *_BOTTOM_RIGHT_MASK variants, false otherwise. +inline bool implied_bottom_right_diagonal(NVTE_Mask_Type attn_mask_type) { + return attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK; +} + +// The ROCm/AITER fused-attn backend derives mask anchoring solely from the +// NVTE_Mask_Type enum and does not consume `bottom_right_diagonal`. Any +// divergence between the boolean and the alignment implied by the mask type +// would silently produce numerically incorrect attention, so we reject it +// here until AITER plumbs an explicit alignment parameter. +inline void check_bottom_right_diagonal(NVTE_Mask_Type attn_mask_type, + bool bottom_right_diagonal) { + if (bottom_right_diagonal != implied_bottom_right_diagonal(attn_mask_type)) { + NVTE_ERROR( + "ROCm fused attention does not support a `bottom_right_diagonal` value " + "that diverges from the alignment implied by `attn_mask_type`. Use " + "NVTE_CAUSAL_BOTTOM_RIGHT_MASK or NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK " + "for bottom-right alignment, or the corresponding non-bottom-right " + "mask types for top-left alignment."); + } +} + void log_fused_attn_config( const char* func_name, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t batch_size, @@ -358,6 +381,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); using namespace transformer_engine; + check_bottom_right_diagonal(attn_mask_type, bottom_right_diagonal); const Tensor *input_cu_seqlens = convertNVTETensorCheck(cu_seqlens); const Tensor *input_cu_seqlens_padded = convertNVTETensorCheck(cu_seqlens_padded); @@ -436,6 +460,7 @@ void nvte_fused_attn_bwd_qkvpacked( bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); using namespace transformer_engine; + check_bottom_right_diagonal(attn_mask_type, bottom_right_diagonal); const Tensor *input_cu_seqlens = convertNVTETensorCheck(cu_seqlens); const Tensor *input_cu_seqlens_padded = convertNVTETensorCheck(cu_seqlens_padded); @@ -530,6 +555,7 @@ void nvte_fused_attn_fwd_kvpacked( NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); using namespace transformer_engine; + check_bottom_right_diagonal(attn_mask_type, bottom_right_diagonal); const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); @@ -621,6 +647,7 @@ void nvte_fused_attn_bwd_kvpacked( NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); using namespace transformer_engine; + check_bottom_right_diagonal(attn_mask_type, bottom_right_diagonal); const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); @@ -729,6 +756,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; + check_bottom_right_diagonal(attn_mask_type, bottom_right_diagonal); const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); @@ -816,6 +844,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; + check_bottom_right_diagonal(attn_mask_type, bottom_right_diagonal); const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded);